import torch import time print(torch.__version__) print(torch.backends.mps.is_available()) print(torch.cuda.is_available()) a = torch.randn(10000,1000) b = torch.randn(1000,2000) t0 = time.time() c = torch.matmul(a, b) t1 = time.time() print(a.device,t1-t0,c.norm(2)) device = torch.device('mps') a = a.to(device) b = b.to(device) t0 = time.time() c = torch.matmul(a, b) t1 = time.time() print(a.device,t1-t0,c.norm(2)) t0 = time.time() c = torch.matmul(a, b) t1 = time.time() print(a.device,t1-t0,c.norm(2))