2025-10-09 11:36:40 +02:00

48 lines
1.5 KiB
Python

from typing import override
import torch
# custom LR from attention is all you need
class WarmupLR(torch.optim.lr_scheduler.LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
embedding_size: int,
warming_multiplier: float = -1.5,
decaying_multiplier: float = -0.5,
multiplicative_factor: float = 1.0,
last_epoch: int = -1,
) -> None:
self.__warmup_steps = warmup_steps
self.__embedding_size = embedding_size
self.__warming_multiplier = warming_multiplier
self.__decaying_multiplier = decaying_multiplier
self.__multiplicative_factor = multiplicative_factor
super().__init__(optimizer, last_epoch)
def __scale_at(self, step: int) -> float:
step = max(step, 1)
return (
self.__multiplicative_factor
* (self.__embedding_size**self.__decaying_multiplier)
* min(
step**self.__decaying_multiplier,
step * (self.__warmup_steps**self.__warming_multiplier),
)
)
@override
def get_lr(self) -> list[float]:
torch.optim.lr_scheduler._warn_get_lr_called_within_step(self)
step = max(self.last_epoch, 1)
scale = self.__scale_at(step)
return [base_lr * scale for base_lr in self.base_lrs]
def _get_closed_form_lr(self):
step = max(self.last_epoch, 1)
scale = self.__scale_at(step)
return [base_lr * scale for base_lr in self.base_lrs]