WIP decoder with prefix mask

This commit is contained in:
GassiGiuseppe 2025-10-11 15:31:43 +02:00
parent ff721107b9
commit 443f54fffd

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_causal_attention_mask
from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
# B, L(T), E_D
@ -16,8 +16,11 @@ class Decoder(nn.Module):
feed_forward_hidden_layer_dimension: int,
number_of_attention_heads: int,
) -> None:
self.__attention_heads = number_of_attention_heads
super().__init__()
self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1
)
@ -54,7 +57,10 @@ class Decoder(nn.Module):
# build of attention mask
# TODO: create a prefix causal mask if needed
attention_mask = get_causal_attention_mask(x.size(1))
if decoder_only:
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
else:
attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention
MASKED_ATTENTION = self.__masked_attention(