-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Fast Function Approximations lowering. #8566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
bea8612
to
0de4dbc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for doing this, and sorry it took me so long to review it (I'm finally clawing out from under my deadlines). It generally looks good but I have some review comments.
What order should this be done in vs our change to strict_float? Is there any reason to delay this until after strict_float is changed?
@@ -33,6 +33,12 @@ bool should_extract(const Expr &e, bool lift_all) { | |||
return false; | |||
} | |||
|
|||
if (const Call *c = e.as<Call>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The need for this makes me think the extra args would be better just as flat extra args to the math intrinsic, instead of being packed into a struct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if I don't pass the things as a make_struct()
, CSE will still lift these arguments out of the call if you have multiple of those. This makes it only harder to figure out what the actual precision-arguments were to the Call node in the lowering pass that wants to read them back. They might now have become Let or LetStmt, instead of a simple ImmFloat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CSE will never lift constants, so they should be fine.
strict_float is already broken today in a few scenarios. All these transcendentals are broken right now regarding strict_float(). Every GPU-backend has a different implementation of them, with different precisions. So a The good thing is, is that this PR actually paves the way towards fixing strict_float for transcendentals. Long term, if we have our own implementation for all of them, we can strict_float guarantee they will yield the same result. However, this would require us to have an FMA intrinsic as well (to be able to express the polynomials using Horner's method with fma-ops). So yeah, I'm definitely in favor of accepting that strict_float is badly broken, and moving forward with this PR. |
No worries! Thanks a lot for looking into this! I'll address your feedback, questions and improvements tomorrow, such that this is all still fresh in your head. |
Update: Nevermind, I found another approach that works well for now. |
3211d3a
to
f6f7fd0
Compare
I took care of all feedback, except for the make_struct wrapper for the precision arguments. That's for later. I updated the original post for the PR on top. More info with the latest concerns and notes can be found there! |
f28a8b0
to
7000f21
Compare
a171ec1
to
6cebc56
Compare
Update: I changed it to have accessor functions that simply return the static member. That seems to be the way it's done everywhere in Halide header files. |
@alexreinking Can you assess what this macos buildbot is up to? |
…Selectively disable some tests that require strict_float on GPU backends.
4ceca2c
to
845d83a
Compare
@abadams There seems to be an issue with the strict float behavior on the WebAssembly target. It seems to be powered by LLVM, so it's weird it doesn't work. The other LLVM-powered backends seem to work fine. Any clues what might be going wrong there? Is the Wasm runtime further simplifying and not respecting the stream of instructions as-is? |
using namespace Halide; | ||
using namespace Halide::Internal; | ||
|
||
const bool use_icons = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get rid of this. If you want icons, print them for each subcase and then print a single unambiguous piece of text for the set that can be searched for via grep or in an editor. (The text is the worst case result.) Colored text would also be fine, but requires terminal support. The icons have already wasted my time as I can't search for them easily and there are a lot of test cases in the test. A flag that one edits in the source code for this is just silly. Yeah, we ought to have a library for writing tests, but we don't. Keep it simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @zvookin. No information should be emoji-only. Emojis shouldn't be sent to terminals that don't support them (e.g. on Windows, I think you need to _setmode(_fileno(stdout), _O_U8TEXT)
). Let's avoid introducing emojis to the codebase in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I just default the bool to false? I found it very useful to have these emojis. Color is a very good dimension of information when you're skimming through the many results.
src/FastMathFunctions.cpp
Outdated
Expr flip = x < make_const(type, 0.0); | ||
Expr use_cotan = abs_x > make_const(type, PI / 4.0); | ||
Expr pi_over_two_minus_abs_x; | ||
if (type == Float(64)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing both branches of the if to use strict add or sub fixes the WebAssembly failures. This is due to reassociation in fast math -- as in commenting out fp_flags.setAllowReassoc
in CodeGen_LLVM::set_fp_fast_math
makes the tests pass for wasm without this change. Thus I'm guessing strictness is legitimately needed here.
if (type == Float(64)) {
// TODO(mcourteaux): We could do split floats here too.
pi_over_two_minus_abs_x = strict_sub(make_const(type, PI_OVER_TWO), abs_x);
} else if (type == Float(32)) { // We want to do this trick always, because we invert later.
auto [hi, lo] = split_float(PI_OVER_TWO);
pi_over_two_minus_abs_x = strict_add(strict_sub(make_const(type, hi), abs_x), make_const(type, lo));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are currently not testing the f64 variants. So I'm assuming you're referring to the + make_const(type, lo)
addition needing to be a strict_add instead? I feel like that shouldn't be necessary, as I stated that the nested sub is already strict? It's optimizing away the addition, because it peaks into the strict_sub and still does constant folding? Am I fundamentally not understanding what the strict ops should do, or would you consider this a bug in LLVMs behavior? In other words, I wouldn't have expected any difference between this version, and one where the + is exchanged for a strict_add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe LLVM does reassociative rewrites/folding if the root node (+) is non-strict, without checking if the child nodes (-) are also non-strict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be very strange and render every LLVM expression in IR non-strict. For example: a binary tree of 5 levels deep of nested strict operations, followed by a non-strict + 1, destroys the whole strict tree? That's at least how I interpret your guess.
Well, I can see that if the non-strict + is directly followed by a strict +, you can tell that the secondary add is not gonna do anything if you constant fold it with the other addition. So you could argue that the "optimization" is correct, because the addition was non-strict, so LLVM decided it would not have any impact. I am not too sure if that reason is sound. However, adding strictness to the second op should be fine either way. We do want the add, so making it strict will never hurt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a set of rewrite rules that are only allowed when certain operators are tagged with certain flags. A rewrite won't be done if any of the nodes it is trying to match against don't allow the type of optimization being applied, but it does not matter if nodes that are just inputs, i.e. wildcards, do not have the flags. Those inner nodes won't see their own rewrites, but they're just treated as variables in the rewrite that is actually happening.
Here's the problematic Halide IR:
for (tan_approx_MULPE_poly1.s0.i.i, 0, t42) {
let t28 = input[tan_approx_MULPE_poly1.s0.i.i]
let t29 = (float32)round(t28*0.318310f)
let t30 = (float32)strict_sub(t28 - (t29*3.141593f), t29*-8.742278e-08f)
let t31 = (float32)abs(t30)
let t32 = 0.785398f < t31
let t33 = select(t32, (float32)strict_sub(1.570796f, t31) + -4.371139e-08f, t31)
let t35 = let t44 = (t33*t33) in (t44*((t44*0.009524f) + -0.428571f))
let t37 = let t45 = ((((t33*t33)*-0.095238f) + 1.000000f)*t33) in (select(t32, t35 + 1.000000f, t45)/select(t32, t45, t35 + 1.000000f))
tan_approx_MULPE_poly1[t43 + tan_approx_MULPE_poly1.s0.i.i] = select(t30 < 0.000000f, 0.000000f - t37, t37)
}
Look at the initialization of t33. The second argument to the select has a non-strict +
at the top level. This is then used directly with non-strict *
and +
. LLVM has to guarantee that the 1.570796f in the strict_sub
is not moved around, but it can freely move the -4.371139e-08f with regard to the other expressions. And in fact for wasm, the select is done between the t31
and strict_sub(1.570796f, t31)
without the addition of -4.371139e-08f because it has noticed that it can constant fold it with the addition of 1 because the select conditions are the same. (I didn't completely reverse engineer the wasm instructions but it is definitely doing the select without the addition of the negative constant and taking care of it later. There are other more complicated uses of t33 and it may be doing constant folding there or it may be just a dumb optimization overall. But it is allowed.)
I'm pretty sure the Float(64) case also had to be fixed to get rid of all test failures, but feel free to verify that on your own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the analysis @zvookin. I fixed the issue in the PR, while introducing f64 and f16 generation of those functions.
…t float calculations for f64 and f16.
Fix FloatImm codegen on several GPU backends. Fix gpu_float16_intrinsics test. Was not really using many float16 ops at all, because fast_pow was historically casting to float. Implement a few quick workarounds for NVIDIA not properly implementing fp16 built-in functions.
…ch failed on x87.
…n case that's marked as supported by the GPU backend.) Change printing style of float-literals to use scientific notation with enough digits to be exact. Relax performance test for fast_tanh on WebGPU. Bugfix float16 nan/inf constants on WebGPU. Separately print out compilation log in runtime/opencl as those logs can get very large, beyond the size of the HeapPrinter capacity.
The big transcendental lower update! Replaces #8388.
TODO
I still have to do:
Overview:
Fast transcendentals implemented for: sin, cos, tan, atan, atan2, exp, log, expm1, tanh, asin, acos.
Simple API to specify precision requirements. Default-initialized precision (
AUTO
without constraints) means "don't care about precision, as long as it's reasonable and fast", which gives you the highest chance of selecting a high-performance implementation based on hardware instructions. Optimization objectivesMULPE
(max ULP error), andMAE
(max absolute error) are available. Compared to previous PR, I removedMULPE_MAE
as I didn't see a good purpose for it.Tabular info on intrinsics and native functions their precision and speed, to select an appropriate implementation for lowering to something that is definitely not slower, while satisfying the precision requirements.
native_cos
,native_exp
, etc...fast::cos
,fast::exp
, etc...Tabular info measuring exact precision obtained by exhaustively iterating over all floats in the polynomial's native interval. Measured both MULPE and MAE for Float32. Precisions are not yet evaluated on f16 or f64, which is future work (which I have currently not planned). Precisions are measured by
correctness/determine_fast_function_approximation_metrics.cpp
.Performance tests validating that:
Accuracy tests validating that:
Drive-by fix for adding
libOpenCL.so.1
to the list of tested sonames for the OpenCL runtime.Review guide
Call::make_struct
node with 4 parameters (see API below). This approximation precision Call node survives until lowering pass where the transcendentals are lowered. In this pass, they are extracted again from this Call node's arguments. I conceptually like that this way, they are bundled and clearly not at the same level as the actual mathematical arguments. Is this a good approach? In order for this to work, I had to stopCSE
from extracting those precision arguments, andStrictfyFloat
from recursing down into that struct and litterstrict_float
on those numbers. I have seen theCall::bundle
intrinsic. Perhaps this one is better for that purpose? @abadamsFloat(16)
andFloat(64)
, but those are not yet implemented/tested. The polynomial approximations should work correctly (although untested) for these other data-types.native_tan()
compiles to the same three instructions as I implemented on CUDA:sin.approx.f32
,cos.approx.f32
,div.approx.f32
. I haven't investigated AMD's documentation on available hardware instructions.Concerns
exp()
(regularexp()
, notfast_exp()
) claims to be bit-exact, which is proven wrong by the pre-existing testcorrectness/vector_math.cpp
:tan_f32(vec2<f32>)
did not get converted totan_f32(first element), tan_f32(second element)
.API
Fixes #8243.