BugFix: FP8 Communication Mismatch with --first-last-layers-bf16 in tp-comm-overlap #1703
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem Description

When using FP8 tensorwise quantization with --first-last-layers-bf16 configuration, enabling tp-comm-overlap causes training failures due to data format misalignment in communication operations.
Issue Details:
env: TE 2.4/2.3 Megatron-LM main
Configuration: FP8 tensorwise + --first-last-layers-bf16 +tp-comm-overlap enabled
Root Cause: Communication operators still use FP8 format even though head/tail layers are configured to use BF16
Symptom: Training crashes with data format mismatch errors during inter-device communication
Workaround: Disabling TP overlap allows normal training to proceed
Root Cause Analysis
The communication logic doesn't properly detect when head/tail layers are using BF16 format, leading to:
Head/tail layers produce BF16 tensors
Communication operators expect FP8 format
Format mismatch causes runtime errors during tensor parallel communication
Solution
This PR addresses the issue by disabling TP communication overlap when --first-last-layers-bf16 is enabled:
Current Fix:
Disable TP communication overlap for configurations using --first-last-layers-bf16
This prevents the data format mismatch by avoiding the problematic communication path
Future Enhancement Consideration:
The ideal long-term solution would be to implement precision-aware communication operators that can:
Dynamically select appropriate communication algorithms based on each layer's actual precision
Use BF16 communication operators for first/last layers when --first-last-layers-bf16 is enabled
Use FP8 communication operators for intermediate layers using FP8 tensorwise quantization
Enable seamless TP overlap regardless of mixed-precision configurations