Skip to content

Commit a4a7531

Browse files
authored
Consider *all* Exprs a func uses, not just the RHS, in Li2018 (#8326)
Fixes #8312
1 parent cab27d8 commit a4a7531

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

src/DerivativeUtils.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ map<string, Box> inference_bounds(const vector<Func> &funcs,
253253
for (auto it = order.rbegin(); it != order.rend(); it++) {
254254
Func func = Func(env[*it]);
255255
// We should already have the bounds of this function
256-
internal_assert(bounds.find(*it) != bounds.end());
256+
internal_assert(bounds.find(*it) != bounds.end()) << *it << "\n";
257257
const Box &current_bounds = bounds[*it];
258258
internal_assert(func.args().size() == current_bounds.size());
259259
// We know the range for each argument of this function
@@ -262,29 +262,33 @@ map<string, Box> inference_bounds(const vector<Func> &funcs,
262262
scope.push(arg, current_bounds[i]);
263263
}
264264
// Propagate the bounds
265-
for (int update_id = -1; update_id < func.num_update_definitions(); update_id++) {
266-
// For each rhs expression
267-
Tuple tuple = update_id == -1 ? func.values() : func.update_values(update_id);
268-
for (const auto &expr : tuple.as_vector()) {
269-
// For all the immediate dependencies of this expression,
270-
// find the required ranges
271-
map<string, Box> update_bounds =
272-
boxes_required(expr, scope, func_value_bounds);
273-
// Loop over the dependencies
274-
for (const auto &it : update_bounds) {
275-
if (it.first == func.name()) {
276-
// Skip self reference
277-
continue;
278-
}
279-
// Update the bounds, if not exists then create a new one
280-
auto found = bounds.find(it.first);
281-
if (found == bounds.end()) {
282-
bounds[it.first] = it.second;
283-
} else {
284-
Box new_box = box_union(found->second, it.second);
285-
bounds[it.first] = new_box;
286-
}
287-
}
265+
class CollectExprs : public IRMutator {
266+
public:
267+
using IRMutator::mutate;
268+
Expr mutate(const Expr &e) override {
269+
exprs.push_back(e);
270+
return e;
271+
}
272+
std::vector<Expr> exprs;
273+
} expr_collector;
274+
func.function().mutate(&expr_collector);
275+
276+
Expr bundle = Call::make(Int(32), Call::bundle, expr_collector.exprs, Call::PureIntrinsic);
277+
map<string, Box> update_bounds =
278+
boxes_required(bundle, scope, func_value_bounds);
279+
// Loop over the dependencies
280+
for (const auto &it : update_bounds) {
281+
if (it.first == func.name()) {
282+
// Skip self reference
283+
continue;
284+
}
285+
// Update the bounds, if not exists then create a new one
286+
auto found = bounds.find(it.first);
287+
if (found == bounds.end()) {
288+
bounds[it.first] = it.second;
289+
} else {
290+
Box new_box = box_union(found->second, it.second);
291+
bounds[it.first] = new_box;
288292
}
289293
}
290294
for (int i = 0; i < (int)current_bounds.size(); i++) {

0 commit comments

Comments
 (0)