1
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ import inspect
3
+ import logging
2
4
3
5
import pytest
4
6
import torch
13
15
)
14
16
from megatron .core .models .gpt .gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
15
17
from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
16
- from megatron .core .transformer .mlp import MLP
18
+ from megatron .core .transformer .mlp import MLP , apply_swiglu_sharded_factory
17
19
from megatron .core .transformer .transformer_config import TransformerConfig
18
20
from tests .unit_tests .dist_checkpointing import TempNamedDir
19
21
from tests .unit_tests .test_utilities import Utils
@@ -61,11 +63,10 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_
61
63
"""Test module saving and loading with different TP/PP"""
62
64
Utils .initialize_model_parallel (* src_tp_pp )
63
65
64
- with TempNamedDir (
65
- tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A'
66
- ) as ckpt_dir_A , TempNamedDir (
67
- tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B'
68
- ) as ckpt_dir_B :
66
+ with (
67
+ TempNamedDir (tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A' ) as ckpt_dir_A ,
68
+ TempNamedDir (tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B' ) as ckpt_dir_B ,
69
+ ):
69
70
# Save checkpoint A
70
71
mlp_A = initialize_mlp ()
71
72
save (mlp_A .sharded_state_dict (sharded_offsets = get_pp_offsets ()), ckpt_dir_A )
@@ -87,3 +88,49 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_
87
88
state_dict_B = load_plain_tensors (ckpt_dir_B )
88
89
diffs = diff (state_dict_A , state_dict_B )
89
90
assert not any (map (bool , diffs )), diffs
91
+
92
+ def test_oom_is_handled (self , caplog ):
93
+ Utils .initialize_model_parallel (Utils .world_size , 1 )
94
+ dtype = torch .bfloat16
95
+
96
+ # Compute free memory in bytes
97
+ device = torch .cuda .current_device ()
98
+ allocated = torch .cuda .memory_allocated (device )
99
+ total = torch .cuda .get_device_properties (device ).total_memory
100
+ free = total - allocated
101
+
102
+ # We should create two tensor which take up between 50% and 100% of free memory,
103
+ # so that the torch.cat tries to allocate twice as many and OOMs.
104
+ expected_local_num_bytes = free * 0.6
105
+
106
+ local_num_elems = expected_local_num_bytes // torch ._utils ._element_size (dtype )
107
+ local_num_elems = int (local_num_elems // 1024 * 1024 )
108
+ assert local_num_elems % 1024 == 0
109
+
110
+ local_w_plus_v_shape = (local_num_elems // 512 , 512 )
111
+ local_w_or_v_shape = (local_num_elems // 1024 , 512 )
112
+
113
+ fc1_weight_sh_ten = ShardedTensor .from_rank_offsets (
114
+ 'a' ,
115
+ torch .ones (local_w_plus_v_shape , device = 'cuda' , dtype = dtype ),
116
+ (0 , Utils .rank , Utils .world_size ),
117
+ )
118
+ fc1_factory = apply_swiglu_sharded_factory (fc1_weight_sh_ten , ())
119
+ sharded_state_dict = fc1_factory .build ()
120
+ assert len (sharded_state_dict ) == 2
121
+ assert sharded_state_dict [0 ].data .shape == local_w_or_v_shape
122
+ # NOTE: with singleton_local_shards=True this assert would fail - global shape is
123
+ # `(Utils.world_size * local_w_or_v_shape[0], local_w_or_v_shape[1])`
124
+ assert sharded_state_dict [0 ].global_shape [- 2 :] == (
125
+ Utils .world_size * local_w_plus_v_shape [0 ],
126
+ local_w_or_v_shape [1 ],
127
+ )
128
+
129
+ # Checkpoint load replaces ShardedTensors with tensors.
130
+ # Load happens in-place, so we can just use the same tensors
131
+ loaded_state_dict = [sh_ten .data for sh_ten in sharded_state_dict ]
132
+
133
+ # The critical part that should OOM:
134
+ with caplog .at_level (logging .WARNING ):
135
+ fc1_factory .merge_fn (loaded_state_dict )
136
+ assert "CUDA OutOfMemoryError encountered during tensors merging" in caplog .text
0 commit comments