Made model Batch ready
This commit is contained in:
@@ -8,17 +8,16 @@ class TorchMultiHeadAttention(nn.Module):
|
||||
self,
|
||||
embedding_dimension: int,
|
||||
number_of_attention_heads: int,
|
||||
dropout: float = 0.0,
|
||||
dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(
|
||||
self.attention = torch.nn.MultiheadAttention(
|
||||
embedding_dimension,
|
||||
number_of_attention_heads,
|
||||
num_heads=number_of_attention_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user