Skip to content

[BUG]Load DCP OOM #1746

@zjjott

Description

@zjjott

Describe the bug
when loading big enough model(20B PP4) from dcp format will cause OOM error
I have measure memory usage using torch.cuda.memory._dump_snapshot(filename) and analyze snapshot file
show that

Megatron-LM/megatron/core/transformer/mlp.py:344 alloc 12344.00M
Megatron-LM/megatron/core/optimizer/distrib_optimizer.py:773 alloc19128.37MB

cost 31GB memory,it's unnecessary

I see @mikolajblaz have fix here,but it doesn't work,mlp.py will use lots of memory but no oom, until oom on optimizer states resume

on distrib_optimizer.py,state_dict should place on cpu,until self.optimizer.load_state_dict, this is also unnecessary

To Reproduce
using GPT model,using PP4 DP2 EP1 TP1 training a 20B model, each rank will hold 5B model Parameters,save to DCP format and resume this format

Expected behavior
resume 5B model should easy for memory

Stack trace/logs
If applicable, add the stack trace or logs from the time of the error.

Environment (please complete the following information):

  • Megatron-LM commit ID: 2f1027d
  • PyTorch version:torch==2.6.0+cu126
  • CUDA version 12.6
  • NCCL version

Proposed fix
I will send a PR to fix it, where should I complete some testcase for that?

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions