-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Is your feature request related to a problem? Please describe.
We already support FusedCrossEntropy losses, it would be even better to support fusing the last linear layer and the CrossEntropy layer to ensure even more memory savings. See pytorch/pytorch#124480 for more discussion. There is talk of adding this as a native operator in PyTorch to help facilitate this. Maybe this issue would be better open in TransformerEngine though. This would massively reduce the amount of memory needed for finetuning small models with large vocabularies or large vocabularies period.
Describe the solution you'd like
Support fusing the final linear layer with the cross entropy layer in a dtype aware way that avoid unnecessary computation based the precision of the floating point.
Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
Proposed implementation
If you have a proposed implementation for the feature state it here or link to a PR.
See the linked Liger kernel or Apple's Cut Cross Entropy kernel.
Additional context
Add any other context or screenshots about the feature request here.