Added epochs
This commit is contained in:
parent
86a063591e
commit
540b78204c
@ -9,7 +9,7 @@ import Project_Model.Libs.Transformer as Transformer
|
||||
import Project_Model.Libs.TransformerUtils as TUtils
|
||||
import Project_Model.Libs.TorchShims as torch_shims
|
||||
import Project_Model.Libs.Batch as Batch
|
||||
from Project_Model.Libs.Training.loss_saver import Log
|
||||
from Project_Model.Libs.Training.loss_saver import Log
|
||||
|
||||
# set a fixed seed
|
||||
torch.manual_seed(0)
|
||||
@ -419,8 +419,8 @@ while current_epoch < MAX_EPOCHS:
|
||||
except:
|
||||
pass
|
||||
|
||||
# write on log
|
||||
loss_saver.write([txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss])
|
||||
# write on log
|
||||
loss_saver.write([current_epoch, txt_train_avg_loss,enc_avg_train_loss,dec_avg_train_loss,txt_avg_loss,enc_avg_loss,dec_avg_loss])
|
||||
SEPARATOR = "================================================================================================================"
|
||||
DEBUG_TEXT = "".join(
|
||||
[
|
||||
|
||||
@ -3,8 +3,8 @@ import os
|
||||
class Log:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
header = ["avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"]
|
||||
|
||||
header = ["epoch","avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"]
|
||||
|
||||
with open(self.path, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user