29 lines
		
	
	
		
			515 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			29 lines
		
	
	
		
			515 B
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import time
 | 
						|
 | 
						|
print(torch.__version__)
 | 
						|
print(torch.backends.mps.is_available())
 | 
						|
print(torch.cuda.is_available())
 | 
						|
 | 
						|
a = torch.randn(10000,1000)
 | 
						|
b = torch.randn(1000,2000)
 | 
						|
 | 
						|
t0 = time.time()
 | 
						|
c = torch.matmul(a, b)
 | 
						|
t1 = time.time()
 | 
						|
print(a.device,t1-t0,c.norm(2))
 | 
						|
 | 
						|
device = torch.device('mps')
 | 
						|
 | 
						|
a = a.to(device)
 | 
						|
b = b.to(device)
 | 
						|
 | 
						|
t0 = time.time()
 | 
						|
c = torch.matmul(a, b)
 | 
						|
t1 = time.time()
 | 
						|
print(a.device,t1-t0,c.norm(2))
 | 
						|
 | 
						|
t0 = time.time()
 | 
						|
c = torch.matmul(a, b)
 | 
						|
t1 = time.time()
 | 
						|
print(a.device,t1-t0,c.norm(2)) |