WIP decoder with prefix mask

This commit is contained in:
GassiGiuseppe 2025-10-11 15:31:43 +02:00
parent ff721107b9
commit 443f54fffd

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_causal_attention_mask from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
# B, L(T), E_D # B, L(T), E_D
@ -16,8 +16,11 @@ class Decoder(nn.Module):
feed_forward_hidden_layer_dimension: int, feed_forward_hidden_layer_dimension: int,
number_of_attention_heads: int, number_of_attention_heads: int,
) -> None: ) -> None:
self.__attention_heads = number_of_attention_heads
super().__init__() super().__init__()
self.__masked_attention = MultiHeadAttention( self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1 embedding_dimension, number_of_attention_heads, dropout=0.1
) )
@ -54,7 +57,10 @@ class Decoder(nn.Module):
# build of attention mask # build of attention mask
# TODO: create a prefix causal mask if needed # TODO: create a prefix causal mask if needed
attention_mask = get_causal_attention_mask(x.size(1)) if decoder_only:
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
else:
attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention # 1) Masked Attention
MASKED_ATTENTION = self.__masked_attention( MASKED_ATTENTION = self.__masked_attention(