Added Custom Learning Rate
This commit is contained in:
parent
b805dc538e
commit
1f9c30b531
@ -1,41 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
# custom LR from attention is all you need
|
|
||||||
class Custom_lr():
|
|
||||||
def __init__(self, d_model: int, warmup_step:int) -> None:
|
|
||||||
|
|
||||||
self.__d_model = d_model
|
|
||||||
self.__warmup_step = warmup_step
|
|
||||||
self.__epoch = 0
|
|
||||||
|
|
||||||
|
|
||||||
def step(self) -> int:
|
|
||||||
self.__epoch += 1
|
|
||||||
return (self.__d_model ** -0.5) * min(self.__epoch ** -0.5,
|
|
||||||
self.__epoch * (self.__warmup_step ** -1.5))
|
|
||||||
|
|
||||||
# OTHER LR
|
|
||||||
|
|
||||||
# Learning rate schedules (matching visualization parameters)
|
|
||||||
def step_lr(epoch, lr):
|
|
||||||
# StepLR: step_size=20, gamma=0.5 (from visualization)
|
|
||||||
return lr * 0.5 if epoch % 20 == 0 and epoch > 0 else lr
|
|
||||||
|
|
||||||
def exp_lr(epoch, lr):
|
|
||||||
# ExponentialLR: gamma=0.95 (from visualization)
|
|
||||||
return lr * 0.95
|
|
||||||
|
|
||||||
def cosine_lr(epoch, lr):
|
|
||||||
# CosineAnnealingLR: lr_min=0.001, lr_max=0.1, max_epochs=100 (from visualization)
|
|
||||||
lr_min, lr_max = 0.001, 0.1
|
|
||||||
max_epochs = 100
|
|
||||||
return lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(epoch * np.pi / max_epochs))
|
|
||||||
|
|
||||||
def cyclical_lr(epoch, lr):
|
|
||||||
# CyclicalLR: base_lr=0.001, max_lr=0.1, step_size=20 (from visualization)
|
|
||||||
base_lr = 0.001
|
|
||||||
max_lr = 0.1
|
|
||||||
step_size = 20
|
|
||||||
|
|
||||||
cycle = np.floor(1 + epoch / (2 * step_size))
|
|
||||||
x = np.abs(epoch / step_size - 2 * cycle + 1)
|
|
||||||
return base_lr + (max_lr - base_lr) * max(0, (1 - x))
|
|
||||||
47
Project_Model/Libs/Transformer/Classes/WarmupLR.py
Normal file
47
Project_Model/Libs/Transformer/Classes/WarmupLR.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import override
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# custom LR from attention is all you need
|
||||||
|
class WarmupLR(torch.optim.lr_scheduler.LRScheduler):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
warmup_steps: int,
|
||||||
|
embedding_size: int,
|
||||||
|
warming_multiplier: float = -1.5,
|
||||||
|
decaying_multiplier: float = -0.5,
|
||||||
|
multiplicative_factor: float = 1.0,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
) -> None:
|
||||||
|
self.__warmup_steps = warmup_steps
|
||||||
|
self.__embedding_size = embedding_size
|
||||||
|
self.__warming_multiplier = warming_multiplier
|
||||||
|
self.__decaying_multiplier = decaying_multiplier
|
||||||
|
self.__multiplicative_factor = multiplicative_factor
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def __scale_at(self, step: int) -> float:
|
||||||
|
step = max(step, 1)
|
||||||
|
return (
|
||||||
|
self.__multiplicative_factor
|
||||||
|
* (self.__embedding_size**self.__decaying_multiplier)
|
||||||
|
* min(
|
||||||
|
step**self.__decaying_multiplier,
|
||||||
|
step * (self.__warmup_steps**self.__warming_multiplier),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_lr(self) -> list[float]:
|
||||||
|
torch.optim.lr_scheduler._warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
|
step = max(self.last_epoch, 1)
|
||||||
|
scale = self.__scale_at(step)
|
||||||
|
return [base_lr * scale for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def _get_closed_form_lr(self):
|
||||||
|
step = max(self.last_epoch, 1)
|
||||||
|
scale = self.__scale_at(step)
|
||||||
|
return [base_lr * scale for base_lr in self.base_lrs]
|
||||||
@ -5,6 +5,7 @@ from .FeedForwardNetwork import FeedForwardNetwork
|
|||||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention
|
from .TorchMultiHeadAttention import TorchMultiHeadAttention
|
||||||
from .SpannedMasker import SpannedMasker
|
from .SpannedMasker import SpannedMasker
|
||||||
from .DeToken import DeToken
|
from .DeToken import DeToken
|
||||||
|
from .WarmupLR import WarmupLR
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Decoder",
|
"Decoder",
|
||||||
@ -12,5 +13,6 @@ __all__ = [
|
|||||||
"FeedForwardNetwork",
|
"FeedForwardNetwork",
|
||||||
"TorchMultiHeadAttention",
|
"TorchMultiHeadAttention",
|
||||||
"SpannedMasker",
|
"SpannedMasker",
|
||||||
"DeToken"
|
"DeToken",
|
||||||
|
"WarmupLR"
|
||||||
]
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user