Skip to content

Commit 992a79e

Browse files
authored
Memory budget strategy for activation checkpointing (#297)
See https://pytorch.org/blog/activation-checkpointing-techniques/ for more details, but essentially this is an easy way to try to enable selective activation checkpointing without fiddling with a bunch of different options to try to make it fast but stay within your GPU memory allowance. ![image](https://github.com/user-attachments/assets/5e17af03-aa43-489e-b30e-471ee3025c7e) > We observe a 50% memory reduction by recomputing only pointwise ops, with a steady drop-off as you recompute more and more of your matmuls. Attention is the most expensive, so you tend to want to recompute those last.
1 parent 0dda3ec commit 992a79e

File tree

7 files changed

+37
-1
lines changed

7 files changed

+37
-1
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838

3939
- name: Test
4040
run: |
41-
pytest -v --color=yes --durations=3 -n auto --dist=loadfile \
41+
pytest -v --color=yes --durations=3 -n auto --dist=load \
4242
--ignore-glob='src/test/distributed/checkpoint*' \
4343
src/test/
4444

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2929
- Added support for accessing Google on non-Google clusters via auth with service account keys.
3030
- Added support for revisions in `convert_checkpoint_from_hf.py` and the `load_hf_model` method of `olmo_core.nn.hf.checkpoint`.
3131
- `foreach` support in `SkipStepAdamW`.
32+
- Added `budget` mode for activation checkpointing configuration.
3233

3334
### Changed
3435

src/olmo_core/nn/transformer/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class TransformerActivationCheckpointingMode(StrEnum):
5959
"""Checkpoint only selected modules."""
6060
selected_ops = "selected_ops"
6161
"""Checkpoint only a specific set of operations."""
62+
budget = "budget"
63+
"""Checkpoint based on a budget."""
6264

6365

6466
class TransformerType(StrEnum):

src/olmo_core/nn/transformer/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def apply_activation_checkpointing(
542542
mode: TransformerActivationCheckpointingMode,
543543
block_interval: Optional[int] = None,
544544
modules: Optional[List[str]] = None,
545+
activation_memory_budget: Optional[float] = None,
545546
):
546547
"""
547548
Apply activation checkpointing to the model.
@@ -551,7 +552,20 @@ def apply_activation_checkpointing(
551552
which blocks are wrapped.
552553
:param modules: Required when :data:`mode` is "selected_modules". A list of modules names
553554
to wrap for activation checkpointing. Globs are supported.
555+
:param activation_memory_budget: The memory budget for activation checkpointing in the range
556+
[0, 1]. 0 corresponds to the memory usage when recomputing all activations, and 1
557+
corresponds to the memory usage when recomputing no activations (which is the default).
558+
Requires compilation to be enabled.
554559
"""
560+
561+
if mode == TransformerActivationCheckpointingMode.budget:
562+
if activation_memory_budget is None:
563+
raise ValueError("'activation_memory_budget' is required for 'budget' mode")
564+
if activation_memory_budget < 0 or activation_memory_budget > 1:
565+
raise ValueError("'activation_memory_budget' must be in the range [0, 1]")
566+
torch._functorch.config.activation_memory_budget = activation_memory_budget
567+
return
568+
555569
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
556570
checkpoint_wrapper as ptd_checkpoint_wrapper,
557571
)

src/olmo_core/train/train_module/transformer/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def parallelize_model(
9898
ac_config.mode,
9999
block_interval=ac_config.block_interval,
100100
modules=ac_config.modules,
101+
activation_memory_budget=ac_config.activation_memory_budget,
101102
)
102103
log.info(f"Applied '{ac_config.mode}' activation checkpointing to the model")
103104

src/olmo_core/train/train_module/transformer/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,14 @@ class TransformerActivationCheckpointingConfig(Config):
221221
activation checkpointing. Globs are supported.
222222
"""
223223

224+
activation_memory_budget: Optional[float] = None
225+
"""
226+
Required when :data:`mode` is "budget". Memory budget for activation checkpointing in range [0, 1].
227+
0 = recompute all activations, 1 = recompute none (default). Requires compilation to be enabled.
228+
229+
See https://pytorch.org/blog/activation-checkpointing-techniques/ for more details.
230+
"""
231+
224232
def __post_init__(self):
225233
if (
226234
self.mode == TransformerActivationCheckpointingMode.selected_blocks

src/olmo_core/train/train_module/transformer/train_module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from olmo_core.float8 import Float8Config
3838
from olmo_core.nn.lm_head import LMOutputWithLoss
3939
from olmo_core.nn.transformer import Transformer
40+
from olmo_core.nn.transformer.config import TransformerActivationCheckpointingMode
4041
from olmo_core.optim import OptimConfig, SkipStepOptimizer
4142
from olmo_core.optim.scheduler import Scheduler
4243
from olmo_core.utils import gc_cuda, get_default_device, log_once, move_to_device
@@ -142,6 +143,15 @@ def __init__(
142143
"Training parallelism configs are only valid for distributed training"
143144
)
144145

146+
if (
147+
ac_config is not None
148+
and ac_config.mode == TransformerActivationCheckpointingMode.budget
149+
and not compile_model
150+
):
151+
raise OLMoConfigurationError(
152+
"Activation checkpointing with 'budget' mode requires compilation to be enabled"
153+
)
154+
145155
# Parallelize model.
146156
self.model = parallelize_model(
147157
model,

0 commit comments

Comments
 (0)