Merge branch 'dev' of https://repositories.communitynotfound.work/PoliBa-DeepLearning/NanoSocrates into dev
This commit is contained in:
commit
79d3fb9ff8
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -98,7 +99,7 @@ NANOSOCRATES = Transformer.TrainingModel(
|
|||||||
EMBEDDED_SIZE,
|
EMBEDDED_SIZE,
|
||||||
FEED_FORWARD_MULTIPLIER,
|
FEED_FORWARD_MULTIPLIER,
|
||||||
ATTENTION_HEADS,
|
ATTENTION_HEADS,
|
||||||
NUMBER_OF_BLOCKS
|
NUMBER_OF_BLOCKS,
|
||||||
)
|
)
|
||||||
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
||||||
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
|
optimizer = torch.optim.AdamW(NANOSOCRATES.parameters())
|
||||||
@ -120,21 +121,31 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
last_loss = 0
|
last_loss = 0
|
||||||
last_prediction: torch.Tensor
|
last_prediction: torch.Tensor
|
||||||
|
|
||||||
|
LOSS_HISTORY = []
|
||||||
|
|
||||||
|
start = time.time_ns()
|
||||||
|
|
||||||
|
|
||||||
for i in range(0, SENTENCE_LENGTH):
|
for i in range(0, SENTENCE_LENGTH):
|
||||||
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
tgt_padding = decoder_list.eq(PAD_TOKEN)
|
tgt_padding = decoder_list.eq(PAD_TOKEN)
|
||||||
|
|
||||||
logits: torch.Tensor = NANOSOCRATES((encoder_list, src_padding, decoder_list, tgt_padding))
|
logits: torch.Tensor = NANOSOCRATES(
|
||||||
|
(encoder_list, src_padding, decoder_list, tgt_padding)
|
||||||
|
)
|
||||||
prob = torch.softmax(logits, 2)
|
prob = torch.softmax(logits, 2)
|
||||||
|
|
||||||
most_probable_tokens = torch.argmax(prob, 2)
|
most_probable_tokens = torch.argmax(prob, 2)
|
||||||
last_prediction = most_probable_tokens
|
last_prediction = most_probable_tokens
|
||||||
|
|
||||||
logits = logits[:,:i,:]
|
logits = logits[:, i, :]
|
||||||
logits = logits.permute(0, 2, 1)
|
# logits = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
loss : torch.Tensor = cross_entropy(logits, target_logits[:, 0:i])
|
loss: torch.Tensor = cross_entropy(logits, target_logits[:, i])
|
||||||
|
LOSS_HISTORY.append(loss.item())
|
||||||
|
# loss : torch.Tensor = cross_entropy(logits, target_logits[:, 0:i])
|
||||||
# loss : torch.Tensor = cross_entropy(logits, target_logits)
|
# loss : torch.Tensor = cross_entropy(logits, target_logits)
|
||||||
|
|
||||||
last_loss = loss
|
last_loss = loss
|
||||||
@ -145,26 +156,22 @@ while current_epoch < MAX_EPOCHS:
|
|||||||
if i < SENTENCE_LENGTH - 1:
|
if i < SENTENCE_LENGTH - 1:
|
||||||
decoder_list[:, i + 1] = target_logits[:, i]
|
decoder_list[:, i + 1] = target_logits[:, i]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
current_epoch += 1
|
current_epoch += 1
|
||||||
|
|
||||||
|
end = time.time_ns()
|
||||||
|
|
||||||
if current_epoch % 1 == 0:
|
if current_epoch % 1 == 0:
|
||||||
print(f"EPOCH {current_epoch}\n\tLoss: {last_loss}")
|
MIN_LOSS = min(LOSS_HISTORY)
|
||||||
|
MAX_LOSS = max(LOSS_HISTORY)
|
||||||
for encoded_sentence, expected_sentence in zip(
|
AVERAGE_LOSS = sum(LOSS_HISTORY)/len(LOSS_HISTORY)
|
||||||
Transformer.tensor2token(last_prediction[:,:], END_TOKEN), # type: ignore
|
print(f"EPOCH {current_epoch}\n\tTime: {(end-start)/1E9}s\n\tLoss: {last_loss}")
|
||||||
Transformer.tensor2token(target_logits[:,:], END_TOKEN)
|
print(f"\tMin Loss: {MIN_LOSS}\tAvg Loss: {AVERAGE_LOSS}\tMax Loss: {MAX_LOSS}\n")
|
||||||
):
|
# for encoded_sentence, expected_sentence in zip(
|
||||||
decoded_sentence = TOKENANO.decode(encoded_sentence)
|
# Transformer.tensor2token(last_prediction[:, :], END_TOKEN), # type: ignore
|
||||||
decoded_target = TOKENANO.decode(expected_sentence)
|
# Transformer.tensor2token(target_logits[:, :], END_TOKEN),
|
||||||
print(f"\tACTUAL:\n\t\t{decoded_sentence}\n\tEXPECTED:\n\t\t{decoded_target}\n")
|
# ):
|
||||||
|
# decoded_sentence = TOKENANO.decode(encoded_sentence)
|
||||||
|
# decoded_target = TOKENANO.decode(expected_sentence)
|
||||||
|
# print(
|
||||||
|
# f"\tACTUAL:\n\t\t{decoded_sentence}\n\tEXPECTED:\n\t\t{decoded_target}\n"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user