dev.train #8

Merged
gape_01 merged 50 commits from dev.train into dev 2025-10-17 22:20:14 +02:00
2 changed files with 4 additions and 5 deletions
Showing only changes of commit f463f699cf - Show all commits

View File

@ -203,7 +203,7 @@ while current_epoch < MAX_EPOCHS:
decoder_only_optim.zero_grad()
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
@ -291,7 +291,7 @@ while current_epoch < MAX_EPOCHS:
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
pred_logits = pred_logits.permute(0, 2, 1)

View File

@ -27,7 +27,6 @@ def truncate_rdf_list(
END_OF_TRIPLES.append(i + 1)
TRIPLES_TOKENS: list[int] = []
TARGET_TRIPLES: list[int] = []
start_of_triple = 0
exit_loop = False
@ -56,10 +55,10 @@ def truncate_rdf_list(
EOT = END_OF_TRIPLES.popleft()
TRIPLE = sequence[start_of_triple:EOT]
TARGET_TRIPLES.extend(TRIPLE)
TRIPLES_TOKENS.extend(TRIPLE)
start_of_triple = EOT
return (TRIPLES_TOKENS, TARGET_TRIPLES)
return (TRIPLES_TOKENS, TRIPLES_TOKENS)