Skip to content

Fast Function Approximations lowering. #8566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 84 commits into
base: main
Choose a base branch
from

Conversation

mcourteaux
Copy link
Contributor

@mcourteaux mcourteaux commented Feb 8, 2025

The big transcendental lower update! Replaces #8388.

TODO

I still have to do:

  • Validate all is fine with build bots.

Overview:

  • Fast transcendentals implemented for: sin, cos, tan, atan, atan2, exp, log, expm1, tanh, asin, acos.

  • Simple API to specify precision requirements. Default-initialized precision (AUTO without constraints) means "don't care about precision, as long as it's reasonable and fast", which gives you the highest chance of selecting a high-performance implementation based on hardware instructions. Optimization objectives MULPE (max ULP error), and MAE (max absolute error) are available. Compared to previous PR, I removed MULPE_MAE as I didn't see a good purpose for it.

  • Tabular info on intrinsics and native functions their precision and speed, to select an appropriate implementation for lowering to something that is definitely not slower, while satisfying the precision requirements.

    • OpenCL: lower to native_cos, native_exp, etc...
    • Metal: lower to fast::cos, fast::exp, etc...
    • CUDA: lower to dedicated PTX instructions when available.
    • When no fast hardware versions are available: polynomial approximations.
  • Tabular info measuring exact precision obtained by exhaustively iterating over all floats in the polynomial's native interval. Measured both MULPE and MAE for Float32. Precisions are not yet evaluated on f16 or f64, which is future work (which I have currently not planned). Precisions are measured by correctness/determine_fast_function_approximation_metrics.cpp.

  • Performance tests validating that:

    • the AUTO-versions are at least always faster.
    • all known-to-be faster functions are faster.
  • Accuracy tests validating that:

    • the AUTO-versions are at least somewhat reasonable precise (at least 1e-4).
    • all polynomials satisfy the precision they advertise on their non-range-reduced interval.
  • Drive-by fix for adding libOpenCL.so.1 to the list of tested sonames for the OpenCL runtime.

Review guide

  • I pass ApproximationPrecision parameters as a Call::make_struct node with 4 parameters (see API below). This approximation precision Call node survives until lowering pass where the transcendentals are lowered. In this pass, they are extracted again from this Call node's arguments. I conceptually like that this way, they are bundled and clearly not at the same level as the actual mathematical arguments. Is this a good approach? In order for this to work, I had to stop CSE from extracting those precision arguments, and StrictfyFloat from recursing down into that struct and litter strict_float on those numbers. I have seen the Call::bundle intrinsic. Perhaps this one is better for that purpose? @abadams
  • I tried to design the API such that it would also be compatible with Float(16) and Float(64), but those are not yet implemented/tested. The polynomial approximations should work correctly (although untested) for these other data-types.
  • The intrinsics table and their behavior (MULPE/MAE-precision) is measured on devices I have available (and build bots). On some backends (such as OpenCL, and Vulkan) these intrinsics have implementation-defined behavior. This probably means it's AMD or NVIDIA that gets to implement them and determine the precision. I do not have any AMD GPU available to test the OpenCL and Vulkan backends on to see how these functions behave. I have realized that for example Vulkan's native_tan() compiles to the same three instructions as I implemented on CUDA: sin.approx.f32, cos.approx.f32, div.approx.f32. I haven't investigated AMD's documentation on available hardware instructions.

Concerns

  • I disabled bit-exact accuracy tests on GPU backends, because they don't play nice without proper control over floating point optimizations.
  • We need proper strict_float ops (i.e.: strict_add(), strict_mul(), strict_sub(), strict_div(), strict_fma()).
  • Documentation for exp() (regular exp(), not fast_exp()) claims to be bit-exact, which is proven wrong by the pre-existing test correctness/vector_math.cpp:
    exp(47.812500) = 581706671813124161536.0000000000 instead of 581707832897403092992.0000000000 (mantissa: 8144523 vs 8144490)
    log mantissa error: 2
    exp mantissa error: 33
    pow mantissa error: 24
    fast_log mantissa error: 16
    fast_exp mantissa error: 59
    fast_pow mantissa error: 54
    
  • Buildbot reveals that WebGPU does not support vectorization of scalar function calls. tan_f32(vec2<f32>) did not get converted to tan_f32(first element), tan_f32(second element).

API

