-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Description
This issue highlights a potential inaccuracy in the FLOPs calculation for decoder-based models with non-standard attention mechanisms. The current formula within megatron/training/training.py
appears to assume a full causal attention pattern, which can lead to an overestimation of FLOPs for models using more efficient, specialized attention structures.
Accurate FLOPs calculation is crucial for performance analysis and model comparison. This has been a recent focus in other major frameworks, with Google's MaxText, for instance, refining its formulas for better precision (see PRs #2009 and #2030).
The Core Issue
The current attention FLOPs calculation is well-suited for models with standard dense attention.
However, it does not account for architectures that use more computationally efficient attention variants, such as:
- Sliding Window Local Attention (e.g., used in Gemma 3)
- Chunked Attention (e.g., Llama 4)
For these models, the current formula will overestimate the true computational cost, with the margin of error increasing with sequence length.
Proposed Solution
We recommend enhancing the FLOPs calculation logic to be architecture-aware. This could be achieved by introducing a dispatch mechanism that selects a specific FLOPs formula based on the model's configuration.
For example, the calculation function could check for arguments like attention_type
in the model configuration and apply the appropriate formula for "local," "chunked," or "standard" attention.
This would create a more flexible and accurate framework for evaluating both current and future models within Megatron-LM.
Thank you for your consideration.