Skip to content

Memory budget strategy for activation checkpointing #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for accessing Google on non-Google clusters via auth with service account keys.
- Added support for revisions in `convert_checkpoint_from_hf.py` and the `load_hf_model` method of `olmo_core.nn.hf.checkpoint`.
- `foreach` support in `SkipStepAdamW`.
- Added `budget` mode for activation checkpointing configuration.

### Changed

Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class TransformerActivationCheckpointingMode(StrEnum):
"""Checkpoint only selected modules."""
selected_ops = "selected_ops"
"""Checkpoint only a specific set of operations."""
budget = "budget"
"""Checkpoint based on a budget."""


class TransformerType(StrEnum):
Expand Down
14 changes: 14 additions & 0 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def apply_activation_checkpointing(
mode: TransformerActivationCheckpointingMode,
block_interval: Optional[int] = None,
modules: Optional[List[str]] = None,
activation_memory_budget: Optional[float] = None,
):
"""
Apply activation checkpointing to the model.
Expand All @@ -551,7 +552,20 @@ def apply_activation_checkpointing(
which blocks are wrapped.
:param modules: Required when :data:`mode` is "selected_modules". A list of modules names
to wrap for activation checkpointing. Globs are supported.
:param activation_memory_budget: The memory budget for activation checkpointing in the range
[0, 1]. 0 corresponds to the memory usage when recomputing all activations, and 1
corresponds to the memory usage when recomputing no activations (which is the default).
Requires compilation to be enabled.
"""

if mode == TransformerActivationCheckpointingMode.budget:
if activation_memory_budget is None:
raise ValueError("'activation_memory_budget' is required for 'budget' mode")
if activation_memory_budget < 0 or activation_memory_budget > 1:
raise ValueError("'activation_memory_budget' must be in the range [0, 1]")
torch._functorch.config.activation_memory_budget = activation_memory_budget
return

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/train/train_module/transformer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def parallelize_model(
ac_config.mode,
block_interval=ac_config.block_interval,
modules=ac_config.modules,
activation_memory_budget=ac_config.activation_memory_budget,
)
log.info(f"Applied '{ac_config.mode}' activation checkpointing to the model")

Expand Down
8 changes: 8 additions & 0 deletions src/olmo_core/train/train_module/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ class TransformerActivationCheckpointingConfig(Config):
activation checkpointing. Globs are supported.
"""

activation_memory_budget: Optional[float] = None
"""
Required when :data:`mode` is "budget". Memory budget for activation checkpointing in range [0, 1].
0 = recompute all activations, 1 = recompute none (default). Requires compilation to be enabled.

See https://pytorch.org/blog/activation-checkpointing-techniques/ for more details.
"""

def __post_init__(self):
if (
self.mode == TransformerActivationCheckpointingMode.selected_blocks
Expand Down
10 changes: 10 additions & 0 deletions src/olmo_core/train/train_module/transformer/train_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from olmo_core.float8 import Float8Config
from olmo_core.nn.lm_head import LMOutputWithLoss
from olmo_core.nn.transformer import Transformer
from olmo_core.nn.transformer.config import TransformerActivationCheckpointingMode
from olmo_core.optim import OptimConfig, SkipStepOptimizer
from olmo_core.optim.scheduler import Scheduler
from olmo_core.utils import gc_cuda, get_default_device, log_once, move_to_device
Expand Down Expand Up @@ -142,6 +143,15 @@ def __init__(
"Training parallelism configs are only valid for distributed training"
)

if (
ac_config is not None
and ac_config.mode == TransformerActivationCheckpointingMode.budget
and not compile_model
):
raise OLMoConfigurationError(
"Activation checkpointing with 'budget' mode requires compilation to be enabled"
)

# Parallelize model.
self.model = parallelize_model(
model,
Expand Down
Loading