Skip to content

[BUG] MTP loss is not accumulated correctly for logging when log_interval > 1 #1686

@Luowaterbi

Description

@Luowaterbi

Describe the bug
In megatron/model/multi_token_prediction.py, the MTP loss is not accumulated correctly for logging purposes when the log_interval is set to a value greater than 1.

The current implementation overwrites the MTP loss value in the total_loss_dict at each step. Consequently, only the loss from the specific iteration that triggers the logging event is stored. This single loss value is then incorrectly divided by log_interval, resulting in a reported MTP loss that is artificially low and does not reflect the true average loss over the interval.

For example, with log_interval=10, the value logged at step 10 is effectively loss_from_step_10 / 10 instead of the correct (sum_of_losses_from_steps_1_to_10) / 10.

To Reproduce
Steps to reproduce the behavior:

  1. Enable Multi-Token Prediction by setting the --mtp-num-layers argument to a value greater than 0 (e.g., --mtp-num-layers 2).
  2. Set the logging interval to be greater than 1 (e.g., --log-interval 10 or --log-interval 100).
  3. Run training.
  4. Observe the logged loss values for MTP (e.g., 'mtp_1 loss', 'mtp_2 loss', etc.). The reported values will be significantly smaller than their expected true average.

Expected behavior
The MTP loss should be accumulated across all steps within a given log_interval. The logged value should represent the true average MTP loss over that period, allowing for accurate monitoring of the training process.

Stack trace/logs
Not applicable, as this is a silent numerical error in logging, not a crash.

However, a comparison of log outputs illustrates the issue:

Observed Log (with log_interval=10 and an assumed true average loss of ~1.0):

'mtp_1 loss': 0.1

(This is roughly loss_from_step_10 / 10)

Expected Log:

'mtp_1 loss': 1

(This would be the correct average over 10 steps)

Environment (please complete the following information):

  • Megatron-LM commit ID: 3a2a972d9ffd5cee534b761f0803ea1b07389d2f
  • Docker version: nvcr.io/nvidia/pytorch:25.04-py3

Proposed fix
The issue can be resolved by changing the direct assignment of the MTP loss in total_loss_dict to a conditional accumulation. This ensures that the loss from every step in the interval is summed up before averaging.

I have submitted a Pull Request with the proposed fix: #1684

Additional context
This is purely a logging/monitoring bug. It does not affect the actual model gradients, weight updates, or the main training loss. However, it provides a misleadingly optimistic view of the MTP loss, which can hinder proper model analysis and debugging.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions