@@ -1305,14 +1305,14 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,
1305
1305
1306
1306
b->setInsertionPointToEnd (&if_inbound_ops[i].getElseRegion ().front ());
1307
1307
if (i == num_input_operands - 1 ) {
1308
- input_index[axis] = b-> create <arith::SubIOp>(loc, out_idx, low_bound);
1309
- auto operand_memref = op. getOperand (i );
1308
+ // we expect this branch never be executed
1309
+ input_index[axis] = b-> create <arith::ConstantIndexOp>(loc, 0 );
1310
1310
auto ret_value =
1311
1311
check_cache ? createLoadOrUseCachedValue (
1312
- loc, b, op.getOperation (), operand_memref ,
1312
+ loc, b, op.getOperation (), op. getOperand (i) ,
1313
1313
input_index, b->saveInsertionPoint (), lower_config)
1314
1314
: createMaySpecificLoad (*b, loc, op.getOperation (),
1315
- operand_memref , input_index,
1315
+ op. getOperand (i) , input_index,
1316
1316
lower_config);
1317
1317
b->create <scf::YieldOp>(loc, ret_value);
1318
1318
} else {
@@ -1360,7 +1360,24 @@ Value elementalLower<lmhlo_disc::ConcatenateOp>(OpBuilder* b, Location loc,
1360
1360
1361
1361
auto int_ptr =
1362
1362
b->create <memref::LoadOp>(loc, ptr_array, ValueRange{operand_index});
1363
- Type ptr_type = LLVM::LLVMPointerType::get (FloatType::getF32 (ctx));
1363
+ auto elem_ty = out.getType ().cast <MemRefType>().getElementType ();
1364
+ // if elem_ty is bf16
1365
+ Type ptr_type;
1366
+ if (elem_ty.isBF16 ()) {
1367
+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getBF16 (ctx));
1368
+ } else if (elem_ty.isF16 ()) {
1369
+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getF16 (ctx));
1370
+ } else if (elem_ty.isF32 ()) {
1371
+ ptr_type = LLVM::LLVMPointerType::get (FloatType::getF32 (ctx));
1372
+ } else if (elem_ty.isInteger (32 ) || elem_ty.isInteger (64 ) ||
1373
+ elem_ty.isInteger (8 )) {
1374
+ ptr_type = LLVM::LLVMPointerType::get (
1375
+ IntegerType::get (ctx, elem_ty.getIntOrFloatBitWidth ()));
1376
+ } else {
1377
+ op.emitError (" unsupported element type for ConcatenateOp" );
1378
+ return Value (nullptr );
1379
+ }
1380
+
1364
1381
auto llvm_ptr = b->create <LLVM::IntToPtrOp>(loc, ptr_type, int_ptr);
1365
1382
1366
1383
SmallVector<Value, 4 > input_index;
0 commit comments