from typing import override import torch # custom LR from attention is all you need class WarmupLR(torch.optim.lr_scheduler.LRScheduler): def __init__( self, optimizer: torch.optim.Optimizer, warmup_steps: int, embedding_size: int, warming_multiplier: float = -1.5, decaying_multiplier: float = -0.5, multiplicative_factor: float = 1.0, last_epoch: int = -1, ) -> None: self.__warmup_steps = warmup_steps self.__embedding_size = embedding_size self.__warming_multiplier = warming_multiplier self.__decaying_multiplier = decaying_multiplier self.__multiplicative_factor = multiplicative_factor super().__init__(optimizer, last_epoch) def __scale_at(self, step: int) -> float: step = max(step, 1) return ( self.__multiplicative_factor * (self.__embedding_size**self.__decaying_multiplier) * min( step**self.__decaying_multiplier, step * (self.__warmup_steps**self.__warming_multiplier), ) ) @override def get_lr(self) -> list[float]: torch.optim.lr_scheduler._warn_get_lr_called_within_step(self) step = max(self.last_epoch, 1) scale = self.__scale_at(step) return [base_lr * scale for base_lr in self.base_lrs] def _get_closed_form_lr(self): step = max(self.last_epoch, 1) scale = self.__scale_at(step) return [base_lr * scale for base_lr in self.base_lrs]