Skip to content

fix(mtp logging): Correctly accumulate MTP loss for logging when log_interval > 1 #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Luowaterbi
Copy link
Contributor

In multi_token_prediction.py, the MTP loss is not correctly accumulated for logging when the log_interval is greater than 1.

The current implementation overwrites the
total_loss_dict[f"mtp_{i+1} loss"] at each step. This means only the loss from the logging step itself (e.g., the 10th step if log_interval=10) is stored. This single value is then incorrectly divided by log_interval, resulting in a reported MTP loss that is artificially low and does not reflect the true average.

This commit modifies the logic to correctly accumulate the f"mtp_{i+1} loss" across all steps within a logging interval by checking for the key's existence and using addition (+=) instead of assignment.

This fix ensures the reported MTP loss is accurate.

…interval > 1

In `multi_token_prediction.py`, the MTP loss is not correctly accumulated for logging when the
log_interval is greater than 1.

The current implementation overwrites the
`total_loss_dict[f"mtp_{i+1} loss"]` at each step. This means only the loss from the logging step itself (e.g., the 10th step
if `log_interval=10`) is stored. This single value is then incorrectly
divided by `log_interval`, resulting in a reported MTP loss that is
artificially low and does not reflect the true average.

This commit modifies the logic to correctly accumulate the `f"mtp_{i+1} loss"`
across all steps within a logging interval by checking for
the key's existence and using addition (`+=`) instead of assignment.

This fix ensures the reported MTP loss is accurate.
@yanring
Copy link
Collaborator

yanring commented Jul 21, 2025

Thank you for the PR. We'll help merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants