29 lines
815 B
Python
29 lines
815 B
Python
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
def fixed_positional_encoding(
|
||
|
|
sentence_dimension: int,
|
||
|
|
embedding_dimension: int,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
|
||
|
|
BIG_CONST = int(1e4)
|
||
|
|
INITIAL_ENCODING = torch.tensor([i for i in range(0, sentence_dimension)])
|
||
|
|
|
||
|
|
ENCODINGS: list[torch.Tensor] = []
|
||
|
|
|
||
|
|
for i in range(0, embedding_dimension):
|
||
|
|
EMBEDDING_POSITION = i
|
||
|
|
|
||
|
|
# Note: The original paper did not specify
|
||
|
|
# to compute: pos mod 2!!
|
||
|
|
DIVISOR = BIG_CONST ** ((2 * (EMBEDDING_POSITION // 2)) / embedding_dimension)
|
||
|
|
INTERMEDIATE_ENCODING = INITIAL_ENCODING / DIVISOR
|
||
|
|
|
||
|
|
if EMBEDDING_POSITION % 2 == 0:
|
||
|
|
ENCODINGS.append(torch.sin(INTERMEDIATE_ENCODING))
|
||
|
|
continue
|
||
|
|
|
||
|
|
ENCODINGS.append(torch.cos(INTERMEDIATE_ENCODING))
|
||
|
|
|
||
|
|
return torch.stack(ENCODINGS).transpose(0, 1)
|