32 lines
904 B
Python
32 lines
904 B
Python
import torch
|
|
import Project_Model.Libs.Embedder as Embedder
|
|
from ..Classes import DeToken
|
|
|
|
class NanoSocraDecoder(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
decoder_embedder: Embedder.NanoSocratesEmbedder,
|
|
decoder_layers: torch.nn.Sequential,
|
|
detokener: DeToken
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
|
|
self.__decoder_embedder = decoder_embedder
|
|
self.__decoder = decoder_layers
|
|
self.__detokener = detokener
|
|
|
|
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
|
|
|
decoder_embedder_input, tgt_padding = args
|
|
|
|
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
|
|
|
|
decoder_output, _, _, _, _, _ = self.__decoder(
|
|
(decoder_tensor, decoder_tensor, decoder_tensor, tgt_padding, tgt_padding, True)
|
|
)
|
|
|
|
logits: torch.Tensor = self.__detokener(decoder_output)
|
|
|
|
return logits |