WIP decoder with prefix mask
This commit is contained in:
parent
ff721107b9
commit
443f54fffd
@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from .FeedForwardNetwork import FeedForwardNetwork
|
||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
||||
from ..Utils.attention_mask import get_causal_attention_mask
|
||||
from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
|
||||
|
||||
# B, L(T), E_D
|
||||
|
||||
@ -16,8 +16,11 @@ class Decoder(nn.Module):
|
||||
feed_forward_hidden_layer_dimension: int,
|
||||
number_of_attention_heads: int,
|
||||
) -> None:
|
||||
self.__attention_heads = number_of_attention_heads
|
||||
super().__init__()
|
||||
|
||||
|
||||
|
||||
self.__masked_attention = MultiHeadAttention(
|
||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
||||
)
|
||||
@ -54,6 +57,9 @@ class Decoder(nn.Module):
|
||||
|
||||
# build of attention mask
|
||||
# TODO: create a prefix causal mask if needed
|
||||
if decoder_only:
|
||||
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
|
||||
else:
|
||||
attention_mask = get_causal_attention_mask(x.size(1))
|
||||
|
||||
# 1) Masked Attention
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user