Added a loss_saver file to save the losses
This commit is contained in:
parent
80fd7fd600
commit
f33d4f1db6
@ -9,6 +9,7 @@ import Project_Model.Libs.Transformer as Transformer
|
|||||||
import Project_Model.Libs.TransformerUtils as TUtils
|
import Project_Model.Libs.TransformerUtils as TUtils
|
||||||
import Project_Model.Libs.TorchShims as torch_shims
|
import Project_Model.Libs.TorchShims as torch_shims
|
||||||
import Project_Model.Libs.Batch as Batch
|
import Project_Model.Libs.Batch as Batch
|
||||||
|
from Project_Model.Libs.Training.loss_saver import Log
|
||||||
|
|
||||||
# set a fixed seed
|
# set a fixed seed
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@ -33,6 +34,9 @@ ENC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/enc_optim.zip")
|
|||||||
DEC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/dec_optim.zip")
|
DEC_OPTIM_PATH = Path(f"{CHECKPOINT_DIR}/dec_optim.zip")
|
||||||
LAST_EPOCH_PATH = Path(f"{CHECKPOINT_DIR}/last_epoch.txt")
|
LAST_EPOCH_PATH = Path(f"{CHECKPOINT_DIR}/last_epoch.txt")
|
||||||
|
|
||||||
|
# log saver:
|
||||||
|
loss_saver = Log(f"{CHECKPOINT_DIR}/log_loss.csv")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# BPE Init
|
# BPE Init
|
||||||
@ -140,6 +144,7 @@ average_loss_validation = {
|
|||||||
"decoder_only": float("inf"),
|
"decoder_only": float("inf"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
while current_epoch < MAX_EPOCHS:
|
while current_epoch < MAX_EPOCHS:
|
||||||
|
|
||||||
NANOSOCRATES.train()
|
NANOSOCRATES.train()
|
||||||
@ -373,6 +378,8 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# write on log
|
||||||
|
loss_saver.write([txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss])
|
||||||
SEPARATOR = "================================================================================================================"
|
SEPARATOR = "================================================================================================================"
|
||||||
DEBUG_TEXT = "".join(
|
DEBUG_TEXT = "".join(
|
||||||
[
|
[
|
||||||
|
|||||||
16
Project_Model/Libs/Training/loss_saver.py
Normal file
16
Project_Model/Libs/Training/loss_saver.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
class Log:
|
||||||
|
def __init__(self, path):
|
||||||
|
self.path = path
|
||||||
|
header = ["avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"]
|
||||||
|
|
||||||
|
with open(self.path, "w", encoding="utf-8", newline="") as f:
|
||||||
|
f.write(",".join(header) + "\n")
|
||||||
|
|
||||||
|
def write(self, loss: list[float]):
|
||||||
|
line = ",".join(str(float(x)) for x in loss) + "\n"
|
||||||
|
with open(self.path, "a", encoding="utf-8", newline="") as f:
|
||||||
|
f.write(line)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno()) # extra durability per write # suggested against sudden crashes since it will be done
|
||||||
Loading…
x
Reference in New Issue
Block a user