230523
This commit is contained in:
		
							
								
								
									
										17
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								README.md
									
									
									
									
									
								
							@@ -1,8 +1,23 @@
 | 
				
			|||||||
# pytorch study
 | 
					# pytorch study
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## ENV
 | 
					## BASE ENV
 | 
				
			||||||
```shell
 | 
					```shell
 | 
				
			||||||
conda create -n pt python=3.10 -y
 | 
					conda create -n pt python=3.10 -y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
conda activate pt
 | 
					conda activate pt
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## MAC
 | 
				
			||||||
 | 
					```shell
 | 
				
			||||||
 | 
					# 安装 pytorch v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端
 | 
				
			||||||
 | 
					conda install pytorch::pytorch torchvision torchaudio -c pytorch -y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pip install numpy
 | 
				
			||||||
 | 
					pip install pandas
 | 
				
			||||||
 | 
					pip install matplotlib
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## gpt4free
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					pip install -U g4f[all]
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
							
								
								
									
										15
									
								
								autograd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								autograd.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import autograd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					device = torch.device('mps')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					x = torch.tensor(1.)
 | 
				
			||||||
 | 
					a = torch.tensor(2., requires_grad=True)
 | 
				
			||||||
 | 
					b = torch.tensor(2., requires_grad=True)
 | 
				
			||||||
 | 
					c = torch.tensor(3., requires_grad=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					y = a ** 2 * x + b * x + c ** 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print('before:', a.grad, b.grad, c.grad)
 | 
				
			||||||
 | 
					grads = autograd.grad(y, [a, b, c])
 | 
				
			||||||
 | 
					print('after:', grads[0], grads[1], grads[2])
 | 
				
			||||||
							
								
								
									
										11
									
								
								gpt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								gpt.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					from g4f.client import Client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					content = "张量在机器学习中的主要用途"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					client = Client()
 | 
				
			||||||
 | 
					response = client.chat.completions.create(
 | 
				
			||||||
 | 
					    model="gpt-4o",
 | 
				
			||||||
 | 
					    messages=[{"role": "user", "content": content}],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					print(response.choices[0].message.content)
 | 
				
			||||||
							
								
								
									
										56
									
								
								linear regression/1.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								linear regression/1.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compute_error_for_line_given_points(b, w, points):
 | 
				
			||||||
 | 
					    totalError = 0
 | 
				
			||||||
 | 
					    N = float(len(points))
 | 
				
			||||||
 | 
					    for i in range(len(points)):
 | 
				
			||||||
 | 
					        x = points[i][0]
 | 
				
			||||||
 | 
					        y = points[i][1]
 | 
				
			||||||
 | 
					        totalError += (y - (w * x + b)) ** 2
 | 
				
			||||||
 | 
					    return totalError / N
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def step_gradient(b_current, w_current, points, learningRate):
 | 
				
			||||||
 | 
					    b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
 | 
				
			||||||
 | 
					    w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
 | 
				
			||||||
 | 
					    N = float(len(points))
 | 
				
			||||||
 | 
					    for i in range(len(points)):
 | 
				
			||||||
 | 
					        x = points[i][0]
 | 
				
			||||||
 | 
					        y = points[i][1]
 | 
				
			||||||
 | 
					        b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current)
 | 
				
			||||||
 | 
					        w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current))
 | 
				
			||||||
 | 
					    new_b = b_current - (learningRate * b_gradient)
 | 
				
			||||||
 | 
					    new_w = w_current - (learningRate * w_gradient)
 | 
				
			||||||
 | 
					    return [new_b, new_w]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
 | 
				
			||||||
 | 
					    b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
 | 
				
			||||||
 | 
					    w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
 | 
				
			||||||
 | 
					    for i in range(num_iterations):
 | 
				
			||||||
 | 
					        b, w = step_gradient(b, w, points, learningRate)
 | 
				
			||||||
 | 
					    return [b, w]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run():
 | 
				
			||||||
 | 
					    # 修改为生成数据的文件路径
 | 
				
			||||||
 | 
					    points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
 | 
				
			||||||
 | 
					    points = torch.tensor(points_np, device='mps')
 | 
				
			||||||
 | 
					    learning_rate = 0.0001  # 使用较小的学习率
 | 
				
			||||||
 | 
					    initial_b = 0.0
 | 
				
			||||||
 | 
					    initial_w = 0.0
 | 
				
			||||||
 | 
					    num_iterations = 1000
 | 
				
			||||||
 | 
					    print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w,
 | 
				
			||||||
 | 
					                                                                      compute_error_for_line_given_points(initial_b,
 | 
				
			||||||
 | 
					                                                                                                          initial_w,
 | 
				
			||||||
 | 
					                                                                                                          points)))
 | 
				
			||||||
 | 
					    print("running...")
 | 
				
			||||||
 | 
					    [b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
 | 
				
			||||||
 | 
					    print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
 | 
				
			||||||
 | 
					                                                                   compute_error_for_line_given_points(b, w, points)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    run()
 | 
				
			||||||
							
								
								
									
										100
									
								
								linear regression/data1.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								linear regression/data1.csv
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
				
			|||||||
 | 
					0.0,2.4360562173289115
 | 
				
			||||||
 | 
					0.10101010101010101,2.4288710820592065
 | 
				
			||||||
 | 
					0.20202020202020202,3.677943201375977
 | 
				
			||||||
 | 
					0.30303030303030304,3.5029234515863217
 | 
				
			||||||
 | 
					0.40404040404040403,4.007715980839878
 | 
				
			||||||
 | 
					0.5050505050505051,3.95999321461469
 | 
				
			||||||
 | 
					0.6060606060606061,3.220853916066527
 | 
				
			||||||
 | 
					0.7070707070707071,3.2211460206798623
 | 
				
			||||||
 | 
					0.8080808080808081,4.2516957270374505
 | 
				
			||||||
 | 
					0.9090909090909091,4.311826715084292
 | 
				
			||||||
 | 
					1.0101010101010102,4.153966583608258
 | 
				
			||||||
 | 
					1.1111111111111112,4.224290328721461
 | 
				
			||||||
 | 
					1.2121212121212122,4.551324602105953
 | 
				
			||||||
 | 
					1.3131313131313131,5.157200101408408
 | 
				
			||||||
 | 
					1.4141414141414141,5.199011258508288
 | 
				
			||||||
 | 
					1.5151515151515151,5.248911218901843
 | 
				
			||||||
 | 
					1.6161616161616161,5.789628423512512
 | 
				
			||||||
 | 
					1.7171717171717171,5.126592322934872
 | 
				
			||||||
 | 
					1.8181818181818181,4.546631494636344
 | 
				
			||||||
 | 
					1.9191919191919191,5.7260434379514065
 | 
				
			||||||
 | 
					2.0202020202020203,5.607446671816119
 | 
				
			||||||
 | 
					2.121212121212121,5.401744626671172
 | 
				
			||||||
 | 
					2.2222222222222223,5.568078510495838
 | 
				
			||||||
 | 
					2.323232323232323,6.136817713051054
 | 
				
			||||||
 | 
					2.4242424242424243,5.399802896696589
 | 
				
			||||||
 | 
					2.525252525252525,6.7465591899811415
 | 
				
			||||||
 | 
					2.6262626262626263,6.510002771256968
 | 
				
			||||||
 | 
					2.727272727272727,6.194107987238278
 | 
				
			||||||
 | 
					2.8282828282828283,6.280445605022811
 | 
				
			||||||
 | 
					2.929292929292929,6.413289184504817
 | 
				
			||||||
 | 
					3.0303030303030303,8.178951965980268
 | 
				
			||||||
 | 
					3.131313131313131,7.438933818741419
 | 
				
			||||||
 | 
					3.2323232323232323,8.161193108124682
 | 
				
			||||||
 | 
					3.3333333333333335,6.466487953447159
 | 
				
			||||||
 | 
					3.4343434343434343,7.6815673373443385
 | 
				
			||||||
 | 
					3.5353535353535355,7.412509123916619
 | 
				
			||||||
 | 
					3.6363636363636362,7.712231039046388
 | 
				
			||||||
 | 
					3.7373737373737375,7.512155302443977
 | 
				
			||||||
 | 
					3.8383838383838382,8.169468174953455
 | 
				
			||||||
 | 
					3.9393939393939394,8.201406255891817
 | 
				
			||||||
 | 
					4.040404040404041,9.413915839209679
 | 
				
			||||||
 | 
					4.141414141414141,7.2131607261403
 | 
				
			||||||
 | 
					4.242424242424242,8.244196707034996
 | 
				
			||||||
 | 
					4.343434343434343,8.059400613529792
 | 
				
			||||||
 | 
					4.444444444444445,9.127093042087843
 | 
				
			||||||
 | 
					4.545454545454545,8.232456814994503
 | 
				
			||||||
 | 
					4.646464646464646,9.026988954051767
 | 
				
			||||||
 | 
					4.747474747474747,8.936405824368308
 | 
				
			||||||
 | 
					4.848484848484849,8.838334259675397
 | 
				
			||||||
 | 
					4.94949494949495,9.717080564295035
 | 
				
			||||||
 | 
					5.05050505050505,9.635892324495916
 | 
				
			||||||
 | 
					5.151515151515151,10.802758752616178
 | 
				
			||||||
 | 
					5.252525252525253,9.889268431487773
 | 
				
			||||||
 | 
					5.353535353535354,9.262021983987134
 | 
				
			||||||
 | 
					5.454545454545454,9.905732041295009
 | 
				
			||||||
 | 
					5.555555555555555,9.697006564677089
 | 
				
			||||||
 | 
					5.656565656565657,10.435437946557755
 | 
				
			||||||
 | 
					5.757575757575758,10.257651825530608
 | 
				
			||||||
 | 
					5.858585858585858,11.394734709569004
 | 
				
			||||||
 | 
					5.959595959595959,10.872621683473387
 | 
				
			||||||
 | 
					6.0606060606060606,10.750944058491058
 | 
				
			||||||
 | 
					6.161616161616162,11.375400587831757
 | 
				
			||||||
 | 
					6.262626262626262,11.834436555701465
 | 
				
			||||||
 | 
					6.363636363636363,11.536088544119654
 | 
				
			||||||
 | 
					6.4646464646464645,11.261555999325722
 | 
				
			||||||
 | 
					6.565656565656566,12.529961808490153
 | 
				
			||||||
 | 
					6.666666666666667,12.19345219105891
 | 
				
			||||||
 | 
					6.767676767676767,11.950653180245155
 | 
				
			||||||
 | 
					6.8686868686868685,12.176773142948385
 | 
				
			||||||
 | 
					6.96969696969697,12.055083206520518
 | 
				
			||||||
 | 
					7.070707070707071,13.498633194384489
 | 
				
			||||||
 | 
					7.171717171717171,12.542518727882712
 | 
				
			||||||
 | 
					7.2727272727272725,13.318372269865769
 | 
				
			||||||
 | 
					7.373737373737374,12.542630883166513
 | 
				
			||||||
 | 
					7.474747474747475,12.93490122675753
 | 
				
			||||||
 | 
					7.575757575757575,14.4040220344926
 | 
				
			||||||
 | 
					7.6767676767676765,13.314367294113964
 | 
				
			||||||
 | 
					7.777777777777778,14.061236496574551
 | 
				
			||||||
 | 
					7.878787878787879,12.686346979737731
 | 
				
			||||||
 | 
					7.979797979797979,14.024375221983842
 | 
				
			||||||
 | 
					8.080808080808081,13.7042096336008
 | 
				
			||||||
 | 
					8.181818181818182,13.342730021126272
 | 
				
			||||||
 | 
					8.282828282828282,14.136548357864573
 | 
				
			||||||
 | 
					8.383838383838384,14.619569834949138
 | 
				
			||||||
 | 
					8.484848484848484,14.01453898823226
 | 
				
			||||||
 | 
					8.585858585858587,15.154877807203663
 | 
				
			||||||
 | 
					8.686868686868687,14.081910297898048
 | 
				
			||||||
 | 
					8.787878787878787,14.474564310016353
 | 
				
			||||||
 | 
					8.88888888888889,14.966525346412723
 | 
				
			||||||
 | 
					8.98989898989899,15.526107019435932
 | 
				
			||||||
 | 
					9.09090909090909,14.352357736719853
 | 
				
			||||||
 | 
					9.191919191919192,15.843742065895144
 | 
				
			||||||
 | 
					9.292929292929292,15.787083172159111
 | 
				
			||||||
 | 
					9.393939393939394,15.211828607109144
 | 
				
			||||||
 | 
					9.494949494949495,15.845176532492374
 | 
				
			||||||
 | 
					9.595959595959595,15.622518083688107
 | 
				
			||||||
 | 
					9.696969696969697,15.589081237426006
 | 
				
			||||||
 | 
					9.797979797979798,15.511248085690712
 | 
				
			||||||
 | 
					9.8989898989899,16.27050774059862
 | 
				
			||||||
 | 
					10.0,16.1105549896166
 | 
				
			||||||
		
		
			
  | 
							
								
								
									
										30
									
								
								linear regression/np_genPoints.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								linear regression/np_genPoints.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import csv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 定义回归方程参数
 | 
				
			||||||
 | 
					w = 1.35
 | 
				
			||||||
 | 
					b = 2.89
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 生成x值范围
 | 
				
			||||||
 | 
					x_min = 0
 | 
				
			||||||
 | 
					x_max = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 生成100个在x轴附近的点
 | 
				
			||||||
 | 
					x = np.linspace(x_min, x_max, 100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 根据回归方程计算y值
 | 
				
			||||||
 | 
					y = w * x + b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 添加一些噪声,使数据更真实
 | 
				
			||||||
 | 
					y += np.random.normal(scale=0.5, size=y.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 将x和y合并成一个二维数组
 | 
				
			||||||
 | 
					data = np.column_stack((x, y))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 将数据保存到CSV文件
 | 
				
			||||||
 | 
					with open('data1.csv', 'w', newline='') as csvfile:
 | 
				
			||||||
 | 
					    writer = csv.writer(csvfile)
 | 
				
			||||||
 | 
					    # 写入表头
 | 
				
			||||||
 | 
					    # writer.writerow(['x', 'y'])
 | 
				
			||||||
 | 
					    # 写入数据
 | 
				
			||||||
 | 
					    writer.writerows(data)
 | 
				
			||||||
							
								
								
									
										22
									
								
								linear regression/plt_print.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								linear regression/plt_print.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 原始数据
 | 
				
			||||||
 | 
					points = np.genfromtxt("data1.csv", delimiter=',')
 | 
				
			||||||
 | 
					x = points[:, 0]
 | 
				
			||||||
 | 
					y = points[:, 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 拟合直线
 | 
				
			||||||
 | 
					x_range = np.linspace(min(x), max(x), 100)
 | 
				
			||||||
 | 
					y_pred = 1.6455038785934448 * x_range + 1.827562689781189
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 绘图
 | 
				
			||||||
 | 
					plt.figure(figsize=(8, 6))
 | 
				
			||||||
 | 
					plt.scatter(x, y, color='blue', label='Original data')
 | 
				
			||||||
 | 
					plt.plot(x_range, y_pred, color='red', label='Fitted line')
 | 
				
			||||||
 | 
					plt.xlabel('X')
 | 
				
			||||||
 | 
					plt.ylabel('Y')
 | 
				
			||||||
 | 
					plt.title('Fitting a line to random data')
 | 
				
			||||||
 | 
					plt.legend()
 | 
				
			||||||
 | 
					plt.grid(True)
 | 
				
			||||||
 | 
					plt.savefig('print1.png')
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								linear regression/print1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								linear regression/print1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 36 KiB  | 
							
								
								
									
										7
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					pytorch::pytorch
 | 
				
			||||||
 | 
					torchvision
 | 
				
			||||||
 | 
					torchaudio
 | 
				
			||||||
 | 
					pandas
 | 
				
			||||||
 | 
					matplotlib
 | 
				
			||||||
 | 
					numpy
 | 
				
			||||||
 | 
					g4f
 | 
				
			||||||
							
								
								
									
										6
									
								
								test/macTest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								test/macTest.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(torch.backends.mps.is_available())
 | 
				
			||||||
 | 
					print(torch.backends.mps.is_built())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(torch.device("mps"))
 | 
				
			||||||
							
								
								
									
										29
									
								
								test/performance.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								test/performance.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(torch.__version__)
 | 
				
			||||||
 | 
					print(torch.backends.mps.is_available())
 | 
				
			||||||
 | 
					print(torch.cuda.is_available())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					a = torch.randn(10000,1000)
 | 
				
			||||||
 | 
					b = torch.randn(1000,2000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					t0 = time.time()
 | 
				
			||||||
 | 
					c = torch.matmul(a, b)
 | 
				
			||||||
 | 
					t1 = time.time()
 | 
				
			||||||
 | 
					print(a.device,t1-t0,c.norm(2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					device = torch.device('mps')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					a = a.to(device)
 | 
				
			||||||
 | 
					b = b.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					t0 = time.time()
 | 
				
			||||||
 | 
					c = torch.matmul(a, b)
 | 
				
			||||||
 | 
					t1 = time.time()
 | 
				
			||||||
 | 
					print(a.device,t1-t0,c.norm(2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					t0 = time.time()
 | 
				
			||||||
 | 
					c = torch.matmul(a, b)
 | 
				
			||||||
 | 
					t1 = time.time()
 | 
				
			||||||
 | 
					print(a.device,t1-t0,c.norm(2))
 | 
				
			||||||
		Reference in New Issue
	
	Block a user