import torch import Project_Model.Libs.Embedder as Embedder from ..Classes import DeToken class NanoSocraDecoder(torch.nn.Module): def __init__( self, decoder_embedder: Embedder.NanoSocratesEmbedder, decoder_layers: torch.nn.Sequential, detokener: DeToken ) -> None: super().__init__() self.__decoder_embedder = decoder_embedder self.__decoder = decoder_layers self.__detokener = detokener def forward(self, args: tuple[torch.Tensor,torch.Tensor, torch.Tensor]): decoder_embedder_input, prefix_mask, tgt_padding = args decoder_tensor = self.__decoder_embedder(decoder_embedder_input) decoder_output, _, _, _, _, _ = self.__decoder( (decoder_tensor, decoder_tensor, decoder_tensor, prefix_mask, tgt_padding, True) ) logits: torch.Tensor = self.__detokener(decoder_output) return logits