Skip to content

Commit bc0994a

Browse files
authored
support Llama bf16 amp training (#1293)
support llama bf16 amp training
1 parent 52cb669 commit bc0994a

File tree

13 files changed

+107
-25
lines changed

13 files changed

+107
-25
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
TF_NEED_CUDA=1
2+
TF_CUDA_CLANG=0
3+
TF_CUDA_VERSION=11.8
4+
TF_CUDNN_VERSION=8
5+
TF_CUDA_COMPUTE_CAPABILITIES="6.0,6.1,7.0,7.5,8.0,8.6"
6+
TF_NEED_TENSORRT=0
7+
TF_NEED_ROCM=0
8+
TF_SET_ANDROID_WORKSPACE=0

tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ void LaunchTransposeKernel(cudaStream_t stream, T* input,
5757
template void LaunchTransposeKernel<float>(cudaStream_t stream, float* input,
5858
std::vector<int64_t> input_dims,
5959
float* output);
60-
6160
template void LaunchTransposeKernel<Eigen::half>(
6261
cudaStream_t stream, Eigen::half* input, std::vector<int64_t> input_dims,
6362
Eigen::half* output);
63+
template void LaunchTransposeKernel<Eigen::bfloat16>(
64+
cudaStream_t stream, Eigen::bfloat16* input,
65+
std::vector<int64_t> input_dims, Eigen::bfloat16* output);
6466
#endif
6567

6668
} // namespace ral

tao_compiler/mlir/custom_ops/transpose_impl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ void ral_transpose(ExecutionContext* ctx, void* stream_handle,
6161

6262
LaunchTransposeKernel<T>(stream, d_in, input_dims, d_out);
6363
}
64-
64+
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
6565
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 2>);
6666
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 3>);
6767
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 2>);
6868
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 3>);
69+
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 2>);
70+
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 3>);
6971
#endif
7072

7173
} // namespace ral

tao_compiler/mlir/disc/disc_compiler.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,12 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
563563
// optimization. Then this pass will be enabled by default.
564564
pm.addNestedPass<FuncOp>(disc_ral::createForLoopUnrollInterleavePass());
565565
}
566-
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass());
567566
// Origin: https://reviews.llvm.org/D147585
568567
// Should be removed after rebasing to the latest llvm head
569568
pm.addNestedPass<FuncOp>(disc_ral::createDiscBF16ExpansionPass());
569+
mlir::arith::ArithExpandOpsOptions arith_option;
570+
arith_option.includeBf16 = true;
571+
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass(arith_option));
570572
pm.addNestedPass<FuncOp>(mlir::memref::createFoldMemRefAliasOpsPass());
571573

572574
// Flatten multi dim memref accesses to its 1D format to enable more

tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
6464
Type resultETy = getElementTypeOrSelf(resultTy);
6565

6666
if (!operandETy.isBF16() || !resultETy.isF32()) {
67-
return failure();
67+
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
6868
}
6969

7070
Type i16Ty = b.getI16Type();
@@ -98,21 +98,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
9898
Type resultETy = getElementTypeOrSelf(resultTy);
9999

100100
if (!operandETy.isF32() || !resultETy.isBF16()) {
101-
return failure();
101+
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
102102
}
103103

104-
#if defined(TAO_AARCH64)
105-
if (isBFCVTEnabled()) {
106-
auto intrinsicName =
107-
StringAttr::get(rewriter.getContext(), "llvm.aarch64.neon.bfcvt");
108-
SmallVector<Value, 2> args;
109-
args.push_back(operand);
110-
rewriter.replaceOpWithNewOp<LLVM::CallIntrinsicOp>(op, resultETy,
111-
intrinsicName, args);
112-
return success();
113-
}
114-
#endif
115-
116104
Type i1Ty = b.getI1Type();
117105
Type i16Ty = b.getI16Type();
118106
Type i32Ty = b.getI32Type();
@@ -125,7 +113,21 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
125113
}
126114

