Added shims for utils on using Pytorch
This commit is contained in:
17
Project_Model/Libs/TorchShims/Utils/get_default_device.py
Normal file
17
Project_Model/Libs/TorchShims/Utils/get_default_device.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
def get_default_device() -> torch.device:
|
||||
|
||||
# Cuda or ROCm
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
# Intel GPUs
|
||||
if torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
|
||||
# Apple GPUs
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
|
||||
return torch.device("cpu")
|
||||
Reference in New Issue
Block a user