Added Custom Learning Rate
This commit is contained in:
47
Project_Model/Libs/Transformer/Classes/WarmupLR.py
Normal file
47
Project_Model/Libs/Transformer/Classes/WarmupLR.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import override
|
||||
import torch
|
||||
|
||||
|
||||
# custom LR from attention is all you need
|
||||
class WarmupLR(torch.optim.lr_scheduler.LRScheduler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: int,
|
||||
embedding_size: int,
|
||||
warming_multiplier: float = -1.5,
|
||||
decaying_multiplier: float = -0.5,
|
||||
multiplicative_factor: float = 1.0,
|
||||
last_epoch: int = -1,
|
||||
) -> None:
|
||||
self.__warmup_steps = warmup_steps
|
||||
self.__embedding_size = embedding_size
|
||||
self.__warming_multiplier = warming_multiplier
|
||||
self.__decaying_multiplier = decaying_multiplier
|
||||
self.__multiplicative_factor = multiplicative_factor
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def __scale_at(self, step: int) -> float:
|
||||
step = max(step, 1)
|
||||
return (
|
||||
self.__multiplicative_factor
|
||||
* (self.__embedding_size**self.__decaying_multiplier)
|
||||
* min(
|
||||
step**self.__decaying_multiplier,
|
||||
step * (self.__warmup_steps**self.__warming_multiplier),
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def get_lr(self) -> list[float]:
|
||||
torch.optim.lr_scheduler._warn_get_lr_called_within_step(self)
|
||||
|
||||
step = max(self.last_epoch, 1)
|
||||
scale = self.__scale_at(step)
|
||||
return [base_lr * scale for base_lr in self.base_lrs]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
step = max(self.last_epoch, 1)
|
||||
scale = self.__scale_at(step)
|
||||
return [base_lr * scale for base_lr in self.base_lrs]
|
||||
@@ -5,6 +5,7 @@ from .FeedForwardNetwork import FeedForwardNetwork
|
||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention
|
||||
from .SpannedMasker import SpannedMasker
|
||||
from .DeToken import DeToken
|
||||
from .WarmupLR import WarmupLR
|
||||
|
||||
__all__ = [
|
||||
"Decoder",
|
||||
@@ -12,5 +13,6 @@ __all__ = [
|
||||
"FeedForwardNetwork",
|
||||
"TorchMultiHeadAttention",
|
||||
"SpannedMasker",
|
||||
"DeToken"
|
||||
"DeToken",
|
||||
"WarmupLR"
|
||||
]
|
||||
Reference in New Issue
Block a user