From 540b78204cbf1e1c1e8001609f4a2579ce4057eb Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Fri, 17 Oct 2025 17:06:42 +0200 Subject: [PATCH] Added epochs --- Playgrounds/nanosocrates-train-experiment-2.py | 6 +++--- Project_Model/Libs/Training/loss_saver.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Playgrounds/nanosocrates-train-experiment-2.py b/Playgrounds/nanosocrates-train-experiment-2.py index f8c8a56..ce81588 100644 --- a/Playgrounds/nanosocrates-train-experiment-2.py +++ b/Playgrounds/nanosocrates-train-experiment-2.py @@ -9,7 +9,7 @@ import Project_Model.Libs.Transformer as Transformer import Project_Model.Libs.TransformerUtils as TUtils import Project_Model.Libs.TorchShims as torch_shims import Project_Model.Libs.Batch as Batch -from Project_Model.Libs.Training.loss_saver import Log +from Project_Model.Libs.Training.loss_saver import Log # set a fixed seed torch.manual_seed(0) @@ -419,8 +419,8 @@ while current_epoch < MAX_EPOCHS: except: pass - # write on log - loss_saver.write([txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss]) + # write on log + loss_saver.write([current_epoch, txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss]) SEPARATOR = "================================================================================================================" DEBUG_TEXT = "".join( [ diff --git a/Project_Model/Libs/Training/loss_saver.py b/Project_Model/Libs/Training/loss_saver.py index f3f5d08..3183fac 100644 --- a/Project_Model/Libs/Training/loss_saver.py +++ b/Project_Model/Libs/Training/loss_saver.py @@ -3,8 +3,8 @@ import os class Log: def __init__(self, path): self.path = path - header = ["avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"] - + header = ["epoch","avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"] + with open(self.path, "w", encoding="utf-8", newline="") as f: f.write(",".join(header) + "\n")