Skip to content

[ENHANCEMENT] Support for FusedLinearCrossEntropy Loss for memory savings #1738

@Skylion007

Description

@Skylion007

Is your feature request related to a problem? Please describe.
We already support FusedCrossEntropy losses, it would be even better to support fusing the last linear layer and the CrossEntropy layer to ensure even more memory savings. See pytorch/pytorch#124480 for more discussion. There is talk of adding this as a native operator in PyTorch to help facilitate this. Maybe this issue would be better open in TransformerEngine though. This would massively reduce the amount of memory needed for finetuning small models with large vocabularies or large vocabularies period.

Describe the solution you'd like
Support fusing the final linear layer with the cross entropy layer in a dtype aware way that avoid unnecessary computation based the precision of the floating point.

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

Proposed implementation
If you have a proposed implementation for the feature state it here or link to a PR.

See the linked Liger kernel or Apple's Cut Cross Entropy kernel.

Additional context
Add any other context or screenshots about the feature request here.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions