2025-10-11 15:31:43 +02:00

117 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
# B, L(T), E_D
class Decoder(nn.Module):
def __init__(
self,
embedding_dimension: int,
feed_forward_hidden_layer_dimension: int,
number_of_attention_heads: int,
) -> None:
self.__attention_heads = number_of_attention_heads
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
)
self.__layer_norm_1 = nn.LayerNorm(embedding_dimension)
self.__cross_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
)
self.__layer_norm_2 = nn.LayerNorm(embedding_dimension)
self.__dropout = nn.Dropout(0.1)
self.__feed_forward_network = FeedForwardNetwork(
embedding_dimension, feed_forward_hidden_layer_dimension
)
self.__layer_norm_3 = nn.LayerNorm(embedding_dimension)
def forward(
self,
args: tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[bool]
]
): # -> list[torch.Tensor]: # k_x = v_x . While x_q = x
# WARNING: args is needed to have sequential
if len(args) < 6:
args = args + (False)
x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only = args
# build of attention mask
# TODO: create a prefix causal mask if needed
if decoder_only:
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
else:
attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention
MASKED_ATTENTION = self.__masked_attention(
x, x, x, key_padding_mask=tgt_padding_mask, attention_mask=attention_mask
)
# 2) Dropout
# DROPPED_MASKED_ATTENTION = self.__dropout(MASKED_ATTENTION)
# del MASKED_ATTENTION
# 3) Residual Connection
x = x + MASKED_ATTENTION
del MASKED_ATTENTION
# 4) Layer Normalization
x = self.__layer_norm_1(x)
if not decoder_only:
# 5) Encoderdecoder (cross) attention
CROSS_ATTENTION = self.__cross_attention(
x, k_x, v_x, key_padding_mask=src_padding_mask
)
# 6) Dropout
# DROPPED_CROSS_ATTENTION = self.__dropout(CROSS_ATTENTION)
# del CROSS_ATTENTION
# 7) Residual Connection
x = x + CROSS_ATTENTION
del CROSS_ATTENTION
# 8) Layer Normalization
x = self.__layer_norm_2(x)
# 9) Position-wise feed-forward
FEED_FORWARD = self.__feed_forward_network(x)
# 10) Dropout
# DROPPED_FEED_FORWARD = self.__dropout(FEED_FORWARD)
# del FEED_FORWARD
# 11) Residual Connection
x = x + FEED_FORWARD
del FEED_FORWARD
# 12) Layer Normalization
x = self.__layer_norm_3(x)
return (x, k_x, v_x, src_padding_mask, tgt_padding_mask, decoder_only)
# use eval to disable dropout ecc