Skip to content

[BUG] Learning rate not overrided when set --override-opt_param-scheduler #1138

@TissueC

Description

@TissueC

Describe the bug
When setting override-opt_param-scheduler (but still load optimizer and load rng) and setting new learning rate scheduler (including max lr, min lr, decay style, etc.), the learning rate still persists its original scheduler.

A related issue could be #963

To Reproduce

  1. Set max_lr as 6e-4 and constant learning rate.
  2. Training some steps and save checkpoint.
  3. Load the checkpoint (including optimizer params) and override the scheduler (e.g. cosine 6e-4 to 6e-5).
  4. Then the bug shows: the learning rate is still constantly 6e-4.

Expected behavior
The learning rate scheduler should be overrided.

Environment (please complete the following information):

Proposed fix
The following code (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/optimizer_param_scheduler.py#L121) leads to reloading the optimizer's learning rate when setting override-opt_param-scheduler

def get_lr(self, param_group: dict) -> float:
        """Learning rate decay functions from:
        https://openreview.net/pdf?id=BJYwwY9ll pg. 4

        Args:
            param_group (dict): parameter group from the optimizer.
        """

        max_lr = param_group.get('max_lr', self.max_lr)
        min_lr = param_group.get('min_lr', self.min_lr)

A possible solution could be

def get_lr(self, param_group: dict) -> float:
        """Learning rate decay functions from:
        https://openreview.net/pdf?id=BJYwwY9ll pg. 4

        Args:
            param_group (dict): parameter group from the optimizer.
        """

        max_lr = self.max_lr
        min_lr = self.min_lr

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions