230523
This commit is contained in:
29
test/performance.py
Normal file
29
test/performance.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
print(torch.__version__)
|
||||
print(torch.backends.mps.is_available())
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
a = torch.randn(10000,1000)
|
||||
b = torch.randn(1000,2000)
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
|
||||
device = torch.device('mps')
|
||||
|
||||
a = a.to(device)
|
||||
b = b.to(device)
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
Reference in New Issue
Block a user