diff --git a/Project_Model/Libs/Training/learning_rade_shedulers.py b/Project_Model/Libs/Training/learning_rade_shedulers.py index 08bc319..0a98bba 100644 --- a/Project_Model/Libs/Training/learning_rade_shedulers.py +++ b/Project_Model/Libs/Training/learning_rade_shedulers.py @@ -5,11 +5,13 @@ class Custom_lr(): self.__d_model = d_model self.__warmup_step = warmup_step + self.__epoch = 0 - def get_lr(self,epoch) -> int: - return (self.__d_model ** -0.5) * min(epoch ** -0.5, - epoch * (self.__warmup_step ** -1.5)) + def step(self) -> int: + self.__epoch += 1 + return (self.__d_model ** -0.5) * min(self.__epoch ** -0.5, + self.__epoch * (self.__warmup_step ** -1.5)) # OTHER LR