import torch import Project_Model.Libs.Embedder as Embedder from ..Classes import DeToken class NanoSocratEncoder(torch.nn.Module): def __init__( self, encoder_embedder: Embedder.NanoSocratesEmbedder, encoder_layers: torch.nn.Sequential, detokener: DeToken ) -> None: super().__init__() self.__encoder_embedder = encoder_embedder self.__encoder = encoder_layers self.__detokener = detokener def forward(self, args: tuple[torch.Tensor, torch.Tensor]): encoder_embedder_input, src_padding = args encoder_tensor = self.__encoder_embedder(encoder_embedder_input) encoder_output, _ = self.__encoder((encoder_tensor, src_padding)) logits: torch.Tensor = self.__detokener(encoder_output) return logits