Skip to content

Commit 84cf979

Browse files
committed
Merge branch 'fix-ep-tp' into 'main'
Fix Issue with EP and TP See merge request ADLR/megatron-lm!3699
2 parents 8ee323a + 84cc522 commit 84cf979

File tree

11 files changed

+305
-29
lines changed

11 files changed

+305
-29
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def __init__(
155155
tp_size = tensor_model_parallel_size
156156
hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads)
157157
num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size)
158-
159158
# Chunk size tokens, bytes.
160159
dtype_size_bytes = params_dtype.itemsize
161160
self.chunk_size_tokens = chunk_size_tokens
@@ -177,23 +176,24 @@ def __init__(
177176
def bytes_to_max_requests_and_tokens(n_bytes):
178177
n_tokens = n_bytes / self.chunk_size_bytes * self.chunk_size_tokens
179178
n_requests = n_tokens / max_sequence_length
180-
return int(n_requests), int(n_tokens)
179+
return self.round_up_requests(int(n_requests), tp_size=tp_size), self.round_up_tokens(
180+
int(n_tokens), tp_size=tp_size
181+
)
181182

182183
self.max_requests, self.max_tokens = bytes_to_max_requests_and_tokens(buffer_size_bytes)
183-
184184
if buffer_overflow_factor is not None:
185185
self.max_requests = self.round_up_requests(
186-
int(self.max_requests * buffer_overflow_factor)
186+
int(self.max_requests * buffer_overflow_factor), tp_size=tp_size
187187
)
188188
self.max_tokens = self.round_up_tokens(
189-
int(self.max_tokens * buffer_overflow_factor / 50.0)
189+
int(self.max_tokens * buffer_overflow_factor / 50.0), tp_size=tp_size
190190
)
191191

192192
if max_requests_override is not None:
193-
self.max_requests = self.round_up_requests(max_requests_override)
193+
self.max_requests = self.round_up_requests(max_requests_override, tp_size=tp_size)
194194

195195
if max_tokens_override is not None:
196-
self.max_tokens = self.round_up_tokens(max_tokens_override)
196+
self.max_tokens = self.round_up_tokens(max_tokens_override, tp_size=tp_size)
197197

198198
self.max_requests = min(self.max_requests, self.max_tokens) # e.g., decode only.
199199

@@ -277,7 +277,8 @@ def bytes_to_max_requests_and_tokens(n_bytes):
277277
self.cuda_graph_step_size = cuda_graph_rounder * int(
278278
math.ceil(int(self.cuda_graph_step_size) / cuda_graph_rounder)
279279
)
280-
280+
# Make sure divisble by TP size
281+
self.cuda_graph_step_size = math.ceil(self.cuda_graph_step_size / tp_size) * tp_size
281282
# Cuda graph request counts.
282283
if num_cuda_graphs == 1:
283284
self.cuda_graph_request_counts = [self.max_requests]
@@ -355,26 +356,46 @@ def bytes_to_max_requests_and_tokens(n_bytes):
355356
REQUEST_ROUNDER = 4
356357

357358
@classmethod
358-
def round_up_tokens(cls, value):
359-
"""Round up to nearest multiple of `TOKEN_ROUNDER` (above)."""
359+
def round_up_tokens(cls, value, tp_size=None):
360+
"""Round up to nearest multiple of `TOKEN_ROUNDER` (above) that is also divisible by tensor model parallel size."""
360361
if not HAVE_PACKAGING:
361362
raise ImportError(
362363
"`packaging` is required for this functionality, please install it with `pip install packaging`"
363364
)
364365
if PkgVersion(mcore_version) < PkgVersion("0.13"):
365366
return cls.round_up(value)
366-
return cls.TOKEN_ROUNDER * int(math.ceil(int(value) / cls.TOKEN_ROUNDER))
367+
368+
# Make sure divisible by TP size
369+
if tp_size is None:
370+
# Check if parallel state is initialized before trying to get TP size
371+
if parallel_state.is_initialized():
372+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
373+
else:
374+
tp_size = 1
375+
token_rounder = math.ceil(cls.TOKEN_ROUNDER / tp_size) * tp_size
376+
377+
return token_rounder * int(math.ceil(int(value) / token_rounder))
367378

368379
@classmethod
369-
def round_up_requests(cls, value):
370-
"""Round up to nearest multiple of `REQUEST_ROUNDER` (above)."""
380+
def round_up_requests(cls, value, tp_size=None):
381+
"""Round up to nearest multiple of `REQUEST_ROUNDER` (above) that is also divisible by tensor model parallel size."""
371382
if not HAVE_PACKAGING:
372383
raise ImportError(
373384
"`packaging` is required for this functionality, please install it with `pip install packaging`"
374385
)
375386
if PkgVersion(mcore_version) < PkgVersion("0.13"):
376387
return cls.round_up(value)
377-
return cls.REQUEST_ROUNDER * int(math.ceil(int(value) / cls.REQUEST_ROUNDER))
388+
389+
# Make sure divisible by TP size
390+
if tp_size is None:
391+
# Check if parallel state is initialized before trying to get TP size
392+
if parallel_state.is_initialized():
393+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
394+
else:
395+
tp_size = 1
396+
request_rounder = math.ceil(cls.REQUEST_ROUNDER / tp_size) * tp_size
397+
398+
return request_rounder * int(math.ceil(int(value) / request_rounder))
378399

379400
@classmethod
380401
def round_up(cls, value):

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from megatron.core.inference.sampling_params import SamplingParams
2727
from megatron.core.inference.utils import get_attention_mask
2828
from megatron.core.transformer.cuda_graphs import create_cudagraphs
29+
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
30+
from megatron.core.transformer.utils import set_model_to_sequence_parallel
2931
from megatron.core.utils import get_model_config
3032

3133
try:
@@ -429,9 +431,11 @@ def generate_output_tokens_dynamic_batch(
429431
# Get flat tokens, position ids.
430432
input_ids, position_ids = context.current_input_and_position_ids()
431433

434+
model_config = get_model_config(self.inference_wrapped_model.model)
435+
432436
# If using symmetric kernels and we are using using nccl
433437
# for prefill turn off symmetric kernels
434-
symmetric_ar_type = get_model_config(self.inference_wrapped_model.model).symmetric_ar_type
438+
symmetric_ar_type = model_config.symmetric_ar_type
435439
nccl_all_reduce_for_prefill = (
436440
self.inference_wrapped_model.inference_wrapper_config.nccl_all_reduce_for_prefill
437441
)
@@ -681,6 +685,21 @@ def generate_all_output_tokens_static_batch(
681685
not self.inference_wrapped_model.inference_context.is_decode_only()
682686
), f"Generation must start in prefill mode"
683687

688+
# Sequence parallelism is required for MoE layers when using expert parallelism (EP)
689+
# becausethe expert routing mechanism relies on sequence parallelism's communication
690+
# infrastructure to distribute tokens across expert ranks. However, sequence parallelism
691+
# is not currently supported for non-MoE layers during inference,so we selectively
692+
# disable it for all other layer types. This is safe because MoE layers perform an
693+
# all-gather operation on sequences before passing data to subsequent layers, ensuring
694+
# that each rank has the complete sequence data needed for the next non-MoE layer.
695+
tp_size = model_config.tensor_model_parallel_size
696+
ep_size = model_config.expert_model_parallel_size
697+
model_is_tp_ep = tp_size > 1 and ep_size > 1
698+
if model_is_tp_ep:
699+
set_model_to_sequence_parallel(
700+
self.inference_wrapped_model.model.module, False, exclude_modules=[BaseMoELayer]
701+
)
702+
684703
# If using symmetric kernels and we are using using nccl
685704
# for prefill turn off symmetric kernels
686705
symmetric_ar_type = model_config.symmetric_ar_type

megatron/core/transformer/utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,14 @@ def sharded_state_dict_default(
199199
_sequence_parallel_attr_cache = None
200200

201201

202-
def _init_sequence_parallel_cache(model):
202+
def _init_sequence_parallel_cache(model, exclude_modules):
203203
"""
204204
Initialize the cache of modules with sequence parallel attributes.
205205
Only needs to be called once, subsequent calls have no effect.
206+
207+
Args:
208+
model: model to change sequence parallelism attributes
209+
exclude_modules: Modules to exclude from changing sequence parallelism
206210
"""
207211
global _sequence_parallel_attr_cache
208212
model_id = id(model)
@@ -229,33 +233,35 @@ def _init_sequence_parallel_cache(model):
229233

230234
# Recursive function to find all modules with our target attributes
231235
def find_modules_with_attrs(module):
232-
# Check if this module has any of our target attributes
233-
for attr in sequence_parallel_attrs:
234-
if hasattr(module, attr):
235-
_sequence_parallel_attr_cache[model_id][attr].append(module)
236+
if exclude_modules is None or module not in exclude_modules:
237+
# Check if this module has any of our target attributes
238+
for attr in sequence_parallel_attrs:
239+
if hasattr(module, attr):
240+
_sequence_parallel_attr_cache[model_id][attr].append(module)
236241

237-
# Check all children modules recursively
238-
for child in module._modules.values():
239-
if child is not None:
240-
find_modules_with_attrs(child)
242+
# Check all children modules recursively
243+
for child in module._modules.values():
244+
if child is not None:
245+
find_modules_with_attrs(child)
241246

242247
# Start the search from each major component
243248
find_modules_with_attrs(model_modules)
244249

245250

246-
def set_model_to_sequence_parallel(model, set_to=False):
251+
def set_model_to_sequence_parallel(model, set_to=False, exclude_modules=None):
247252
"""
248253
Set sequence parallel attributes for the model.
249254
250255
Args:
251256
set_to: Value to set for sequence_parallel attributes
257+
exclude_modules: Modules to exclude from changing sequence parallelism
252258
"""
253259
global _sequence_parallel_attr_cache
254260
model_id = id(model)
255261

256262
# Initialize cache if needed
257263
if _sequence_parallel_attr_cache is None or model_id not in _sequence_parallel_attr_cache:
258-
_init_sequence_parallel_cache(model)
264+
_init_sequence_parallel_cache(model, exclude_modules)
259265

260266
model.config.sequence_parallel = set_to
261267

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"0": {"input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", "generated_text": " Wait for the moment when the music stops, and the lights come up, and the DJ says, \"I'm going to play a song for you", "generated_tokens": [32844, 1394, 1278, 4735, 2200, 1278, 7146, 30774, 1044, 1321, 1278, 26466, 3930, 2015, 1044, 1321, 1278, 30245, 8223, 1044, 1429, 1073, 4525, 4670, 1317, 3354, 1261, 6947, 1394, 1636], "latency": 19.686351776123047, "logprobs": [-5.292269229888916, -7.716421127319336, -9.068008422851562, -12.118106842041016, -3.741875648498535, -1.8551081418991089, -1.8765699863433838, -9.52701187133789, -15.140599250793457, -9.51123046875, -10.250877380371094, -8.108641624450684, -9.099360466003418, -9.529533386230469, -10.495244979858398, -9.094446182250977, -9.802777290344238, -8.999850273132324, -9.103967666625977, -9.895696640014648, -8.00230884552002, -7.3570966720581055, -7.892587184906006, -12.32270622253418, -20.62922477722168, -9.672601699829102, -8.485877990722656, -10.270708084106445, -11.473578453063965, -15.617767333984375, -7.8881988525390625, -12.872822761535645, -8.940616607666016, -7.3508501052856445, -10.157344818115234, -12.235904693603516, -9.32239818572998, -6.516319751739502, -8.389573097229004, -8.860508918762207, -16.462238311767578, -7.349939346313477, -11.075740814208984, -14.01966667175293, -9.536352157592773, -9.535323143005371, -11.839295387268066, -12.564751625061035, -9.356565475463867, -9.240681648254395, -9.669130325317383, -8.965482711791992, -11.053977012634277, -14.045623779296875, -13.299668312072754, -14.695284843444824, -13.231292724609375, -9.543293952941895, -11.672986030578613, -10.587867736816406, -8.400468826293945, -10.324536323547363, -13.930037498474121, -15.27256965637207, -10.176668167114258, -13.777766227722168, -8.423280715942383, -7.511598110198975, -14.0129976272583, -5.561246871948242, -9.51725959777832, -10.10839557647705, -8.918962478637695, -8.14908218383789, -7.653857707977295, -11.743547439575195, -10.011963844299316, -12.899750709533691, -15.401609420776367, -6.838616847991943, -9.010682106018066, -10.37846565246582, -6.819251537322998, -13.074575424194336, -10.851410865783691, -8.874515533447266, -10.204574584960938, -16.298084259033203, -13.584976196289062, -10.295950889587402, -8.796205520629883, -12.162117004394531, -9.572405815124512, -8.92280101776123, -10.94050407409668, -15.27184772491455, -13.962615966796875, -9.328908920288086, -9.781393051147461, -12.07744026184082, -11.402749061584473, -11.740723609924316, -17.354206085205078, -9.84351634979248, -9.201858520507812, -8.702098846435547, -12.5997314453125, -14.244935989379883, -14.273555755615234, -16.253263473510742, -13.604464530944824, -11.363554000854492, -9.675899505615234, -12.930312156677246, -10.388641357421875, -11.593982696533203, -10.904473304748535]}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
ENV_VARS:
2+
CUDA_DEVICE_MAX_CONNECTIONS: 1
3+
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
4+
NCCL_ALGO: Ring
5+
CUBLAS_WORKSPACE_CONFIG: :4096:8
6+
TEST_TYPE: frozen-start
7+
MODE: inference
8+
MODEL_ARGS:
9+
--log-num-zeros-in-grad: true
10+
--log-validation-ppl-to-tensorboard: true
11+
--log-timers-to-tensorboard: true
12+
--log-memory-to-tensorboard: true
13+
--timing-log-level: 2
14+
--load: ${CHECKPOINT_LOAD_PATH}/deepseek_16b_pyt/model/checkpoints
15+
--tokenizer-model: ${DATA_PATH}/deepseek_16b_pyt/tokenizer/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json
16+
--tokenizer-type: TikTokenizer
17+
--tiktoken-pattern: v2
18+
--distributed-backend: nccl
19+
--log-interval: 1
20+
--transformer-impl: transformer_engine
21+
--tensor-model-parallel-size: 4
22+
--pipeline-model-parallel-size: 1
23+
--expert-model-parallel-size: 4
24+
--expert-tensor-parallel-size: 1
25+
--sequence-parallel: true
26+
--use-mcore-models: true
27+
--moe-token-dispatcher-type: alltoall
28+
--moe-grouped-gemm: true
29+
--num-experts: 64
30+
--moe-router-topk: 6
31+
--moe-z-loss-coeff: 0
32+
--moe-router-load-balancing-type: seq_aux_loss
33+
--moe-aux-loss-coeff: 1e-3
34+
--moe-router-score-function: sigmoid
35+
--untie-embeddings-and-output-weights: true
36+
--disable-bias-linear: true
37+
--init-method-std: 0.014
38+
--position-embedding-type: rope
39+
--rotary-base: 1000000
40+
--rotary-percent: 1.0
41+
--num-layers: 27
42+
--hidden-size: 2048
43+
--moe-ffn-hidden-size: 1408
44+
--moe-shared-expert-intermediate-size: 2816
45+
--ffn-hidden-size: 10944
46+
--num-attention-heads: 16
47+
--kv-channels: 128
48+
--normalization: RMSNorm
49+
--swiglu: true
50+
--attention-dropout: 0.0
51+
--hidden-dropout: 0.0
52+
--seq-length: 4096
53+
--max-position-embeddings: 4096
54+
--micro-batch-size: 1
55+
--ckpt-format: torch_dist
56+
--ckpt-fully-parallel-save: true
57+
--ckpt-fully-parallel-load: true
58+
--ckpt-assume-constant-structure: true
59+
--dist-ckpt-strictness: log_unexpected
60+
--bf16: true
61+
--attention-backend: flash
62+
--no-create-attention-mask-in-dataloader: true
63+
--num-workers: 8
64+
--use-checkpoint-args: true
65+
--no-use-tokenizer-model-from-checkpoint-args: true
66+
--no-load-optim: true
67+
--deterministic-mode: true
68+
--save-interval: 2000
69+
--temperature: 1.0
70+
--top_k: 1
71+
--return-log-probs: true
72+
--num-tokens-to-generate: 30
73+
--max-tokens-to-oom: 3600000
74+
--inference-max-seq-length: 4096
75+
--output-path: ${TENSORBOARD_PATH}
76+
--prompts: "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies."
77+
METRICS:
78+
- "generated_tokens"
79+
- "logprobs"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"0": {"input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", "generated_text": " Wait for the moment when the music stops, and the lights come up, and the DJ says, \"I'm going to play a song for you", "generated_tokens": [32844, 1394, 1278, 4735, 2200, 1278, 7146, 30774, 1044, 1321, 1278, 26466, 3930, 2015, 1044, 1321, 1278, 30245, 8223, 1044, 1429, 1073, 4525, 4670, 1317, 3354, 1261, 6947, 1394, 1636], "tpot": [11.527735710144043, 0.591648280620575, 0.07851152122020721, 0.08397766202688217, 0.07938633859157562, 0.07749997079372406, 0.078935906291008, 0.08028851449489594, 0.0792686715722084, 0.07622275501489639, 0.07548975944519043, 0.07680464535951614, 0.07543014734983444, 0.07560738921165466, 0.07399769872426987, 0.07543785870075226, 0.07579366117715836, 0.0751631036400795, 0.07548335939645767, 0.07564015686511993, 0.07764911651611328, 0.07656403630971909, 0.07500188797712326, 0.07636498659849167, 0.07543619722127914, 0.07694652676582336, 0.07432099431753159, 0.0751761943101883, 0.07691458612680435, 0.07628953456878662], "latency": 14.289155829232186, "logprobs": [-10.448518753051758, -3.693941593170166, -2.833103656768799, -1.2445695400238037, -0.23799529671669006, -1.7522815465927124, -2.378152370452881, -1.9484899044036865, -2.108924388885498, -6.127920150756836, -0.8197959661483765, -2.477976083755493, -3.492497444152832, -4.170319557189941, -1.9918553829193115, -1.8618279695510864, -2.2335567474365234, -7.071791172027588, -0.039936937391757965, -1.9948835372924805, -5.008172512054443, -8.708097457885742, -9.903486251831055, -0.851460337638855, -4.765171051025391, -0.8707393407821655, -2.219733238220215, -0.01853257417678833, -0.035978663712739944, -3.387631416320801, -8.754067420959473, -1.2686023712158203, -6.662981986999512, -3.7872395515441895, -3.6667354106903076, -4.171259880065918, -2.2128500938415527, -1.091404914855957, -0.22139909863471985, -0.8265669941902161, -4.746159553527832, -9.04170036315918, -0.013459297828376293, -3.17301607131958, -1.3139652013778687, -3.9821701049804688, -0.7707944512367249, -0.002040567807853222, -2.9162371158599854, -10.677328109741211, -3.1504364013671875, -1.1485933065414429, -4.871399402618408, -0.20786719024181366, -0.06325722485780716, -1.3587590456008911, -2.207646369934082, -4.407937049865723, -0.36253970861434937, -4.0189995765686035, -0.3988611698150635, -0.13855230808258057, -2.7199528217315674, -10.558171272277832, -0.04671315476298332, -3.5006980895996094, -0.9756439328193665, -4.673828125, -0.2634696066379547, -2.5747756958007812, -0.8531911969184875, -1.6041897535324097, -5.738401412963867, -16.978456497192383, -2.6206722259521484, -0.14098073542118073, -7.450814247131348, -1.076573371887207, -2.129807472229004, -1.5724716186523438, -0.29326727986335754, -5.609436511993408, -0.0065282415598630905, -7.79502010345459, -2.715085744857788, -3.0889575481414795, -3.0355961322784424, -2.4395439624786377, -0.3983170986175537, -1.5089631080627441, -2.276723861694336, -0.6004312038421631, -1.3054823875427246, -1.9454480409622192, -1.7226327657699585, -0.7742734551429749, -0.49186939001083374, -1.2962923049926758, -1.567298173904419, -1.0149078369140625, -0.40288272500038147, -0.4789682626724243, -0.04533138871192932, -1.2695876359939575, -2.223480224609375, -2.6703481674194336, -0.7677091956138611, -0.42749911546707153, -2.8563802242279053, -1.5350499153137207, -1.6456167697906494, -0.05149398744106293, -1.3739523887634277, -1.3543274402618408, -1.2655469179153442, -1.307403326034546, -0.497008740901947]}}

0 commit comments

Comments
 (0)