struct ApproximationPrecision {
    enum OptimizationObjective {
        AUTO,   //< No preference, but favor speed.
        MAE,    //< Optimized for Max Absolute Error.
        MULPE,  //< Optimized for Max ULP Error. ULP is "Units in Last Place", when represented in IEEE 32-bit floats.
    } optimized_for{AUTO};

    /**
     * Most function approximations have a range where the approximation works
     * natively (typically close to zero), without any range reduction tricks
     * (e.g., exploiting symmetries, repetitions). You may specify a maximal
     * absolute error or maximal units in last place error, which will be
     * interpreted as the maximal absolute error within this native range of the
     * approximation. This will be used as a hint as to which implementation to
     * use.
     */
    // @{
    uint64_t constraint_max_ulp_error{0};
    double constraint_max_absolute_error{0.0};
    // @}

    /**
     * For most functions, Halide has a built-in table of polynomial
     * approximations. However, some targets have specialized instructions or
     * intrinsics available that allow to produce an even faster approximation.
     * Setting this integer to a non-zero value will force Halide to use the
     * polynomial with at least this many terms, instead of specialized
     * device-specific code. This means this is still combinable with the
     * other constraints.
     * This is mostly useful for testing and benchmarking.
     */
    int force_halide_polynomial{0};

    /** MULPE-optimized, with max ULP error. */
    static ApproximationPrecision max_ulp_error(uint64_t mulpe) {
        return ApproximationPrecision{MULPE, mulpe, 0.0f, false};
    }
    /** MAE-optimized, with max absolute error. */
    static ApproximationPrecision max_abs_error(float mae) {
        return ApproximationPrecision{MAE, 0, mae, false};
    }
    /** MULPE-optimized, forced Halide polynomial with given number of terms. */
    static ApproximationPrecision poly_mulpe(int num_terms) {
        user_assert(num_terms > 0);
        return ApproximationPrecision{MULPE, 0, 0.0f, num_terms};
    }
    /** MAE-optimized, forced Halide polynomial with given number of terms. */
    static ApproximationPrecision poly_mae(int num_terms) {
        user_assert(num_terms > 0);
        return ApproximationPrecision{MAE, 0, 0.0f, num_terms};
    }
};

/** Fast approximation to some trigonometric functions for Float(32).
 * Slow on x86 if you don't have at least sse 4.1.
 * Vectorize cleanly when using polynomials.
 * See \ref ApproximationPrecision for details on specifying precision.
 */
// @{
/** Caution: Might exceed the range (-1, 1) by a tiny bit.
 * On NVIDIA CUDA: default-precision maps to a dedicated sin.approx.f32 instruction. */
Expr fast_sin(const Expr &x, ApproximationPrecision precision = {});
/** Caution: Might exceed the range (-1, 1) by a tiny bit.
 * On NVIDIA CUDA: default-precision maps to a dedicated cos.approx.f32 instruction. */
Expr fast_cos(const Expr &x, ApproximationPrecision precision = {});
/** On NVIDIA CUDA: default-precision maps to a combination of sin.approx.f32,
 * cos.approx.f32, div.approx.f32 instructions. */
Expr fast_tan(const Expr &x, ApproximationPrecision precision = {});
Expr fast_asin(const Expr &x, ApproximationPrecision precision = {});
Expr fast_acos(const Expr &x, ApproximationPrecision precision = {});
Expr fast_atan(const Expr &x, ApproximationPrecision precision = {});
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = {});
// @}

/** Fast approximate log for Float(32).
 * Returns nonsense for x <= 0.0f.
 * Approximation available up to the Max 5 ULP, Mean 2 ULP.
 * Vectorizes cleanly when using polynomials.
 * Slow on x86 if you don't have at least sse 4.1.
 * On NVIDIA CUDA: default-precision maps to a combination of lg2.approx.f32 and a multiplication.
 * See \ref ApproximationPrecision for details on specifying precision.
 */
Expr fast_log(const Expr &x, ApproximationPrecision precision = {});

/** Fast approximate exp for Float(32).
 * Returns nonsense for inputs that would overflow.
 * Approximation available up to Max 3 ULP, Mean 1 ULP.
 * Vectorizes cleanly when using polynomials.
 * Slow on x86 if you don't have at least sse 4.1.
 * On NVIDIA CUDA: default-precision maps to a combination of ex2.approx.f32 and a multiplication.
 * See \ref ApproximationPrecision for details on specifying precision.
 */
