30 lines
821 B
Python
30 lines
821 B
Python
|
|
import torch
|
||
|
|
import Project_Model.Libs.Embedder as Embedder
|
||
|
|
from ..Classes import DeToken
|
||
|
|
|
||
|
|
class NanoSocratEncoder(torch.nn.Module):
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
encoder_embedder: Embedder.NanoSocratesEmbedder,
|
||
|
|
encoder_layers: torch.nn.Sequential,
|
||
|
|
detokener: DeToken
|
||
|
|
) -> None:
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
self.__encoder_embedder = encoder_embedder
|
||
|
|
self.__encoder = encoder_layers
|
||
|
|
self.__detokener = detokener
|
||
|
|
|
||
|
|
def forward(self, args: tuple[torch.Tensor, torch.Tensor]):
|
||
|
|
|
||
|
|
encoder_embedder_input, src_padding = args
|
||
|
|
|
||
|
|
encoder_tensor = self.__encoder_embedder(encoder_embedder_input)
|
||
|
|
|
||
|
|
encoder_output, _ = self.__encoder((encoder_tensor, src_padding))
|
||
|
|
|
||
|
|
logits: torch.Tensor = self.__detokener(encoder_output)
|
||
|
|
|
||
|
|
return logits
|