Skip to content

BugFix: FP8 Communication Mismatch with --first-last-layers-bf16 in tp-comm-overlap #1703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

xiaomin-D
Copy link

Problem Description
When using FP8 tensorwise quantization with --first-last-layers-bf16 configuration, enabling tp-comm-overlap causes training failures due to data format misalignment in communication operations.
Clipboard_Screenshot_1753326925

Issue Details:
env: TE 2.4/2.3 Megatron-LM main
Configuration: FP8 tensorwise + --first-last-layers-bf16 +tp-comm-overlap enabled
Root Cause: Communication operators still use FP8 format even though head/tail layers are configured to use BF16
Symptom: Training crashes with data format mismatch errors during inter-device communication
Workaround: Disabling TP overlap allows normal training to proceed
Root Cause Analysis
The communication logic doesn't properly detect when head/tail layers are using BF16 format, leading to:

Head/tail layers produce BF16 tensors
Communication operators expect FP8 format
Format mismatch causes runtime errors during tensor parallel communication

Solution
This PR addresses the issue by disabling TP communication overlap when --first-last-layers-bf16 is enabled:

Current Fix:

Disable TP communication overlap for configurations using --first-last-layers-bf16
This prevents the data format mismatch by avoiding the problematic communication path
Future Enhancement Consideration:
The ideal long-term solution would be to implement precision-aware communication operators that can:

Dynamically select appropriate communication algorithms based on each layer's actual precision
Use BF16 communication operators for first/last layers when --first-last-layers-bf16 is enabled
Use FP8 communication operators for intermediate layers using FP8 tensorwise quantization
Enable seamless TP overlap regardless of mixed-precision configurations

@DAISY-gh
Copy link
Collaborator

@xiaomin-D we are pulling this PR internally to review and run Ci. Thanks.

@sbhavani sbhavani added bug Something isn't working module: transformer engine labels Jul 24, 2025
@xiaomin-D
Copy link
Author

@xiaomin-D we are pulling this PR internally to review and run Ci. Thanks.

@DAISY-gh Is there any progress? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: transformer engine
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants