-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Describe the bug
In megatron/model/multi_token_prediction.py
, the MTP loss is not accumulated correctly for logging purposes when the log_interval
is set to a value greater than 1.
The current implementation overwrites the MTP loss value in the total_loss_dict
at each step. Consequently, only the loss from the specific iteration that triggers the logging event is stored. This single loss value is then incorrectly divided by log_interval
, resulting in a reported MTP loss that is artificially low and does not reflect the true average loss over the interval.
For example, with log_interval=10
, the value logged at step 10 is effectively loss_from_step_10 / 10
instead of the correct (sum_of_losses_from_steps_1_to_10) / 10
.
To Reproduce
Steps to reproduce the behavior:
- Enable Multi-Token Prediction by setting the
--mtp-num-layers
argument to a value greater than 0 (e.g.,--mtp-num-layers 2
). - Set the logging interval to be greater than 1 (e.g.,
--log-interval 10
or--log-interval 100
). - Run training.
- Observe the logged loss values for MTP (e.g.,
'mtp_1 loss'
,'mtp_2 loss'
, etc.). The reported values will be significantly smaller than their expected true average.
Expected behavior
The MTP loss should be accumulated across all steps within a given log_interval
. The logged value should represent the true average MTP loss over that period, allowing for accurate monitoring of the training process.
Stack trace/logs
Not applicable, as this is a silent numerical error in logging, not a crash.
However, a comparison of log outputs illustrates the issue:
Observed Log (with log_interval=10
and an assumed true average loss of ~1.0):
'mtp_1 loss': 0.1
(This is roughly loss_from_step_10 / 10
)
Expected Log:
'mtp_1 loss': 1
(This would be the correct average over 10 steps)
Environment (please complete the following information):
- Megatron-LM commit ID:
3a2a972d9ffd5cee534b761f0803ea1b07389d2f
- Docker version:
nvcr.io/nvidia/pytorch:25.04-py3
Proposed fix
The issue can be resolved by changing the direct assignment of the MTP loss in total_loss_dict
to a conditional accumulation. This ensures that the loss from every step in the interval is summed up before averaging.
I have submitted a Pull Request with the proposed fix: #1684
Additional context
This is purely a logging/monitoring bug. It does not affect the actual model gradients, weight updates, or the main training loss. However, it provides a misleadingly optimistic view of the MTP loss, which can hinder proper model analysis and debugging.