Spaces:
Running
on
Zero
Running
on
Zero
| """This file contains code to run different learning rate schedulers. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| Reference: | |
| https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py | |
| """ | |
| import math | |
| from enum import Enum | |
| from typing import Optional, Union | |
| import torch | |
| class SchedulerType(Enum): | |
| COSINE = "cosine" | |
| def get_cosine_schedule_with_warmup( | |
| optimizer: torch.optim.Optimizer, | |
| num_warmup_steps: int, | |
| num_training_steps: int, | |
| num_cycles: float = 0.5, | |
| last_epoch: int = -1, | |
| base_lr: float = 1e-4, | |
| end_lr: float = 0.0, | |
| ): | |
| """Creates a cosine learning rate schedule with warm-up and ending learning rate. | |
| Args: | |
| optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. | |
| num_warmup_steps: An integer, the number of steps for the warmup phase. | |
| num_training_steps: An integer, the total number of training steps. | |
| num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to | |
| just decrease from the max value to 0 following a half-cosine). | |
| last_epoch: An integer, the index of the last epoch when resuming training. | |
| base_lr: A float, the base learning rate. | |
| end_lr: A float, the final learning rate. | |
| Return: | |
| `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | |
| """ | |
| def lr_lambda(current_step): | |
| if current_step < num_warmup_steps: | |
| return float(current_step) / float(max(1, num_warmup_steps)) | |
| progress = float(current_step - num_warmup_steps) / \ | |
| float(max(1, num_training_steps - num_warmup_steps)) | |
| ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) | |
| return (end_lr + (base_lr - end_lr) * ratio) / base_lr | |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) | |
| TYPE_TO_SCHEDULER_FUNCTION = { | |
| SchedulerType.COSINE: get_cosine_schedule_with_warmup, | |
| } | |
| def get_scheduler( | |
| name: Union[str, SchedulerType], | |
| optimizer: torch.optim.Optimizer, | |
| num_warmup_steps: Optional[int] = None, | |
| num_training_steps: Optional[int] = None, | |
| base_lr: float = 1e-4, | |
| end_lr: float = 0.0, | |
| ): | |
| """Retrieves a learning rate scheduler from the given name and optimizer. | |
| Args: | |
| name: A string or SchedulerType, the name of the scheduler to retrieve. | |
| optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler. | |
| num_warmup_steps: An integer, the number of warmup steps. | |
| num_training_steps: An integer, the total number of training steps. | |
| base_lr: A float, the base learning rate. | |
| end_lr: A float, the final learning rate. | |
| Returns: | |
| A instance of torch.optim.lr_scheduler.LambdaLR | |
| Raises: | |
| ValueError: If num_warmup_steps or num_training_steps is not provided. | |
| """ | |
| name = SchedulerType(name) | |
| schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] | |
| if num_warmup_steps is None: | |
| raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") | |
| if num_training_steps is None: | |
| raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") | |
| return schedule_func( | |
| optimizer, | |
| num_warmup_steps=num_warmup_steps, | |
| num_training_steps=num_training_steps, | |
| base_lr=base_lr, | |
| end_lr=end_lr, | |
| ) |