33 lines
931 B
Python
Raw Normal View History

import torch
import Project_Model.Libs.Embedder as Embedder
from ..Classes import DeToken
class NanoSocraDecoder(torch.nn.Module):
def __init__(
self,
decoder_embedder: Embedder.NanoSocratesEmbedder,
decoder_layers: torch.nn.Sequential,
detokener: DeToken
2025-10-12 16:30:30 +02:00
) -> None:
super().__init__()
self.__decoder_embedder = decoder_embedder
self.__decoder = decoder_layers
self.__detokener = detokener
2025-10-12 16:30:30 +02:00
def forward(self, args: tuple[torch.Tensor,torch.Tensor, torch.Tensor]):
2025-10-12 16:30:30 +02:00
decoder_embedder_input, prefix_mask, tgt_padding = args
decoder_tensor = self.__decoder_embedder(decoder_embedder_input)
decoder_output, _, _, _, _, _ = self.__decoder(
2025-10-12 16:30:30 +02:00
(decoder_tensor, decoder_tensor, decoder_tensor, prefix_mask, tgt_padding, True)
)
logits: torch.Tensor = self.__detokener(decoder_output)
return logits