Added shims for utils on using Pytorch

This commit is contained in:
Christian Risi
2025-10-03 20:11:14 +02:00
parent 999141f886
commit 87f24878f4
4 changed files with 31 additions and 1 deletions

View File

@@ -0,0 +1,17 @@
import torch
def get_default_device() -> torch.device:
# Cuda or ROCm
if torch.cuda.is_available():
return torch.device("cuda")
# Intel GPUs
if torch.xpu.is_available():
return torch.device("xpu")
# Apple GPUs
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")