Expr fast_exp(const Expr &x, ApproximationPrecision precision = {});

/** Fast approximate expm1 for Float(32).
 * Returns nonsense for inputs that would overflow.
 * Slow on x86 if you don't have at least sse 4.1.
 */
Expr fast_expm1(const Expr &x, ApproximationPrecision precision = {});

/** Fast approximate pow for Float(32).
 * Returns nonsense for x < 0.0f.
 * Returns 1 when x == y == 0.0.
 * Approximations accurate up to Max 53 ULPs, Mean 13 ULPs.
 * Gets worse when approaching overflow.
 * Vectorizes cleanly when using polynomials.
 * Slow on x86 if you don't have at least sse 4.1.
 * On NVIDIA CUDA: default-precision maps to a combination of ex2.approx.f32 and lg2.approx.f32.
 * See \ref ApproximationPrecision for details on specifying precision.
 */
Expr fast_pow(const Expr &x, const Expr &y, ApproximationPrecision precision = {});

/** Fast approximate pow for Float(32).
 * Approximations accurate to 2e-7 MAE, and Max 2500 ULPs (on average < 1 ULP) available.
 * Caution: might exceed the range (-1, 1) by a tiny bit.
 * Vectorizes cleanly when using polynomials.
 * Slow on x86 if you don't have at least sse 4.1.
 * On NVIDIA CUDA: default-precision maps to a combination of ex2.approx.f32 and lg2.approx.f32.
 * See \ref ApproximationPrecision for details on specifying precision.
 */
Expr fast_tanh(const Expr &x, ApproximationPrecision precision = {});

Fixes #8243.

@mcourteaux mcourteaux marked this pull request as draft February 8, 2025 21:36
@mcourteaux mcourteaux requested a review from abadams February 10, 2025 18:11
@mcourteaux mcourteaux marked this pull request as ready for review February 10, 2025 18:12
@mcourteaux mcourteaux added enhancement New user-visible features or improvements to existing features. performance gpu release_notes For changes that may warrant a note in README for official releases. labels Feb 10, 2025
@mcourteaux mcourteaux force-pushed the fast-math-lowering branch 2 times, most recently from bea8612 to 0de4dbc Compare February 11, 2025 11:36
@mcourteaux mcourteaux added skip_buildbots Do not run buildbots on this PR. Must add before opening PR as we scan labels immediately. and removed skip_buildbots Do not run buildbots on this PR. Must add before opening PR as we scan labels immediately. labels Feb 11, 2025
Copy link
Member

