Added shims for utils on using Pytorch
This commit is contained in:
parent
999141f886
commit
87f24878f4
5
Project_Model/Libs/TorchShims/Utils/__init__.py
Normal file
5
Project_Model/Libs/TorchShims/Utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .get_default_device import get_default_device
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_default_device"
|
||||||
|
]
|
||||||
17
Project_Model/Libs/TorchShims/Utils/get_default_device.py
Normal file
17
Project_Model/Libs/TorchShims/Utils/get_default_device.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def get_default_device() -> torch.device:
|
||||||
|
|
||||||
|
# Cuda or ROCm
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
|
||||||
|
# Intel GPUs
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
return torch.device("xpu")
|
||||||
|
|
||||||
|
# Apple GPUs
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
|
||||||
|
return torch.device("cpu")
|
||||||
7
Project_Model/Libs/TorchShims/__init__.py
Normal file
7
Project_Model/Libs/TorchShims/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .Utils import *
|
||||||
|
|
||||||
|
from .Utils import get_default_device
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_default_device"
|
||||||
|
]
|
||||||
@ -1 +1,2 @@
|
|||||||
from . import BPE
|
from . import BPE
|
||||||
|
from . import TorchShims
|
||||||
Loading…
x
Reference in New Issue
Block a user