Fixed a bug over task 4
This commit is contained in:
parent
ab3d68bc13
commit
f463f699cf
@ -203,7 +203,7 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
|
|
||||||
decoder_only_optim.zero_grad()
|
decoder_only_optim.zero_grad()
|
||||||
|
|
||||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
|
||||||
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
loss: torch.Tensor = decoder_ce(pred_logits, tgt)
|
||||||
@ -291,7 +291,7 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
pred_logits = DECODER_ONLY((enc_x, enc_x_pad))
|
pred_logits = DECODER_ONLY((dec_x, dec_x_pad))
|
||||||
|
|
||||||
pred_logits = pred_logits.permute(0, 2, 1)
|
pred_logits = pred_logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,6 @@ def truncate_rdf_list(
|
|||||||
END_OF_TRIPLES.append(i + 1)
|
END_OF_TRIPLES.append(i + 1)
|
||||||
|
|
||||||
TRIPLES_TOKENS: list[int] = []
|
TRIPLES_TOKENS: list[int] = []
|
||||||
TARGET_TRIPLES: list[int] = []
|
|
||||||
|
|
||||||
start_of_triple = 0
|
start_of_triple = 0
|
||||||
exit_loop = False
|
exit_loop = False
|
||||||
@ -56,10 +55,10 @@ def truncate_rdf_list(
|
|||||||
EOT = END_OF_TRIPLES.popleft()
|
EOT = END_OF_TRIPLES.popleft()
|
||||||
|
|
||||||
TRIPLE = sequence[start_of_triple:EOT]
|
TRIPLE = sequence[start_of_triple:EOT]
|
||||||
TARGET_TRIPLES.extend(TRIPLE)
|
TRIPLES_TOKENS.extend(TRIPLE)
|
||||||
|
|
||||||
start_of_triple = EOT
|
start_of_triple = EOT
|
||||||
|
|
||||||
|
|
||||||
return (TRIPLES_TOKENS, TARGET_TRIPLES)
|
return (TRIPLES_TOKENS, TRIPLES_TOKENS)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user