WIP decoder with prefix mask
This commit is contained in:
parent
ff721107b9
commit
443f54fffd
@ -3,7 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .FeedForwardNetwork import FeedForwardNetwork
|
from .FeedForwardNetwork import FeedForwardNetwork
|
||||||
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
from .TorchMultiHeadAttention import TorchMultiHeadAttention as MultiHeadAttention
|
||||||
from ..Utils.attention_mask import get_causal_attention_mask
|
from ..Utils.attention_mask import get_causal_attention_mask, get_prefix_causal_mask_from_padding_mask
|
||||||
|
|
||||||
# B, L(T), E_D
|
# B, L(T), E_D
|
||||||
|
|
||||||
@ -16,8 +16,11 @@ class Decoder(nn.Module):
|
|||||||
feed_forward_hidden_layer_dimension: int,
|
feed_forward_hidden_layer_dimension: int,
|
||||||
number_of_attention_heads: int,
|
number_of_attention_heads: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.__attention_heads = number_of_attention_heads
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.__masked_attention = MultiHeadAttention(
|
self.__masked_attention = MultiHeadAttention(
|
||||||
embedding_dimension, number_of_attention_heads, dropout=0.1
|
embedding_dimension, number_of_attention_heads, dropout=0.1
|
||||||
)
|
)
|
||||||
@ -54,7 +57,10 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# build of attention mask
|
# build of attention mask
|
||||||
# TODO: create a prefix causal mask if needed
|
# TODO: create a prefix causal mask if needed
|
||||||
attention_mask = get_causal_attention_mask(x.size(1))
|
if decoder_only:
|
||||||
|
attention_mask = get_prefix_causal_mask_from_padding_mask(x.size(1),tgt_padding_mask,self.__attention_heads)
|
||||||
|
else:
|
||||||
|
attention_mask = get_causal_attention_mask(x.size(1))
|
||||||
|
|
||||||
# 1) Masked Attention
|
# 1) Masked Attention
|
||||||
MASKED_ATTENTION = self.__masked_attention(
|
MASKED_ATTENTION = self.__masked_attention(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user