Added new playgrounds

This commit is contained in:
Christian Risi 2025-10-07 20:45:04 +02:00
parent 533347ee22
commit 8adacdb08c
3 changed files with 318 additions and 2 deletions

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"id": "f5762da9", "id": "f5762da9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -127,6 +127,8 @@
"\n", "\n",
"\n", "\n",
"\n", "\n",
"\n",
"\n",
"\n" "\n"
] ]
} }

File diff suppressed because one or more lines are too long

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "4ae47336", "id": "4ae47336",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -15,6 +15,77 @@
"mha = torch.nn.MultiheadAttention(D, num_heads=4, batch_first=True)\n", "mha = torch.nn.MultiheadAttention(D, num_heads=4, batch_first=True)\n",
"y, _ = mha(x, x, x, attn_mask=attn_mask, key_padding_mask=pad_mask) # should work\n" "y, _ = mha(x, x, x, attn_mask=attn_mask, key_padding_mask=pad_mask) # should work\n"
] ]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e38e3fb5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]],\n",
"\n",
" [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.nn.functional.one_hot(torch.tensor([\n",
" [4, 1, 9],\n",
" [2,4,5]\n",
"]))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7119ad53",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.get_default_device()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8c95691a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"xpu\n"
]
}
],
"source": [
"from Project_Model.Libs.TorchShims import get_default_device\n",
"\n",
"print(get_default_device())"
]
} }
], ],
"metadata": { "metadata": {