diff --git a/Project_Model/Libs/TorchShims/Utils/__init__.py b/Project_Model/Libs/TorchShims/Utils/__init__.py new file mode 100644 index 0000000..5838211 --- /dev/null +++ b/Project_Model/Libs/TorchShims/Utils/__init__.py @@ -0,0 +1,5 @@ +from .get_default_device import get_default_device + +__all__ = [ + "get_default_device" +] \ No newline at end of file diff --git a/Project_Model/Libs/TorchShims/Utils/get_default_device.py b/Project_Model/Libs/TorchShims/Utils/get_default_device.py new file mode 100644 index 0000000..38e6d20 --- /dev/null +++ b/Project_Model/Libs/TorchShims/Utils/get_default_device.py @@ -0,0 +1,17 @@ +import torch + +def get_default_device() -> torch.device: + + # Cuda or ROCm + if torch.cuda.is_available(): + return torch.device("cuda") + + # Intel GPUs + if torch.xpu.is_available(): + return torch.device("xpu") + + # Apple GPUs + if torch.backends.mps.is_available(): + return torch.device("mps") + + return torch.device("cpu") diff --git a/Project_Model/Libs/TorchShims/__init__.py b/Project_Model/Libs/TorchShims/__init__.py new file mode 100644 index 0000000..0920233 --- /dev/null +++ b/Project_Model/Libs/TorchShims/__init__.py @@ -0,0 +1,7 @@ +from .Utils import * + +from .Utils import get_default_device + +__all__ = [ + "get_default_device" +] \ No newline at end of file diff --git a/Project_Model/Libs/__init__.py b/Project_Model/Libs/__init__.py index 39fcdff..abc0042 100644 --- a/Project_Model/Libs/__init__.py +++ b/Project_Model/Libs/__init__.py @@ -1 +1,2 @@ -from . import BPE \ No newline at end of file +from . import BPE +from . import TorchShims \ No newline at end of file