Added shims for utils on using Pytorch

This commit is contained in:
Christian Risi 2025-10-03 20:11:14 +02:00
parent 999141f886
commit 87f24878f4
4 changed files with 31 additions and 1 deletions

View File

@ -0,0 +1,5 @@
from .get_default_device import get_default_device
__all__ = [
"get_default_device"
]

View File

@ -0,0 +1,17 @@
import torch
def get_default_device() -> torch.device:
# Cuda or ROCm
if torch.cuda.is_available():
return torch.device("cuda")
# Intel GPUs
if torch.xpu.is_available():
return torch.device("xpu")
# Apple GPUs
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")

View File

@ -0,0 +1,7 @@
from .Utils import *
from .Utils import get_default_device
__all__ = [
"get_default_device"
]

View File

@ -1 +1,2 @@
from . import BPE
from . import BPE
from . import TorchShims