23 lines
634 B
Python
23 lines
634 B
Python
|
|
import torch
|
||
|
|
from NanoSocratesCore import NanoSocratesCore
|
||
|
|
|
||
|
|
class NanoSocrates(torch.nn.Module):
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
embedded_size: int,
|
||
|
|
feed_forward_dim: int,
|
||
|
|
encoder_layers: int,
|
||
|
|
decoder_layers:int,
|
||
|
|
attention_heads: int,
|
||
|
|
vocab_size: int) -> None:
|
||
|
|
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
self._model = NanoSocratesCore(
|
||
|
|
embedded_size,
|
||
|
|
feed_forward_dim,
|
||
|
|
encoder_layers,
|
||
|
|
decoder_layers,
|
||
|
|
attention_heads,
|
||
|
|
vocab_size)
|
||
|
|
|