Skip to content

Commit 5334021

Browse files
authored
avoid mix reduction type in one fusion block (#1101)
* fix row reduce fusion * update
1 parent 40d424a commit 5334021

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

tao_compiler/mlir/disc/transforms/fusion_utils.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,23 @@ Value BaseGpuFusionStrategy::getEffectiveShape(FusionPattern& target, Value v) {
14671467
return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
14681468
}
14691469

1470+
bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
1471+
FusionPattern& lhs, FusionPattern& rhs,
1472+
FusionPattern& target) {
1473+
// TODO(Yancey): support fusion with different reduction type
1474+
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
1475+
return isRank2RowReduction(op);
1476+
});
1477+
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
1478+
return isRank2ColReduction(op);
1479+
});
1480+
1481+
if (has_row_reduction && has_col_reduciton) {
1482+
return false;
1483+
}
1484+
return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
1485+
}
1486+
14701487
////////////////////// Stitch-Base CPU FusionStrategy Implemenation /////
14711488
//////////////////////////////////////////////////////////////////
14721489

tao_compiler/mlir/disc/transforms/fusion_utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,8 @@ class BaseGpuFusionStrategy : public BaseFusionStrategy {
791791
FusionPattern& target) {
792792
return lhs.isKInputFusion() && rhs.isKInputFusion();
793793
}
794-
794+
virtual bool tryFuse(ShapeAnalysis& shapeAnalysis, FusionPattern& lhs,
795+
FusionPattern& rhs, FusionPattern& target) override;
795796
Value getEffectiveShape(FusionPattern& target, Value v) override;
796797
virtual StringRef getName() override { return "BaseGpuFusionStrategy"; }
797798
};
@@ -801,7 +802,8 @@ class StitchGpuFusionStrategy : public FusionStrategy {
801802
StitchGpuFusionStrategy(const FusionOptions& options)
802803
: FusionStrategy(options) {}
803804
virtual bool isFusible(Operation* op) override;
804-
805+
virtual bool tryFuse(ShapeAnalysis& shapeAnalysis, FusionPattern& lhs,
806+
FusionPattern& rhs, FusionPattern& target) override;
805807
virtual bool initFusionPattern(ShapeAnalysis& shapeAnalysis,
806808
FusionPattern& fusion_pattern) override;
807809
virtual StringRef getName() override { return "StitchGpuFusionStrategy"; }

tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,23 @@ bool StitchGpuFusionStrategy::isFusible(Operation* op) {
5555
return true;
5656
}
5757

58+
bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
59+
FusionPattern& lhs, FusionPattern& rhs,
60+
FusionPattern& target) {
61+
// TODO(Yancey): support fusion with different reduction type
62+
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
63+
return isRank2RowReduction(op);
64+
});
65+
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
66+
return isRank2ColReduction(op);
67+
});
68+
69+
if (has_row_reduction && has_col_reduciton) {
70+
return false;
71+
}
72+
return FusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
73+
}
74+
5875
Value StitchGpuFusionStrategy::getEffectiveShape(FusionPattern& target,
5976
Value v) {
6077
Operation* result_op = target.findLastWriter(v);

tao_compiler/tao

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/workspace/BladeDISC/tao

0 commit comments

Comments
 (0)