Skip to content

Commit 2a97a16

Browse files
committed
Merge branch 'mblaz/mlp-glu-oom-fix' into 'main'
Fix DCP OOM during SwiGLU load See merge request ADLR/megatron-lm!3642
2 parents 2fbd646 + fad83fc commit 2a97a16

File tree

3 files changed

+71
-9
lines changed

3 files changed

+71
-9
lines changed

megatron/core/dist_checkpointing/strategies/torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]
374374
ten = ten.view(-1)
375375
else:
376376
for _ in range(mcore_sh_ten.prepend_axis_num):
377-
ten = ten.squeeze(0)
377+
assert ten.size(0) == 1
378+
ten = ten[0] # NOTE: ten.squeeze(0) uses more memory for FP8 tensors
378379
ret_tensors.append(ten)
379380
return ret_tensors
380381

megatron/core/transformer/mlp.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2-
2+
import gc
3+
import logging
34
import warnings
45
from dataclasses import dataclass
56
from typing import Optional, Union
@@ -34,6 +35,9 @@
3435
HAVE_TE = False
3536

3637

38+
logger = logging.getLogger(__name__)
39+
40+
3741
# pylint: disable=missing-class-docstring
3842
@dataclass
3943
class MLPSubmodules:
@@ -311,7 +315,17 @@ def sh_ten_build_fn(
311315

312316
def sh_ten_merge_fn(sub_state_dict):
313317
with torch.no_grad():
314-
return torch.cat(sub_state_dict)
318+
try:
319+
return torch.cat(sub_state_dict)
320+
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
321+
logger.warning(
322+
f"CUDA OutOfMemoryError encountered during tensors merging."
323+
f" Switching to CPU merge. (Error: {e})"
324+
)
325+
merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict])
326+
gc.collect()
327+
torch.cuda.empty_cache()
328+
return merged_sub_state_dict
315329

316330
return ShardedTensorFactory(
317331
original_sh_ten.key,

tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
import inspect
3+
import logging
24

35
import pytest
46
import torch
@@ -13,7 +15,7 @@
1315
)
1416
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
1517
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
16-
from megatron.core.transformer.mlp import MLP
18+
from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory
1719
from megatron.core.transformer.transformer_config import TransformerConfig
1820
from tests.unit_tests.dist_checkpointing import TempNamedDir
1921
from tests.unit_tests.test_utilities import Utils
@@ -61,11 +63,10 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_
6163
"""Test module saving and loading with different TP/PP"""
6264
Utils.initialize_model_parallel(*src_tp_pp)
6365

64-
with TempNamedDir(
65-
tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A'
66-
) as ckpt_dir_A, TempNamedDir(
67-
tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B'
68-
) as ckpt_dir_B:
66+
with (
67+
TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A') as ckpt_dir_A,
68+
TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B') as ckpt_dir_B,
69+
):
6970
# Save checkpoint A
7071
mlp_A = initialize_mlp()
7172
save(mlp_A.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A)
@@ -87,3 +88,49 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_
8788
state_dict_B = load_plain_tensors(ckpt_dir_B)
8889
diffs = diff(state_dict_A, state_dict_B)
8990
assert not any(map(bool, diffs)), diffs
91+
92+
def test_oom_is_handled(self, caplog):
93+
Utils.initialize_model_parallel(Utils.world_size, 1)
94+
dtype = torch.bfloat16
95+
96+
# Compute free memory in bytes
97+
device = torch.cuda.current_device()
98+
allocated = torch.cuda.memory_allocated(device)
99+
total = torch.cuda.get_device_properties(device).total_memory
100+
free = total - allocated
101+
102+
# We should create two tensor which take up between 50% and 100% of free memory,
103+
# so that the torch.cat tries to allocate twice as many and OOMs.
104+
expected_local_num_bytes = free * 0.6
105+
106+
local_num_elems = expected_local_num_bytes // torch._utils._element_size(dtype)
107+
local_num_elems = int(local_num_elems // 1024 * 1024)
108+
assert local_num_elems % 1024 == 0
109+
110+
local_w_plus_v_shape = (local_num_elems // 512, 512)
111+
local_w_or_v_shape = (local_num_elems // 1024, 512)
112+
113+
fc1_weight_sh_ten = ShardedTensor.from_rank_offsets(
114+
'a',
115+
torch.ones(local_w_plus_v_shape, device='cuda', dtype=dtype),
116+
(0, Utils.rank, Utils.world_size),
117+
)
118+
fc1_factory = apply_swiglu_sharded_factory(fc1_weight_sh_ten, ())
119+
sharded_state_dict = fc1_factory.build()
120+
assert len(sharded_state_dict) == 2
121+
assert sharded_state_dict[0].data.shape == local_w_or_v_shape
122+
# NOTE: with singleton_local_shards=True this assert would fail - global shape is
123+
# `(Utils.world_size * local_w_or_v_shape[0], local_w_or_v_shape[1])`
124+
assert sharded_state_dict[0].global_shape[-2:] == (
125+
Utils.world_size * local_w_plus_v_shape[0],
126+
local_w_or_v_shape[1],
127+
)
128+
129+
# Checkpoint load replaces ShardedTensors with tensors.
130+
# Load happens in-place, so we can just use the same tensors
131+
loaded_state_dict = [sh_ten.data for sh_ten in sharded_state_dict]
132+
133+
# The critical part that should OOM:
134+
with caplog.at_level(logging.WARNING):
135+
fc1_factory.merge_fn(loaded_state_dict)
136+
assert "CUDA OutOfMemoryError encountered during tensors merging" in caplog.text

0 commit comments

Comments
 (0)