From 87f24878f4ae2594fcf449bc59b108257d76aaf0 Mon Sep 17 00:00:00 2001 From: Christian Risi <75698846+CnF-Gris@users.noreply.github.com> Date: Fri, 3 Oct 2025 20:11:14 +0200 Subject: [PATCH] Added shims for utils on using Pytorch --- Project_Model/Libs/TorchShims/Utils/__init__.py | 5 +++++ .../Libs/TorchShims/Utils/get_default_device.py | 17 +++++++++++++++++ Project_Model/Libs/TorchShims/__init__.py | 7 +++++++ Project_Model/Libs/__init__.py | 3 ++- 4 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 Project_Model/Libs/TorchShims/Utils/__init__.py create mode 100644 Project_Model/Libs/TorchShims/Utils/get_default_device.py create mode 100644 Project_Model/Libs/TorchShims/__init__.py diff --git a/Project_Model/Libs/TorchShims/Utils/__init__.py b/Project_Model/Libs/TorchShims/Utils/__init__.py new file mode 100644 index 0000000..5838211 --- /dev/null +++ b/Project_Model/Libs/TorchShims/Utils/__init__.py @@ -0,0 +1,5 @@ +from .get_default_device import get_default_device + +__all__ = [ + "get_default_device" +] \ No newline at end of file diff --git a/Project_Model/Libs/TorchShims/Utils/get_default_device.py b/Project_Model/Libs/TorchShims/Utils/get_default_device.py new file mode 100644 index 0000000..38e6d20 --- /dev/null +++ b/Project_Model/Libs/TorchShims/Utils/get_default_device.py @@ -0,0 +1,17 @@ +import torch + +def get_default_device() -> torch.device: + + # Cuda or ROCm + if torch.cuda.is_available(): + return torch.device("cuda") + + # Intel GPUs + if torch.xpu.is_available(): + return torch.device("xpu") + + # Apple GPUs + if torch.backends.mps.is_available(): + return torch.device("mps") + + return torch.device("cpu") diff --git a/Project_Model/Libs/TorchShims/__init__.py b/Project_Model/Libs/TorchShims/__init__.py new file mode 100644 index 0000000..0920233 --- /dev/null +++ b/Project_Model/Libs/TorchShims/__init__.py @@ -0,0 +1,7 @@ +from .Utils import * + +from .Utils import get_default_device + +__all__ = [ + "get_default_device" +] \ No newline at end of file diff --git a/Project_Model/Libs/__init__.py b/Project_Model/Libs/__init__.py index 39fcdff..abc0042 100644 --- a/Project_Model/Libs/__init__.py +++ b/Project_Model/Libs/__init__.py @@ -1 +1,2 @@ -from . import BPE \ No newline at end of file +from . import BPE +from . import TorchShims \ No newline at end of file