fix(mtp logging): Correctly accumulate MTP loss for logging when log_interval > 1 #1684
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In
multi_token_prediction.py
, the MTP loss is not correctly accumulated for logging when the log_interval is greater than 1.The current implementation overwrites the
total_loss_dict[f"mtp_{i+1} loss"]
at each step. This means only the loss from the logging step itself (e.g., the 10th step iflog_interval=10
) is stored. This single value is then incorrectly divided bylog_interval
, resulting in a reported MTP loss that is artificially low and does not reflect the true average.This commit modifies the logic to correctly accumulate the
f"mtp_{i+1} loss"
across all steps within a logging interval by checking for the key's existence and using addition (+=
) instead of assignment.This fix ensures the reported MTP loss is accurate.