48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
from typing import override
|
|
import torch
|
|
|
|
|
|
# custom LR from attention is all you need
|
|
class WarmupLR(torch.optim.lr_scheduler.LRScheduler):
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer: torch.optim.Optimizer,
|
|
warmup_steps: int,
|
|
embedding_size: int,
|
|
warming_multiplier: float = -1.5,
|
|
decaying_multiplier: float = -0.5,
|
|
multiplicative_factor: float = 1.0,
|
|
last_epoch: int = -1,
|
|
) -> None:
|
|
self.__warmup_steps = warmup_steps
|
|
self.__embedding_size = embedding_size
|
|
self.__warming_multiplier = warming_multiplier
|
|
self.__decaying_multiplier = decaying_multiplier
|
|
self.__multiplicative_factor = multiplicative_factor
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def __scale_at(self, step: int) -> float:
|
|
step = max(step, 1)
|
|
return (
|
|
self.__multiplicative_factor
|
|
* (self.__embedding_size**self.__decaying_multiplier)
|
|
* min(
|
|
step**self.__decaying_multiplier,
|
|
step * (self.__warmup_steps**self.__warming_multiplier),
|
|
)
|
|
)
|
|
|
|
@override
|
|
def get_lr(self) -> list[float]:
|
|
torch.optim.lr_scheduler._warn_get_lr_called_within_step(self)
|
|
|
|
step = max(self.last_epoch, 1)
|
|
scale = self.__scale_at(step)
|
|
return [base_lr * scale for base_lr in self.base_lrs]
|
|
|
|
def _get_closed_form_lr(self):
|
|
step = max(self.last_epoch, 1)
|
|
scale = self.__scale_at(step)
|
|
return [base_lr * scale for base_lr in self.base_lrs]
|