Skip to content

BugFix: FP8 Communication Mismatch with --first-last-layers-bf16 in tp-comm-overlap #1703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ def build_layer(layer_spec, layer_number):
else:
layer_config = self.config

# If the first or last layers are bf16, disable tp comm overlap for them
if layer_config.first_last_layers_bf16 and layer_config.tp_comm_overlap:
num_bf16_layers_at_start = layer_config.num_layers_at_start_in_bf16
num_bf16_layers_at_end = layer_config.num_layers_at_end_in_bf16
# Since global layer number is 1-based, check if current layer is first or last
is_first_layer = global_layer_number <= num_bf16_layers_at_start
is_last_layer = global_layer_number > (layer_config.num_layers - num_bf16_layers_at_end)

if is_first_layer or is_last_layer:
# Create a copy of config with tp comm overlap disabled for this layer
from dataclasses import replace
layer_config = replace(layer_config, tp_comm_overlap=False)

fp8_init_context = get_fp8_context(layer_config, global_layer_number - 1, is_init=True)
with fp8_init_context:
module = build_module(
Expand Down