Skip to content

[QUESTION] Is there a way for Megatron to recompute the whole transformer layer except for the flash-attn part? #1732

@xUhEngwAng

Description

@xUhEngwAng

Hi, as stated in the title, I'm wondering whether megatron provides native functionality like context_fn in torch.utils.checkpoint.checkpoint, such that flash-attn computation can be excluded from the recomputation of a transformer layer.

For now, I manully modified tensor_parallel.checkpoint to accept such an argument. However, issues remain when I want to capture the saved activations in flash-attn and offload them to cpu.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions