Skip to content

Commit 58efe1a

Browse files
authored
fix concat codegen (#1311)
using lmhlo_disc.concat if operands are fixed shape
1 parent fbe39bc commit 58efe1a

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ struct LhloConcatenateOpConverter
8787
PatternRewriter& rewriter) const override {
8888
Operation* op = lhloOp.getOperation();
8989
if (!isFixedShape(lhloOp)) return failure();
90-
9190
auto operands = op->getOperands();
9291

9392
// TODO(yancey): support CPU place
94-
if (!placement_utils::isGpuMemRef(operands[0])) return failure();
93+
auto deviceAttr = op->getAttrOfType<StringAttr>(kDiscPlaceAssignment);
94+
if (!deviceAttr || deviceAttr.getValue() != kGpu) return failure();
9595
int num_input_operands = op->getNumOperands() - 1;
9696

9797
SmallVector<Value, 4> ptr_array;

tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,14 +1305,14 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,
13051305

13061306
b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front());
13071307
if (i == num_input_operands - 1) {
1308-
input_index[axis] = b->create<arith::SubIOp>(loc, out_idx, low_bound);
1309-
auto operand_memref = op.getOperand(i);
1308+
// we expect this branch never be executed
1309+
input_index[axis] = b->create<arith::ConstantIndexOp>(loc, 0);
13101310
auto ret_value =
13111311
check_cache ? createLoadOrUseCachedValue(
1312-
loc, b, op.getOperation(), operand_memref,
1312+
loc, b, op.getOperation(), op.getOperand(i),
13131313
input_index, b->saveInsertionPoint(), lower_config)
13141314
: createMaySpecificLoad(*b, loc, op.getOperation(),
1315-
operand_memref, input_index,
1315+
op.getOperand(i), input_index,
13161316
lower_config);
13171317
b->create<scf::YieldOp>(loc, ret_value);
13181318
} else {
@@ -1360,7 +1360,24 @@ Value elementalLower<lmhlo_disc::ConcatenateOp>(OpBuilder* b, Location loc,
13601360

13611361
auto int_ptr =
13621362
b->create<memref::LoadOp>(loc, ptr_array, ValueRange{operand_index});
1363-
Type ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx));
1363+
auto elem_ty = out.getType().cast<MemRefType>().getElementType();
1364+
// if elem_ty is bf16
1365+
Type ptr_type;
1366+
if (elem_ty.isBF16()) {
1367+
ptr_type = LLVM::LLVMPointerType::get(FloatType::getBF16(ctx));
1368+
} else if (elem_ty.isF16()) {
1369+
ptr_type = LLVM::LLVMPointerType::get(FloatType::getF16(ctx));
1370+
} else if (elem_ty.isF32()) {
1371+
ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx));
1372+
} else if (elem_ty.isInteger(32) || elem_ty.isInteger(64) ||
1373+
elem_ty.isInteger(8)) {
1374+
ptr_type = LLVM::LLVMPointerType::get(
1375+
IntegerType::get(ctx, elem_ty.getIntOrFloatBitWidth()));
1376+
} else {
1377+
op.emitError("unsupported element type for ConcatenateOp");
1378+
return Value(nullptr);
1379+
}
1380+
13641381
auto llvm_ptr = b->create<LLVM::IntToPtrOp>(loc, ptr_type, int_ptr);
13651382

13661383
SmallVector<Value, 4> input_index;

tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22

33
module @main attributes {gpu.container_module} {
44
func.func @test_concat(%arg0: memref<2x16xf32, #gpu.address_space<global>>, %arg1: memref<2x16xf32, #gpu.address_space<global>>, %out : memref<4x16xf32, #gpu.address_space<global>>) -> memref<4x16xf32, #gpu.address_space<global>> attributes {gpu.kernel} {
5-
// CHECK: lmhlo_disc.concatenate
6-
"lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64 } : (memref<2x16xf32, #gpu.address_space<global>>, memref<2x16xf32, #gpu.address_space<global>>, memref<4x16xf32, #gpu.address_space<global>>) -> ()
5+
// CHECK: memref.alloc() : memref<3xi64>
6+
// CHECK: "disc_ral.get_pointer"(%arg0)
7+
// CHECK: memref.store %0, %alloc[%c0]
8+
// CHECK: "disc_ral.get_pointer"(%arg1)
9+
// CHECK: memref.store %1, %alloc[%c1]
10+
// CHECK: memref.alloc() : memref<3xi64>
11+
// CHECK: "lmhlo_disc.h2d"(%alloc, %alloc_0)
12+
// CHECK: "lmhlo_disc.concatenate"(%arg0, %arg1, %alloc_0, %arg2)
13+
"lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64, disc.device = "gpu"} : (memref<2x16xf32, #gpu.address_space<global>>, memref<2x16xf32, #gpu.address_space<global>>, memref<4x16xf32, #gpu.address_space<global>>) -> ()
714
return %out : memref<4x16xf32, #gpu.address_space<global>>
815
}
916
}

0 commit comments

Comments
 (0)