Skip to content

Commit 3967229

Browse files
committed
Merge branch 'skierat/avoid_prefix' into 'main'
Fix reshardable checkpoint format by removing 'chained_*.' prefixes Closes dl/JoC/nemo-ci#1995 See merge request ADLR/megatron-lm!3455
2 parents 2a97a16 + 8e9d56d commit 3967229

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

megatron/core/optimizer/optimizer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,15 @@ def state_dict(self):
11361136
def sharded_state_dict(
11371137
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs
11381138
):
1139+
metadata = kwargs.get('metadata') or {}
1140+
should_add_prefix = True # Backward-compatibility
1141+
if (
1142+
metadata.get('chained_optim_avoid_prefix', False)
1143+
# This condition should be True if a distributed optimizer isn't used
1144+
and metadata.get('distrib_optim_sharding_type') != 'dp_zero_gather_scatter'
1145+
):
1146+
should_add_prefix = False
1147+
11391148
if len(self.chained_optimizers) == 1:
11401149
return self.chained_optimizers[0].sharded_state_dict(
11411150
model_sharded_state_dict, is_loading, **kwargs
@@ -1147,7 +1156,8 @@ def sharded_state_dict(
11471156
optim_state_dict = optimizer.sharded_state_dict(
11481157
model_sharded_state_dict, is_loading, **kwargs
11491158
)
1150-
add_prefix_for_sharding(optim_state_dict, f'chained_{optimizer_idx}.')
1159+
if should_add_prefix:
1160+
add_prefix_for_sharding(optim_state_dict, f'chained_{optimizer_idx}.')
11511161
sharded_state_dict[optimizer_idx] = optim_state_dict
11521162
return sharded_state_dict
11531163

megatron/training/checkpointing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def _build_sharded_state_dict_metadata(args: Namespace) -> dict:
336336
metadata['distrib_optim_sharding_type'] = 'fully_sharded_model_space'
337337
else:
338338
metadata['distrib_optim_sharding_type'] = 'dp_zero_gather_scatter'
339+
metadata['chained_optim_avoid_prefix'] = True
339340
return metadata
340341

341342
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far,

0 commit comments

Comments
 (0)