import torch def get_default_device() -> torch.device: # Cuda or ROCm if torch.cuda.is_available(): return torch.device("cuda") # Intel GPUs if torch.xpu.is_available(): return torch.device("xpu") # Apple GPUs if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu")