You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, as stated in the title, I'm wondering whether megatron provides native functionality like context_fn in torch.utils.checkpoint.checkpoint, such that flash-attn computation can be excluded from the recomputation of a transformer layer.
For now, I manully modified tensor_parallel.checkpoint to accept such an argument. However, issues remain when I want to capture the saved activations in flash-attn and offload them to cpu.