Skip to content

Commit cab27d8

Browse files
Fix horrifying bug in lossless_cast of a subtract (#8155)
* Fix horrifying bug in lossless_cast of a subtract * Use constant integer intervals to analyze safety for lossless_cast TODO: - Dedup the constant integer code with the same code in the simplifier. - Move constant interval arithmetic operations out of the class. - Make the ConstantInterval part of the return type of lossless_cast (and turn it into an inner helper) so that it isn't constantly recomputed. * Fix ARM and HVX instruction selection Also added more TODOs * Using constant_integer_bounds to strengthen FindIntrinsics In particular, we can do better instruction selection for pmulhrsw * Move new classes to new files Also fix up Monotonic.cpp * Make the simplifier use ConstantInterval * Handle bounds of narrower types in the simplifier too * Fix * operator. Add min/max/mod * Add cache for constant bounds queries * Fix ConstantInterval multiplication * Add a simplifier rule which is apparently now necessary * Misc cleanups and test improvements * Add missing files * Account for more aggressive simplification in fuse test * Remove redundant helpers * Add missing comment * clear_bounds_info -> clear_expr_info * Remove bad TODO I can't think of a single case that could cause this * It's too late to change the semantics of fixed point intrinsics * Fix some UB * Stronger assert in Simplify_Div * Delete bad rewrite rules * Fix bad test when lowering mul_shift_right b_shift + b_shift < missing_q * Avoid UB in lowering of rounding_shift_right/left * Add shifts to the lossless cast fuzzer This required a more careful signed-integer-overflow detection routine * Fix bug in lossless_negate * Add constant interval test * Rework find_mpy_ops to handle more structures * Fix bugs in lossless_cast * Fix mul_shift_right expansion * Delete commented-out code * Don't introduce out-of-range shifts in lossless_cast * Some constant folding only happens after lowering intrinsics in codegen --------- Co-authored-by: Steven Johnson <[email protected]>
1 parent 9b703f3 commit cab27d8

12 files changed

+514
-392
lines changed

src/CodeGen_ARM.cpp

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,50 +1212,42 @@ void CodeGen_ARM::visit(const Add *op) {
12121212
Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc");
12131213
Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc");
12141214

1215-
// clang-format off
1215+
Expr ma_i8 = widening_mul(a_i8, ac_i8);
1216+
Expr mb_i8 = widening_mul(b_i8, bc_i8);
1217+
Expr mc_i8 = widening_mul(c_i8, cc_i8);
1218+
Expr md_i8 = widening_mul(d_i8, dc_i8);
1219+
1220+
Expr ma_u8 = widening_mul(a_u8, ac_u8);
1221+
Expr mb_u8 = widening_mul(b_u8, bc_u8);
1222+
Expr mc_u8 = widening_mul(c_u8, cc_u8);
1223+
Expr md_u8 = widening_mul(d_u8, dc_u8);
1224+
12161225
static const Pattern patterns[] = {
1217-
// If we had better normalization, we could drastically reduce the number of patterns here.
12181226
// Signed variants.
1219-
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"},
1220-
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)},
1221-
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
1222-
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
1223-
{init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
1224-
// Signed variants (associative).
1225-
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"},
1226-
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)},
1227-
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
1228-
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
1229-
{init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
1227+
{(init_i32 + widening_add(ma_i8, mb_i8)) + widening_add(mc_i8, md_i8), "dot_product"},
1228+
{init_i32 + (widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8)), "dot_product"},
1229+
{widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8), "dot_product"},
1230+
12301231
// Unsigned variants.
1231-
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"},
1232-
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)},
1233-
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
1234-
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
1235-
{init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
1236-
// Unsigned variants (associative).
1237-
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"},
1238-
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)},
1239-
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
1240-
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
1241-
{init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
1232+
{(init_u32 + widening_add(ma_u8, mb_u8)) + widening_add(mc_u8, md_u8), "dot_product"},
1233+
{init_u32 + (widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8)), "dot_product"},
1234+
{widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8), "dot_product"},
12421235
};
1243-
// clang-format on
12441236

12451237
std::map<std::string, Expr> matches;
12461238
for (const Pattern &p : patterns) {
12471239
if (expr_match(p.pattern, op, matches)) {
1248-
Expr init = matches["init"];
1249-
Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]});
1250-
// Coefficients can be 1 if not in the pattern.
1251-
Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes()));
1252-
// This hideous code pattern implements fetching a
1253-
// default value if the map doesn't contain a key.
1254-
Expr _ac = matches.try_emplace("ac", one).first->second;
1255-
Expr _bc = matches.try_emplace("bc", one).first->second;
1256-
Expr _cc = matches.try_emplace("cc", one).first->second;
1257-
Expr _dc = matches.try_emplace("dc", one).first->second;
1258-
Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc});
1240+
Expr init;
1241+
auto it = matches.find("init");
1242+
if (it == matches.end()) {
1243+
init = make_zero(op->type);
1244+
} else {
1245+
init = it->second;
1246+
}
1247+
Expr values = Shuffle::make_interleave({matches["a"], matches["b"],
1248+
matches["c"], matches["d"]});
1249+
Expr coeffs = Shuffle::make_interleave({matches["ac"], matches["bc"],
1250+
matches["cc"], matches["dc"]});
12591251
value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs});
12601252
if (value) {
12611253
return;

src/CodeGen_X86.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,8 @@ void CodeGen_X86::visit(const Cast *op) {
538538
};
539539

540540
// clang-format off
541-
static const Pattern patterns[] = {
542-
// This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't
541+
static Pattern patterns[] = {
542+
// This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't
543543
// saturate the result.
544544
{"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))},
545545

@@ -736,7 +736,12 @@ void CodeGen_X86::visit(const Call *op) {
736736
// Handle edge case of possible overflow.
737737
// See https://github.com/halide/Halide/pull/7129/files#r1008331426
738738
// On AVX512 (and with enough lanes) we can use a mask register.
739-
if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
739+
ConstantInterval ca = constant_integer_bounds(a);
740+
ConstantInterval cb = constant_integer_bounds(b);
741+
if (!ca.contains(-32768) || !cb.contains(-32768)) {
742+
// Overflow isn't possible
743+
pmulhrs.accept(this);
744+
} else if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
740745
Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs);
741746
expr.accept(this);
742747
} else {

src/Expr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const IntImm *IntImm::make(Type t, int64_t value) {
88
internal_assert(t.is_int() && t.is_scalar())
99
<< "IntImm must be a scalar Int\n";
1010
internal_assert(t.bits() >= 1 && t.bits() <= 64)
11-
<< "IntImm must have between 1 and 64 bits\n";
11+
<< "IntImm must have between 1 and 64 bits: " << t << "\n";
1212

1313
// Normalize the value by dropping the high bits.
1414
// Since left-shift of negative value is UB in C++, cast to uint64 first;
@@ -28,7 +28,7 @@ const UIntImm *UIntImm::make(Type t, uint64_t value) {
2828
internal_assert(t.is_uint() && t.is_scalar())
2929
<< "UIntImm must be a scalar UInt\n";
3030
internal_assert(t.bits() >= 1 && t.bits() <= 64)
31-
<< "UIntImm must have between 1 and 64 bits\n";
31+
<< "UIntImm must have between 1 and 64 bits " << t << "\n";
3232

3333
// Normalize the value by dropping the high bits
3434
value <<= (64 - t.bits());

0 commit comments

Comments
 (0)