Skip to content

Commit 44fa0ea

Browse files
committed
Merge branch 'boxiangw/fsdp2-te2-fp8-warning' into 'main'
Add TE 2.0 check for FSDP2 with fp8-param-gather See merge request ADLR/megatron-lm!3338
2 parents ee082bf + d2c2210 commit 44fa0ea

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

megatron/training/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from megatron.core.utils import (
3232
get_torch_version,
33+
is_te_min_version,
3334
is_torch_min_version,
3435
)
3536
from megatron.training.activations import squared_relu
@@ -619,6 +620,11 @@ def validate_args(args, defaults={}):
619620
assert os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
620621
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one'
621622

623+
if args.fp8_param_gather and is_te_min_version("2.0.0"):
624+
args.fp8_param_gather = False
625+
warnings.warn('FSDP2 FP8 param gather is not supported yet in TE 2.0, will fallback to bf16' \
626+
'all_gather instead, turning off fp8_param_gather')
627+
622628
if args.overlap_param_gather_with_optimizer_step:
623629
assert args.use_distributed_optimizer, \
624630
'--overlap-param-gather-with-optimizer-step only supported with distributed optimizer'

0 commit comments

Comments
 (0)