Skip to content

Commit 2fbd646

Browse files
committed
Merge branch 'mla_down_proj_telinear' into 'main'
perf(MLA): MLA down proj switch back to TELinear See merge request ADLR/megatron-lm!3576
2 parents 7af09be + 720c8b4 commit 2fbd646

File tree

10 files changed

+213
-160
lines changed

10 files changed

+213
-160
lines changed

megatron/core/extensions/transformer_engine_spec_provider.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TEColumnParallelLinear,
99
TEDotProductAttention,
1010
TELayerNormColumnParallelLinear,
11+
TELinear,
1112
TENorm,
1213
TERowParallelGroupedLinear,
1314
TERowParallelLinear,
@@ -23,6 +24,10 @@
2324
class TESpecProvider(BackendSpecProvider):
2425
"""A protocol for providing the submodules used in Spec building."""
2526

27+
def linear(self) -> type:
28+
"""Which linear module TE backend uses"""
29+
return TELinear
30+
2631
def column_parallel_linear(self) -> type:
2732
"""Which column parallel linear module TE backend uses"""
2833
return TEColumnParallelLinear

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def get_gpt_layer_with_transformer_engine_spec(
141141
params={"attn_mask_type": AttnMaskType.causal},
142142
submodules=MLASelfAttentionSubmodules(
143143
linear_q_proj=backend.column_parallel_linear(),
144-
linear_q_down_proj=backend.column_parallel_linear(),
144+
linear_q_down_proj=backend.linear(),
145145
linear_q_up_proj=linear_q_up_proj,
146-
linear_kv_down_proj=backend.column_parallel_linear(),
146+
linear_kv_down_proj=backend.linear(),
147147
linear_kv_up_proj=linear_kv_up_proj,
148148
core_attention=backend.core_attention(),
149149
linear_proj=backend.row_parallel_linear(),

megatron/core/transformer/multi_latent_attention.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
apply_rotary_pos_emb,
1616
)
1717
from megatron.core.process_groups_config import ModelCommProcessGroups
18+
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
1819
from megatron.core.tensor_parallel.mappings import (
1920
gather_from_sequence_parallel_region,
2021
gather_from_tensor_model_parallel_region,
@@ -36,6 +37,16 @@
3637
fused_apply_mla_rope_for_q = None
3738

3839

40+
try:
41+
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TELinear
42+
from megatron.core.post_training.modelopt.layers import Linear
43+
44+
HAVE_TE = True
45+
except ImportError:
46+
TEColumnParallelLinear, TELinear, Linear = None, None, None
47+
HAVE_TE = False
48+
49+
3950
@dataclass
4051
class MLASelfAttentionSubmodules:
4152
"""Submodules for the MLA self-attention layer."""
@@ -282,6 +293,17 @@ def __init__(
282293
)
283294

284295
else:
296+
q_down_proj_kwargs = {}
297+
if submodules.linear_q_down_proj in [TELinear]:
298+
q_down_proj_kwargs['parallel_mode'] = 'duplicated'
299+
elif submodules.linear_q_down_proj in [
300+
Linear,
301+
TEColumnParallelLinear,
302+
ColumnParallelLinear,
303+
]:
304+
q_down_proj_kwargs['gather_output'] = False
305+
else:
306+
raise ValueError(f"Unsupported linear_q_down_proj: {submodules.linear_q_down_proj}")
285307

286308
self.linear_q_down_proj = build_module(
287309
submodules.linear_q_down_proj,
@@ -291,9 +313,10 @@ def __init__(
291313
init_method=self.config.init_method,
292314
bias=False,
293315
skip_bias_add=False,
294-
gather_output=False,
295316
is_expert=False,
296317
tp_comm_buffer_name='q_down_proj',
318+
skip_weight_param_allocation=False,
319+
**q_down_proj_kwargs,
297320
)
298321

299322
self.linear_q_up_proj = build_module(
@@ -309,6 +332,18 @@ def __init__(
309332
tp_comm_buffer_name='q_up_proj',
310333
)
311334

335+
kv_down_proj_kwargs = {}
336+
if submodules.linear_kv_down_proj in [TELinear]:
337+
kv_down_proj_kwargs['parallel_mode'] = 'duplicated'
338+
elif submodules.linear_kv_down_proj in [
339+
Linear,
340+
TEColumnParallelLinear,
341+
ColumnParallelLinear,
342+
]:
343+
kv_down_proj_kwargs['gather_output'] = False
344+
else:
345+
raise ValueError(f"Unsupported linear_kv_down_proj: {submodules.linear_kv_down_proj}")
346+
312347
self.linear_kv_down_proj = build_module(
313348
submodules.linear_kv_down_proj,
314349
self.config.hidden_size,
@@ -317,9 +352,10 @@ def __init__(
317352
init_method=self.config.init_method,
318353
bias=False,
319354
skip_bias_add=False,
320-
gather_output=False,
321355
is_expert=False,
322356
tp_comm_buffer_name='kv_down_proj',
357+
skip_weight_param_allocation=False,
358+
**kv_down_proj_kwargs,
323359
)
324360

325361
self.linear_kv_up_proj = build_module(
@@ -453,7 +489,10 @@ def get_query_key_value_tensors(
453489
kv_compressed, k_pos_emb = torch.split(
454490
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
455491
)
456-
if parallel_state.get_tensor_model_parallel_world_size() > 1:
492+
if (
493+
parallel_state.get_tensor_model_parallel_world_size() > 1
494+
and self.config.sequence_parallel
495+
):
457496
# k_pos_emb: [s, b, qk_pos_emb_head_dim]
458497
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)
459498

megatron/training/arguments.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,13 @@ def core_transformer_config_from_args(args, config_class=None):
11351135
else:
11361136
kw_args['num_query_groups'] = None
11371137
kw_args['config_logger_dir'] = args.config_logger_dir
1138+
if args.rope_type is None:
1139+
# Pop 'rope_type' to let the config class use the default value.
1140+
kw_args.pop('rope_type', None)
1141+
else:
1142+
assert (args.multi_latent_attention or args.rope_type == 'rope'), (
1143+
f'Common attention only support rope_type="rope", but got {args.rope_type}.'
1144+
)
11381145

11391146
if len(args.cp_comm_type) == 1:
11401147
kw_args['cp_comm_type'] = args.cp_comm_type[0]
@@ -1884,6 +1891,10 @@ def _add_training_args(parser):
18841891
help='Disable rope fusion, the fusion is available '
18851892
'only when using megatron-core.',
18861893
dest='apply_rope_fusion')
1894+
group.add_argument('--rope-type', type=str, default=None,
1895+
choices=['rope', 'yarn'],
1896+
help='Type of rope to use. Note that MLA takes yarn by default, '
1897+
'and common attention takes rope by default.')
18871898
group.add_argument('--cross-entropy-loss-fusion', action='store_true',
18881899
help='Enabled fusion of cross entropy loss calculation.',
18891900
dest='cross_entropy_loss_fusion')
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"lm loss": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 10.90224, "5": 10.9121, "10": 10.89882, "15": 10.90014, "20": 10.87364, "25": 10.86175, "30": 10.79053, "35": 10.76848, "40": 10.63331, "45": 10.54116, "50": 10.54543}}, "num-zeros": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 22791116.0, "5": 22778568.0, "10": 22982992.0, "15": 22885372.0, "20": 22758452.0, "25": 22883116.0, "30": 22694696.0, "35": 22851780.0, "40": 22721894.0, "45": 22738960.0, "50": 22968936.0}}, "mem-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 384918016.0, "5": 384918016.0, "10": 384918016.0, "15": 384918016.0, "20": 384918016.0, "25": 384918016.0, "30": 384918016.0, "35": 384918016.0, "40": 384918016.0, "45": 384918016.0, "50": 384918016.0}}, "mem-max-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 1123523072.0, "5": 1245691392.0, "10": 1245691392.0, "15": 1245691392.0, "20": 1245691392.0, "25": 1245691392.0, "30": 1245691392.0, "35": 1245691392.0, "40": 1245691392.0, "45": 1245691392.0, "50": 1245691392.0}}, "iteration-time": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 9.95327, "5": 0.107, "10": 0.1037, "15": 0.10008, "20": 0.09966, "25": 0.09698, "30": 0.09982, "35": 0.09784, "40": 0.09998, "45": 0.09728, "50": 0.10112}}}
1+
{"lm loss": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 10.92337, "5": 10.92089, "10": 10.92192, "15": 10.92351, "20": 10.90031, "25": 10.87827, "30": 10.81423, "35": 10.78865, "40": 10.65927, "45": 10.56875, "50": 10.55421}}, "num-zeros": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 22791636.0, "5": 22778528.0, "10": 22983232.0, "15": 22886400.0, "20": 22758358.0, "25": 22883742.0, "30": 22695256.0, "35": 22851572.0, "40": 22721680.0, "45": 22738904.0, "50": 22968272.0}}, "mem-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 384918016.0, "5": 384918016.0, "10": 384918016.0, "15": 384918016.0, "20": 384918016.0, "25": 384918016.0, "30": 384918016.0, "35": 384918016.0, "40": 384918016.0, "45": 384918016.0, "50": 384918016.0}}, "mem-max-allocated-bytes": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 1123523072.0, "5": 1245691392.0, "10": 1245691392.0, "15": 1245691392.0, "20": 1245691392.0, "25": 1245691392.0, "30": 1245691392.0, "35": 1245691392.0, "40": 1245691392.0, "45": 1245691392.0, "50": 1245691392.0}}, "iteration-time": {"start_step": 1, "end_step": 50, "step_interval": 5, "values": {"1": 9.95327, "5": 0.107, "10": 0.1037, "15": 0.10008, "20": 0.09966, "25": 0.09698, "30": 0.09982, "35": 0.09784, "40": 0.09998, "45": 0.09728, "50": 0.10112}}}

tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_memory_speed/golden_values_dev_dgx_h100.json

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,35 @@
44
"end_step": 50,
55
"step_interval": 5,
66
"values": {
7-
"1": 11.11942,
8-
"5": 9.28112,
9-
"10": 8.96767,
10-
"15": 7.98924,
11-
"20": 7.80173,
12-
"25": 7.63303,
13-
"30": 7.57275,
14-
"35": 7.16214,
15-
"40": 7.49418,
16-
"45": 7.13663,
17-
"50": 6.97223
7+
"1": 11.04747,
8+
"5": 9.52402,
9+
"10": 9.0596,
10+
"15": 8.04431,
11+
"20": 7.90653,
12+
"25": 7.67312,
13+
"30": 7.64496,
14+
"35": 7.21326,
15+
"40": 7.54337,
16+
"45": 7.18518,
17+
"50": 7.03308
1818
}
1919
},
2020
"num-zeros": {
2121
"start_step": 1,
2222
"end_step": 50,
2323
"step_interval": 5,
2424
"values": {
25-
"1": 38802576.0,
26-
"5": 240322608.0,
27-
"10": 627841536.0,
28-
"15": 579569024.0,
29-
"20": 658931008.0,
30-
"25": 509733120.0,
31-
"30": 445364352.0,
32-
"35": 561495552.0,
33-
"40": 311616704.0,
34-
"45": 420418656.0,
35-
"50": 199242224.0
25+
"1": 38802572.0,
26+
"5": 252883888.0,
27+
"10": 731676032.0,
28+
"15": 708564416.0,
29+
"20": 989209984.0,
30+
"25": 827443712.0,
31+
"30": 756766080.0,
32+
"35": 709348352.0,
33+
"40": 588409600.0,
34+
"45": 517940384.0,
35+
"50": 409992032.0
3636
}
3737
},
3838
"mem-allocated-bytes": {
@@ -76,17 +76,17 @@
7676
"end_step": 50,
7777
"step_interval": 5,
7878
"values": {
79-
"1": 11.04374,
80-
"5": 9.86351,
81-
"10": 9.02642,
82-
"15": 7.99309,
83-
"20": 7.4113,
84-
"25": 7.73904,
85-
"30": 7.48829,
86-
"35": 7.54205,
87-
"40": 7.94269,
88-
"45": 7.3323,
89-
"50": 6.83748
79+
"1": 11.10137,
80+
"5": 9.92778,
81+
"10": 9.06984,
82+
"15": 8.03354,
83+
"20": 7.45652,
84+
"25": 7.77087,
85+
"30": 7.52221,
86+
"35": 7.54715,
87+
"40": 7.94738,
88+
"45": 7.32562,
89+
"50": 6.85517
9090
}
9191
},
9292
"iteration-time": {

tests/functional_tests/test_cases/moe/gpt3_mr_mcore_te_tp2_pp2_ep4_etp1_mtp_resume_torch_dist_fp8/golden_values_dev_dgx_h100.json

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,35 @@
44
"end_step": 50,
55
"step_interval": 5,
66
"values": {
7-
"1": 11.12115,
8-
"5": 9.48776,
9-
"10": 9.06639,
10-
"15": 8.04478,
11-
"20": 7.81622,
12-
"25": 7.64405,
13-
"30": 7.59737,
14-
"35": 7.17911,
15-
"40": 7.50704,
16-
"45": 7.15152,
17-
"50": 6.9914
7+
"1": 11.0475,
8+
"5": 9.47263,
9+
"10": 8.90522,
10+
"15": 7.94285,
11+
"20": 7.7696,
12+
"25": 7.60471,
13+
"30": 7.56115,
14+
"35": 7.14613,
15+
"40": 7.47799,
16+
"45": 7.11821,
17+
"50": 6.96092
1818
}
1919
},
2020
"num-zeros": {
2121
"start_step": 1,
2222
"end_step": 50,
2323
"step_interval": 5,
2424
"values": {
25-
"1": 38802568.0,
26-
"5": 290825856.0,
27-
"10": 731989952.0,
28-
"15": 630001984.0,
29-
"20": 677850432.0,
30-
"25": 585432256.0,
31-
"30": 750669888.0,
32-
"35": 618214784.0,
33-
"40": 531952480.0,
34-
"45": 313537728.0,
35-
"50": 394303040.0
25+
"1": 38802664.0,
26+
"5": 221567312.0,
27+
"10": 735002624.0,
28+
"15": 611135296.0,
29+
"20": 590465536.0,
30+
"25": 542043712.0,
31+
"30": 429887840.0,
32+
"35": 467241280.0,
33+
"40": 380798464.0,
34+
"45": 329247616.0,
35+
"50": 284175040.0
3636
}
3737
},
3838
"mem-allocated-bytes": {
@@ -76,17 +76,17 @@
7676
"end_step": 50,
7777
"step_interval": 5,
7878
"values": {
79-
"1": 11.03802,
80-
"5": 9.88912,
81-
"10": 9.02538,
82-
"15": 8.00927,
83-
"20": 7.41448,
84-
"25": 7.73636,
85-
"30": 7.48773,
86-
"35": 7.5439,
87-
"40": 7.93894,
88-
"45": 7.32524,
89-
"50": 6.83572
79+
"1": 11.10067,
80+
"5": 9.99203,
81+
"10": 8.95639,
82+
"15": 7.95116,
83+
"20": 7.37498,
84+
"25": 7.71218,
85+
"30": 7.46442,
86+
"35": 7.5167,
87+
"40": 7.91951,
88+
"45": 7.30491,
89+
"50": 6.82535
9090
}
9191
},
9292
"iteration-time": {

tests/functional_tests/test_cases/moe/gpt3_mr_mcore_te_tp2_pp2_ep4_etp1_resume_torch_dist_attn_cudagraph/golden_values_dev_dgx_h100.json

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,35 @@
44
"end_step": 50,
55
"step_interval": 5,
66
"values": {
7-
"1": 10.96592,
8-
"5": 9.91128,
9-
"10": 9.79015,
10-
"15": 9.01941,
11-
"20": 8.83745,
12-
"25": 8.65462,
13-
"30": 8.67364,
14-
"35": 8.08232,
15-
"40": 8.36432,
16-
"45": 8.09963,
17-
"50": 7.81534
7+
"1": 10.94995,
8+
"5": 9.9238,
9+
"10": 9.85512,
10+
"15": 9.01582,
11+
"20": 8.83018,
12+
"25": 8.62061,
13+
"30": 8.65266,
14+
"35": 8.06408,
15+
"40": 8.34095,
16+
"45": 8.08321,
17+
"50": 7.79855
1818
}
1919
},
2020
"num-zeros": {
2121
"start_step": 1,
2222
"end_step": 50,
2323
"step_interval": 5,
2424
"values": {
25-
"1": 19403588.0,
26-
"5": 106158376.0,
27-
"10": 167686976.0,
28-
"15": 197075792.0,
29-
"20": 198976032.0,
30-
"25": 270640864.0,
31-
"30": 224341904.0,
32-
"35": 246167936.0,
33-
"40": 174723744.0,
34-
"45": 125710632.0,
35-
"50": 156288624.0
25+
"1": 19403880.0,
26+
"5": 142219280.0,
27+
"10": 125211168.0,
28+
"15": 264681776.0,
29+
"20": 217875968.0,
30+
"25": 236053968.0,
31+
"30": 266788816.0,
32+
"35": 243068672.0,
33+
"40": 166847584.0,
34+
"45": 159913152.0,
35+
"50": 165693360.0
3636
}
3737
},
3838
"mem-allocated-bytes": {

0 commit comments

Comments
 (0)