18 lines
359 B
Python
18 lines
359 B
Python
import torch
|
|
|
|
def get_default_device() -> torch.device:
|
|
|
|
# Cuda or ROCm
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
|
|
# Intel GPUs
|
|
if torch.xpu.is_available():
|
|
return torch.device("xpu")
|
|
|
|
# Apple GPUs
|
|
if torch.backends.mps.is_available():
|
|
return torch.device("mps")
|
|
|
|
return torch.device("cpu")
|