127115
Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
116+
117+
// fast rouding algorithm to trunct fp32 to bf16:
118+
// uint32_t lsb = (input >> 16) & 1;
119+
// uint32_t rounding_bias = 0x7fff + lsb;
120+
// input += rounding_bias;
121+
// output.value = static_cast<uint16_t>(input >> 16);
122+
// ref:
123+
// htps://hhhhhojeihsu.github.io/tensorflow_1.8_woboq/tensorflow_1.8_xla/tensorflow/tensorflow/core/lib/bfloat16/bfloat16.h.html#196
128124
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
125+
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
126+
Value lsb = b.create<arith::ShRUIOp>(bitcast, c16);
127+
lsb = b.create<arith::AndIOp>(lsb, c1);
128+
Value rouding_bias = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
129+
rouding_bias = b.create<arith::AddIOp>(rouding_bias, lsb);
130+
bitcast = b.create<arith::AddIOp>(bitcast, rouding_bias);
129131
Value shr = b.create<arith::ShRUIOp>(bitcast, c16);
130132
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
131133
Value result = b.create<arith::BitcastOp>(resultTy, trunc);

tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
lmhlo::AbsOp, lmhlo::CeilOp, lmhlo::FloorOp, lmhlo::ConvertOp, lmhlo::CosineOp,
2323
lmhlo::ExpOp, lmhlo::LogOp, lmhlo::NegOp, lmhlo::RsqrtOp, lmhlo::SqrtOp,
2424
lmhlo::SignOp, lmhlo::TanhOp, lmhlo::LogisticOp, lmhlo::Log1pOp,
25-
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp,
25+
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp, lmhlo::BitcastConvertOp,
2626

2727
// Binary Elementwise Ops
2828
lmhlo::AddOp, lmhlo::DivOp, lmhlo::MaxOp, lmhlo::MinOp, lmhlo::MulOp,

tao_compiler/mlir/disc/transforms/disc_to_llvm.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,18 @@ LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) {
8787
out.append(Twine("i").concat(Twine(int_type.getWidth())).str());
8888
}
8989
} else if (auto fp_type = t.dyn_cast<FloatType>()) {
90-
out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
90+
if (fp_type.isF16()) {
91+
out.append("f16");
92+
} else if (fp_type.isBF16()) {
93+
out.append("bf16");
94+
} else if (fp_type.isF32()) {
95+
out.append("f32");
96+
} else if (fp_type.isF64()) {
97+
out.append("f64");
98+
} else {
99+
return failure();
100+
}
101+
// out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
91102
} else if (auto ctx_type = t.dyn_cast<RalExecutionContextType>() ||
92103
t == llvm_i8ptr_type || t == llvm_ptr_type) {
93104
out.append("pvoid");

tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,10 +1231,11 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,
12311231

12321232
Value zero_element;
12331233
if (result_elem_type.isF16() || result_elem_type.isF32() ||
1234-
result_elem_type.isF64()) {
1234+
result_elem_type.isF64() || result_elem_type.isBF16()) {
12351235
auto float_result_elem_type = result_elem_type.cast<FloatType>();
12361236
zero_element = b->create<arith::ConstantFloatOp>(
1237-
loc, APFloat::getZero(float_result_elem_type.getFloatSemantics()),
1237+
loc,
1238+
APFloat::getZero(float_result_elem_type.getFloatSemantics(), false),
12381239
float_result_elem_type);
12391240
} else if (result_elem_type.isSignlessInteger() ||
12401241
result_elem_type.isSignedInteger() ||
@@ -1304,7 +1305,16 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,
13041305

13051306
b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front());
13061307
if (i == num_input_operands - 1) {
1307-
b->create<scf::YieldOp>(loc, zero_element); // expect never used
1308+
input_index[axis] = b->create<arith::SubIOp>(loc, out_idx, low_bound);
1309+
auto operand_memref = op.getOperand(i);
1310+
auto ret_value =
1311+
check_cache ? createLoadOrUseCachedValue(
1312+
loc, b, op.getOperation(), operand_memref,
1313+
input_index, b->saveInsertionPoint(), lower_config)
1314+
: createMaySpecificLoad(*b, loc, op.getOperation(),
1315+
operand_memref, input_index,
1316+
lower_config);
1317+
b->create<scf::YieldOp>(loc, ret_value);
13081318
} else {
13091319
b->create<scf::YieldOp>(loc, if_inbound_ops[i + 1].getResults());
13101320
}

tao_compiler/mlir/ral/context/base/base_context.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ void ral_base_cuda_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
235235
TAO_RAL_API(tao::ral::kRalSendOutput, "cpu", ral_base_cuda_send_output_0d<T>);
236236

237237
RAL_REGISTER_IO_FUNC_0D(float);
238+
RAL_REGISTER_IO_FUNC_0D(bfloat16);
238239
RAL_REGISTER_IO_FUNC_0D(double);
239240
RAL_REGISTER_IO_FUNC_0D(int8_t);
240241
RAL_REGISTER_IO_FUNC_0D(int32_t);
@@ -306,5 +307,13 @@ RAL_REGISTER_IO_FUNC(Eigen::half, 5);
306307
RAL_REGISTER_IO_FUNC(Eigen::half, 6);
307308
RAL_REGISTER_IO_FUNC(Eigen::half, 7);
308309
RAL_REGISTER_IO_FUNC(Eigen::half, 8);
310+
RAL_REGISTER_IO_FUNC(bfloat16, 1);
311+
RAL_REGISTER_IO_FUNC(bfloat16, 2);
312+
RAL_REGISTER_IO_FUNC(bfloat16, 3);
313+
RAL_REGISTER_IO_FUNC(bfloat16, 4);
314+
RAL_REGISTER_IO_FUNC(bfloat16, 5);
315+
RAL_REGISTER_IO_FUNC(bfloat16, 6);
316+
RAL_REGISTER_IO_FUNC(bfloat16, 7);
317+
RAL_REGISTER_IO_FUNC(bfloat16, 8);
309318
} // namespace ral
310319
} // namespace tao

tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ RAL_REGISTER_BITCAST_FUNC_0D(double);
370370
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
371371
RAL_REGISTER_BITCAST_FUNC_0D(int64_t);
372372
RAL_REGISTER_BITCAST_FUNC_0D(bool);
373+
RAL_REGISTER_BITCAST_FUNC_0D(bfloat16);
373374
RAL_REGISTER_BITCAST_FUNC(float, 1);
374375
RAL_REGISTER_BITCAST_FUNC(float, 2);
375376
RAL_REGISTER_BITCAST_FUNC(float, 3);

tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle,
671671
TAO_RAL_API(tao::ral::kRalBitcast, "gpu", ral_base_cuda_bitcast_0d<T, 8, 0>);
672672

673673
RAL_REGISTER_BITCAST_FUNC_0D(Eigen::half);
674+
RAL_REGISTER_BITCAST_FUNC_0D(Eigen::bfloat16);
674675
RAL_REGISTER_BITCAST_FUNC_0D(float);
675676
RAL_REGISTER_BITCAST_FUNC_0D(double);
676677
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
@@ -684,6 +685,14 @@ RAL_REGISTER_BITCAST_FUNC(Eigen::half, 5);
684685
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 6);
685686
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 7);
686687
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 8);
688+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 1);
689+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 2);
690+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 3);
691+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 4);
692+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 5);
693+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 6);
694+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 7);
695+
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 8);
687696
RAL_REGISTER_BITCAST_FUNC(float, 1);
688697
RAL_REGISTER_BITCAST_FUNC(float, 2);
689698
RAL_REGISTER_BITCAST_FUNC(float, 3);
@@ -745,5 +754,6 @@ TAO_RAL_API(tao::ral::gpu::kRalGpuSyncOnStream, "gpu",
745754
TAO_RAL_API(tao::ral::gpu::kRalGpuMemset, "gpu", ral_base_cuda_memset);
746755

747756
} // namespace gpu
757+
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
748758
} // namespace ral
749759
} // namespace tao

tao_compiler/mlir/ral/context/stream_executor_based_impl.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ inline se::blas::ComputationType NativeTypeToBlasType<double>() {
151151
return se::blas::ComputationType::kF64;
152152
}
153153

