Last fixes

This commit is contained in:
Christian Risi
2025-10-17 22:17:24 +02:00
parent 540b78204c
commit b79521995c
17 changed files with 50 additions and 6041 deletions

View File

@@ -1,10 +1,14 @@
import os
from pathlib import Path
class Log:
def __init__(self, path):
self.path = path
header = ["epoch","avg_txt","avg_enc","avg_dec","txt_loss","masking_loss","prediction_loss"]
if Path(path).is_file():
return
with open(self.path, "w", encoding="utf-8", newline="") as f:
f.write(",".join(header) + "\n")

View File

@@ -14,6 +14,7 @@ class NanoSocratesCore(torch.nn.Module):
sos: int,
pad: int,
eos: int,
continuerdf: int,
latent_space: int = 256,
feed_forward_multiplier: int = 4,
attention_heads: int = 4,
@@ -24,6 +25,7 @@ class NanoSocratesCore(torch.nn.Module):
self.__sos = sos
self.__pad = pad
self.__eos = eos
self.__continuerdf = continuerdf
self.__sentence_len = sentence_max_length
feed_forward_latent_space = latent_space * feed_forward_multiplier
@@ -156,7 +158,9 @@ class NanoSocratesCore(torch.nn.Module):
decoder_in_pad_mask = decoder_in.eq(self.__pad)
continue_generating = True
token_idx = 0
token_idx: int= int((decoder_in[0] == self.__continuerdf).nonzero()[0].item()) + 1
while continue_generating:

File diff suppressed because it is too large Load Diff