@abadams abadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for doing this, and sorry it took me so long to review it (I'm finally clawing out from under my deadlines). It generally looks good but I have some review comments.

What order should this be done in vs our change to strict_float? Is there any reason to delay this until after strict_float is changed?

@@ -33,6 +33,12 @@ bool should_extract(const Expr &e, bool lift_all) {
return false;
}

if (const Call *c = e.as<Call>()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The need for this makes me think the extra args would be better just as flat extra args to the math intrinsic, instead of being packed into a struct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if I don't pass the things as a make_struct(), CSE will still lift these arguments out of the call if you have multiple of those. This makes it only harder to figure out what the actual precision-arguments were to the Call node in the lowering pass that wants to read them back. They might now have become Let or LetStmt, instead of a simple ImmFloat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CSE will never lift constants, so they should be fine.

@mcourteaux
Copy link
Contributor Author

What order should this be done in vs our change to strict_float? Is there any reason to delay this until after strict_float is changed?

strict_float is already broken today in a few scenarios. All these transcendentals are broken right now regarding strict_float(). Every GPU-backend has a different implementation of them, with different precisions. So a strict_float(tan(strict_float(x))) will give wildly different results on different backends. Vulkan in particular has a very imprecise native tan() function. They already default to some approximation. So given that all transcendentals are already broken (because we rely on some third-party function), I wouldn't be concerned that strict_float has serious limitations and is not playing nicely with the lowering pass selecting a polynomial in this PR.

The good thing is, is that this PR actually paves the way towards fixing strict_float for transcendentals. Long term, if we have our own implementation for all of them, we can strict_float guarantee they will yield the same result. However, this would require us to have an FMA intrinsic as well (to be able to express the polynomials using Horner's method with fma-ops). So yeah, I'm definitely in favor of accepting that strict_float is badly broken, and moving forward with this PR.

@mcourteaux
Copy link
Contributor Author

Thanks so much for doing this, and sorry it took me so long to review it (I'm finally clawing out from under my deadlines). It generally looks good but I have some review comments.

No worries! Thanks a lot for looking into this! I'll address your feedback, questions and improvements tomorrow, such that this is all still fresh in your head.

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Mar 11, 2025

@abadams Any chance you can put together an fma intrinsic in Halide with reasonable effort? I don't wanna ask too much, but you had a pretty good idea of what that would look like I think. That would be helpful to finalize this PR.

Update: Nevermind, I found another approach that works well for now.

@mcourteaux mcourteaux added the skip_buildbots Do not run buildbots on this PR. Must add before opening PR as we scan labels immediately. label Mar 12, 2025
@mcourteaux mcourteaux removed the skip_buildbots Do not run buildbots on this PR. Must add before opening PR as we scan labels immediately. label Mar 14, 2025
@mcourteaux
Copy link
Contributor Author

mcourteaux commented Mar 15, 2025

I took care of all feedback, except for the make_struct wrapper for the precision arguments. That's for later. I updated the original post for the PR on top. More info with the latest concerns and notes can be found there!

@mcourteaux mcourteaux requested a review from abadams March 15, 2025 11:49
@mcourteaux mcourteaux force-pushed the fast-math-lowering branch from a171ec1 to 6cebc56 Compare June 1, 2025 13:30
@mcourteaux
Copy link
Contributor Author

mcourteaux commented Jun 1, 2025

@slomp or @alexreinking Can you identify the issue with the Windows builds here? https://buildbot.halide-lang.org/master/#/builders/107/builds/98

The test doesn't see the extern declared symbol, compiled in ApproximationTables.cpp. I tried adding HL_EXPOROT_SYMBOL but that didn't resolve this.

Update: I changed it to have accessor functions that simply return the static member. That seems to be the way it's done everywhere in Halide header files.

@mcourteaux
Copy link
Contributor Author

@alexreinking Can you assess what this macos buildbot is up to?

@mcourteaux mcourteaux force-pushed the fast-math-lowering branch from 4ceca2c to 845d83a Compare June 14, 2025 12:02
@mcourteaux
Copy link
Contributor Author

@abadams There seems to be an issue with the strict float behavior on the WebAssembly target. It seems to be powered by LLVM, so it's weird it doesn't work. The other LLVM-powered backends seem to work fine. Any clues what might be going wrong there? Is the Wasm runtime further simplifying and not respecting the stream of instructions as-is?

using namespace Halide;
using namespace Halide::Internal;

const bool use_icons = true;
Copy link
Member

@zvookin zvookin Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get rid of this. If you want icons, print them for each subcase and then print a single unambiguous piece of text for the set that can be searched for via grep or in an editor. (The text is the worst case result.) Colored text would also be fine, but requires terminal support. The icons have already wasted my time as I can't search for them easily and there are a lot of test cases in the test. A flag that one edits in the source code for this is just silly. Yeah, we ought to have a library for writing tests, but we don't. Keep it simple.

Copy link
Member

@alexreinking alexreinking Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @zvookin. No information should be emoji-only. Emojis shouldn't be sent to terminals that don't support them (e.g. on Windows, I think you need to _setmode(_fileno(stdout), _O_U8TEXT)). Let's avoid introducing emojis to the codebase in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I just default the bool to false? I found it very useful to have these emojis. Color is a very good dimension of information when you're skimming through the many results.

Expr flip = x < make_const(type, 0.0);
Expr use_cotan = abs_x > make_const(type, PI / 4.0);
Expr pi_over_two_minus_abs_x;
if (type == Float(64)) {
Copy link
Member

@zvookin zvookin Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing both branches of the if to use strict add or sub fixes the WebAssembly failures. This is due to reassociation in fast math -- as in commenting out fp_flags.setAllowReassoc in CodeGen_LLVM::set_fp_fast_math makes the tests pass for wasm without this change. Thus I'm guessing strictness is legitimately needed here.

    if (type == Float(64)) {
        // TODO(mcourteaux): We could do split floats here too.
        pi_over_two_minus_abs_x = strict_sub(make_const(type, PI_OVER_TWO), abs_x);
    } else if (type == Float(32)) {  // We want to do this trick always, because we invert later.
        auto [hi, lo] = split_float(PI_OVER_TWO);
        pi_over_two_minus_abs_x = strict_add(strict_sub(make_const(type, hi), abs_x), make_const(type, lo));
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are currently not testing the f64 variants. So I'm assuming you're referring to the + make_const(type, lo) addition needing to be a strict_add instead? I feel like that shouldn't be necessary, as I stated that the nested sub is already strict? It's optimizing away the addition, because it peaks into the strict_sub and still does constant folding? Am I fundamentally not understanding what the strict ops should do, or would you consider this a bug in LLVMs behavior? In other words, I wouldn't have expected any difference between this version, and one where the + is exchanged for a strict_add?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe LLVM does reassociative rewrites/folding if the root node (+) is non-strict, without checking if the child nodes (-) are also non-strict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be very strange and render every LLVM expression in IR non-strict. For example: a binary tree of 5 levels deep of nested strict operations, followed by a non-strict + 1, destroys the whole strict tree? That's at least how I interpret your guess.

Well, I can see that if the non-strict + is directly followed by a strict +, you can tell that the secondary add is not gonna do anything if you constant fold it with the other addition. So you could argue that the "optimization" is correct, because the addition was non-strict, so LLVM decided it would not have any impact. I am not too sure if that reason is sound. However, adding strictness to the second op should be fine either way. We do want the add, so making it strict will never hurt.

Copy link
Member

@zvookin zvookin Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a set of rewrite rules that are only allowed when certain operators are tagged with certain flags. A rewrite won't be done if any of the nodes it is trying to match against don't allow the type of optimization being applied, but it does not matter if nodes that are just inputs, i.e. wildcards, do not have the flags. Those inner nodes won't see their own rewrites, but they're just treated as variables in the rewrite that is actually happening.

Here's the problematic Halide IR:

   for (tan_approx_MULPE_poly1.s0.i.i, 0, t42) {
    let t28 = input[tan_approx_MULPE_poly1.s0.i.i]
    let t29 = (float32)round(t28*0.318310f)
    let t30 = (float32)strict_sub(t28 - (t29*3.141593f), t29*-8.742278e-08f)
    let t31 = (float32)abs(t30)
    let t32 = 0.785398f < t31
    let t33 = select(t32, (float32)strict_sub(1.570796f, t31) + -4.371139e-08f, t31)
    let t35 = let t44 = (t33*t33) in (t44*((t44*0.009524f) + -0.428571f))
    let t37 = let t45 = ((((t33*t33)*-0.095238f) + 1.000000f)*t33) in (select(t32, t35 + 1.000000f, t45)/select(t32, t45, t35 + 1.000000f))
    tan_approx_MULPE_poly1[t43 + tan_approx_MULPE_poly1.s0.i.i] = select(t30 < 0.000000f, 0.000000f - t37, t37)
   }

Look at the initialization of t33. The second argument to the select has a non-strict + at the top level. This is then used directly with non-strict * and +. LLVM has to guarantee that the 1.570796f in the strict_sub is not moved around, but it can freely move the -4.371139e-08f with regard to the other expressions. And in fact for wasm, the select is done between the t31 and strict_sub(1.570796f, t31) without the addition of -4.371139e-08f because it has noticed that it can constant fold it with the addition of 1 because the select conditions are the same. (I didn't completely reverse engineer the wasm instructions but it is definitely doing the select without the addition of the negative constant and taking care of it later. There are other more complicated uses of t33 and it may be doing constant folding there or it may be just a dumb optimization overall. But it is allowed.)

I'm pretty sure the Float(64) case also had to be fixed to get rid of all test failures, but feel free to verify that on your own.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the analysis @zvookin. I fixed the issue in the PR, while introducing f64 and f16 generation of those functions.

Fix FloatImm codegen on several GPU backends.
Fix gpu_float16_intrinsics test. Was not really using many float16 ops at all, because fast_pow was historically casting to float.
Implement a few quick workarounds for NVIDIA not properly implementing fp16 built-in functions.
…n case that's marked as supported by the GPU backend.)

Change printing style of float-literals to use scientific notation with enough digits to be exact.
Relax performance test for fast_tanh on WebGPU.
Bugfix float16 nan/inf constants on WebGPU.
Separately print out compilation log in runtime/opencl as those logs can get very large, beyond the size of the HeapPrinter capacity.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New user-visible features or improvements to existing features. gpu performance release_notes For changes that may warrant a note in README for official releases.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

We should make a cleanly-vectorizing fast-approximation for atan2f.
4 participants