44 lines
1.3 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from Transformer.feed_forward_nn import FeedForwardNetwork
from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention
class Decoder(nn.Module):
def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None:
super().__init__()
self._masked_attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
self.norm1 = nn.LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
self.ffn = FeedForwardNetwork(d_model, d_ff)
self.norm3 = nn.LayerNorm(d_model)
pass
def forward(self, x, k_x,v_x, attention_mask): # k_x = v_x . While x_q = x
# 1) Masked self-attention
x = x + self.dropout(self._masked_attention(x, x, x, attention_mask= attention_mask))
x = self.norm1(x)
# 2) Encoderdecoder (cross) attention
x = x + self.dropout(self.attention(x, k_x, v_x))
x = self.norm2(x)
# 3) Position-wise feed-forward
x = x + self.dropout(self.ffn(x))
x = self.norm3(x)
return x
# use eval to disable dropout ecc