44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
|
||
import torch
|
||
import torch.nn as nn
|
||
from Transformer.feed_forward_nn import FeedForwardNetwork
|
||
from Transformer.pytorch_multi_head_attention import TorchMultiHeadAttention as MultiHeadAttention
|
||
|
||
|
||
class Decoder(nn.Module):
|
||
|
||
def __init__(self, d_model:int, d_ff: int, attention_heads:int) -> None:
|
||
super().__init__()
|
||
self._masked_attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
|
||
self.attention = MultiHeadAttention(d_model, attention_heads, dropout=0.1)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
|
||
self.dropout = nn.Dropout(0.1)
|
||
|
||
self.ffn = FeedForwardNetwork(d_model, d_ff)
|
||
self.norm3 = nn.LayerNorm(d_model)
|
||
pass
|
||
|
||
def forward(self, x, k_x,v_x, attention_mask): # k_x = v_x . While x_q = x
|
||
|
||
# 1) Masked self-attention
|
||
x = x + self.dropout(self._masked_attention(x, x, x, attention_mask= attention_mask))
|
||
x = self.norm1(x)
|
||
|
||
# 2) Encoder–decoder (cross) attention
|
||
x = x + self.dropout(self.attention(x, k_x, v_x))
|
||
x = self.norm2(x)
|
||
|
||
# 3) Position-wise feed-forward
|
||
x = x + self.dropout(self.ffn(x))
|
||
x = self.norm3(x)
|
||
|
||
return x
|
||
|
||
|
||
|
||
|
||
|
||
# use eval to disable dropout ecc |