154+
template <>
155+
inline se::blas::ComputationType NativeTypeToBlasType<Eigen::bfloat16>() {
156+
return se::blas::ComputationType::kBF16AsF32;
157+
}
158+
154159
// The template was introduced, because not all instantiation of
155160
// DoGemmWithAlgorithm template arguments was support by ThenBlasGemv.
156161
template <typename InT, typename OutT, typename AlphaBeta>
@@ -293,7 +298,8 @@ se::blas::AlgorithmType tuningGemm(se::Stream* stream,
293298
se::blas::ProfileResult profile_result;
294299
DoGemmWithAlgorithm<InT, OutT, AlphaBeta>(
295300
/*batch_size*/ 1, lhs_matrix, rhs_matrix, output_matrix,
296-
/*alpha*/ 1., /*beta*/ 0., stream, algorithm, &profile_result);
301+
/*alpha*/ AlphaBeta(1.0), /*beta*/ AlphaBeta(0.0), stream, algorithm,
302+
&profile_result);
297303

298304
if (!profile_result.is_valid()) {
299305
TAO_VLOG(1) << "algo: " << algorithm << " is invalid.";
@@ -1952,10 +1958,18 @@ void ral_qconv(ExecutionContext* ctx, void* stream_handle,
19521958
} // namespace gpu
19531959

19541960
// gemm ops
1955-
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<float, float>);
1956-
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<double, double, double>);
1961+
#ifndef DISC_BUILD_FROM_TF_BRIDGE
1962+
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
19571963
TAO_RAL_API("ral_gemm", "gpu",
19581964
gpu::se_impl::ral_gemm<Eigen::half, Eigen::half>);
1965+
TAO_RAL_API("ral_gemm", "gpu",
1966+
gpu::se_impl::ral_gemm<Eigen::bfloat16, Eigen::bfloat16, float>);
1967+
TAO_RAL_API("ral_gemm", "gpu",
1968+
gpu::se_impl::ral_batch_gemm<Eigen::bfloat16, Eigen::bfloat16, 3>);
1969+
#endif
1970+
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<float, float>);
1971+
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<double, double, double>);
1972+
19591973
TAO_RAL_API("ral_qgemm", "gpu", gpu::se_impl::ral_qgemm);
19601974
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 3>);
19611975
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 4>);
@@ -1965,6 +1979,7 @@ TAO_RAL_API("ral_gemm", "gpu",
19651979
gpu::se_impl::ral_batch_gemm<double, double, 4, double>);
19661980
TAO_RAL_API("ral_gemm", "gpu",
19671981
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 3>);
1982+
19681983
TAO_RAL_API("ral_gemm", "gpu",
19691984
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 4>);
19701985
#ifdef BLAZE_OPT

tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ void ral_tf_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
909909
TAO_RAL_API(::tao::ral::kRalBitcast, "cpu", ral_tf_bitcast_0d<T, 8, 0>);
910910

911911
RAL_REGISTER_IO_FUNC_0D(float);
912+
RAL_REGISTER_IO_FUNC_0D(bfloat16);
912913
RAL_REGISTER_IO_FUNC_0D(double);
913914
RAL_REGISTER_IO_FUNC_0D(Eigen::half);
914915
RAL_REGISTER_IO_FUNC_0D(int8_t);
@@ -980,13 +981,22 @@ RAL_REGISTER_IO_FUNC(bool, 5);
980981
RAL_REGISTER_IO_FUNC(bool, 6);
981982
RAL_REGISTER_IO_FUNC(bool, 7);
982983
RAL_REGISTER_IO_FUNC(bool, 8);
984+
RAL_REGISTER_IO_FUNC(bfloat16, 1);
985+
RAL_REGISTER_IO_FUNC(bfloat16, 2);
986+
RAL_REGISTER_IO_FUNC(bfloat16, 3);
987+
RAL_REGISTER_IO_FUNC(bfloat16, 4);
988+
RAL_REGISTER_IO_FUNC(bfloat16, 5);
989+
RAL_REGISTER_IO_FUNC(bfloat16, 6);
990+
RAL_REGISTER_IO_FUNC(bfloat16, 7);
991+
RAL_REGISTER_IO_FUNC(bfloat16, 8);
983992

984993
} // namespace tensorflow
985994

986995
namespace tao {
987996
namespace ral {
988997

989998
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
999+
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
9901000

9911001
TAO_RAL_API(::tao::ral::cpu::kRalCpuAlloc, "cpu", tensorflow::ral_tf_cpu_alloc);
9921002
TAO_RAL_API(::tao::ral::cpu::kRalCpuAllocPersistent, "cpu",

0 commit comments

Comments
 (0)