dev.train #8

Merged
gape_01 merged 50 commits from dev.train into dev 2025-10-17 22:20:14 +02:00
Showing only changes of commit 443f54fffd - Show all commits

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .FeedForwardNetwork import FeedForwardNetwork from .FeedForwardNetwork import FeedForwardNetwork
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
from ..Utils.attention_mask import get_causal_attention_mask from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
# B, L(T), E_D # B, L(T), E_D
@ -16,8 +16,11 @@ class Decoder(nn.Module):
feed_forward_hidden_layer_dimension: int, feed_forward_hidden_layer_dimension: int,
number_of_attention_heads: int, number_of_attention_heads: int,
) -> None: ) -> None:
self.__attention_heads = number_of_attention_heads
super().__init__() super().__init__()
self.__masked_attention = MultiHeadAttention( self.__masked_attention = MultiHeadAttention(
embedding_dimension, number_of_attention_heads, dropout=0.1 embedding_dimension, number_of_attention_heads, dropout=0.1
) )
@ -54,6 +57,9 @@ class Decoder(nn.Module):
# build of attention mask # build of attention mask
# TODO: create a prefix causal mask if needed # TODO: create a prefix causal mask if needed
if decoder_only:
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
else:
attention_mask = get_causal_attention_mask(x.size(1)) attention_mask = get_causal_attention_mask(x.size(1))
# 1) Masked Attention # 1) Masked Attention