1
0
This commit is contained in:
2024-05-23 20:43:52 +08:00
parent e6da567437
commit 9a33bd6f49
11 changed files with 292 additions and 1 deletions

View File

@@ -1,8 +1,23 @@
# pytorch study # pytorch study
## ENV ## BASE ENV
```shell ```shell
conda create -n pt python=3.10 -y conda create -n pt python=3.10 -y
conda activate pt conda activate pt
``` ```
## MAC
```shell
# 安装 pytorch v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端
conda install pytorch::pytorch torchvision torchaudio -c pytorch -y
pip install numpy
pip install pandas
pip install matplotlib
```
## gpt4free
```
pip install -U g4f[all]
```

15
autograd.py Normal file
View File

@@ -0,0 +1,15 @@
import torch
from torch import autograd
device = torch.device('mps')
x = torch.tensor(1.)
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)
y = a ** 2 * x + b * x + c ** 3
print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after:', grads[0], grads[1], grads[2])

11
gpt.py Normal file
View File

@@ -0,0 +1,11 @@
from g4f.client import Client
content = "张量在机器学习中的主要用途"
client = Client()
response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": content}],
)
print(response.choices[0].message.content)

56
linear regression/1.py Normal file
View File

@@ -0,0 +1,56 @@
import numpy as np
import torch
def compute_error_for_line_given_points(b, w, points):
totalError = 0
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
totalError += (y - (w * x + b)) ** 2
return totalError / N
def step_gradient(b_current, w_current, points, learningRate):
b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current)
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current))
new_b = b_current - (learningRate * b_gradient)
new_w = w_current - (learningRate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
for i in range(num_iterations):
b, w = step_gradient(b, w, points, learningRate)
return [b, w]
def run():
# 修改为生成数据的文件路径
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
points = torch.tensor(points_np, device='mps')
learning_rate = 0.0001 # 使用较小的学习率
initial_b = 0.0
initial_w = 0.0
num_iterations = 1000
print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w,
compute_error_for_line_given_points(initial_b,
initial_w,
points)))
print("running...")
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
compute_error_for_line_given_points(b, w, points)))
if __name__ == '__main__':
run()

100
linear regression/data1.csv Normal file
View File

@@ -0,0 +1,100 @@
0.0,2.4360562173289115
0.10101010101010101,2.4288710820592065
0.20202020202020202,3.677943201375977
0.30303030303030304,3.5029234515863217
0.40404040404040403,4.007715980839878
0.5050505050505051,3.95999321461469
0.6060606060606061,3.220853916066527
0.7070707070707071,3.2211460206798623
0.8080808080808081,4.2516957270374505
0.9090909090909091,4.311826715084292
1.0101010101010102,4.153966583608258
1.1111111111111112,4.224290328721461
1.2121212121212122,4.551324602105953
1.3131313131313131,5.157200101408408
1.4141414141414141,5.199011258508288
1.5151515151515151,5.248911218901843
1.6161616161616161,5.789628423512512
1.7171717171717171,5.126592322934872
1.8181818181818181,4.546631494636344
1.9191919191919191,5.7260434379514065
2.0202020202020203,5.607446671816119
2.121212121212121,5.401744626671172
2.2222222222222223,5.568078510495838
2.323232323232323,6.136817713051054
2.4242424242424243,5.399802896696589
2.525252525252525,6.7465591899811415
2.6262626262626263,6.510002771256968
2.727272727272727,6.194107987238278
2.8282828282828283,6.280445605022811
2.929292929292929,6.413289184504817
3.0303030303030303,8.178951965980268
3.131313131313131,7.438933818741419
3.2323232323232323,8.161193108124682
3.3333333333333335,6.466487953447159
3.4343434343434343,7.6815673373443385
3.5353535353535355,7.412509123916619
3.6363636363636362,7.712231039046388
3.7373737373737375,7.512155302443977
3.8383838383838382,8.169468174953455
3.9393939393939394,8.201406255891817
4.040404040404041,9.413915839209679
4.141414141414141,7.2131607261403
4.242424242424242,8.244196707034996
4.343434343434343,8.059400613529792
4.444444444444445,9.127093042087843
4.545454545454545,8.232456814994503
4.646464646464646,9.026988954051767
4.747474747474747,8.936405824368308
4.848484848484849,8.838334259675397
4.94949494949495,9.717080564295035
5.05050505050505,9.635892324495916
5.151515151515151,10.802758752616178
5.252525252525253,9.889268431487773
5.353535353535354,9.262021983987134
5.454545454545454,9.905732041295009
5.555555555555555,9.697006564677089
5.656565656565657,10.435437946557755
5.757575757575758,10.257651825530608
5.858585858585858,11.394734709569004
5.959595959595959,10.872621683473387
6.0606060606060606,10.750944058491058
6.161616161616162,11.375400587831757
6.262626262626262,11.834436555701465
6.363636363636363,11.536088544119654
6.4646464646464645,11.261555999325722
6.565656565656566,12.529961808490153
6.666666666666667,12.19345219105891
6.767676767676767,11.950653180245155
6.8686868686868685,12.176773142948385
6.96969696969697,12.055083206520518
7.070707070707071,13.498633194384489
7.171717171717171,12.542518727882712
7.2727272727272725,13.318372269865769
7.373737373737374,12.542630883166513
7.474747474747475,12.93490122675753
7.575757575757575,14.4040220344926
7.6767676767676765,13.314367294113964
7.777777777777778,14.061236496574551
7.878787878787879,12.686346979737731
7.979797979797979,14.024375221983842
8.080808080808081,13.7042096336008
8.181818181818182,13.342730021126272
8.282828282828282,14.136548357864573
8.383838383838384,14.619569834949138
8.484848484848484,14.01453898823226
8.585858585858587,15.154877807203663
8.686868686868687,14.081910297898048
8.787878787878787,14.474564310016353
8.88888888888889,14.966525346412723
8.98989898989899,15.526107019435932
9.09090909090909,14.352357736719853
9.191919191919192,15.843742065895144
9.292929292929292,15.787083172159111
9.393939393939394,15.211828607109144
9.494949494949495,15.845176532492374
9.595959595959595,15.622518083688107
9.696969696969697,15.589081237426006
9.797979797979798,15.511248085690712
9.8989898989899,16.27050774059862
10.0,16.1105549896166
1 0.0 2.4360562173289115
2 0.10101010101010101 2.4288710820592065
3 0.20202020202020202 3.677943201375977
4 0.30303030303030304 3.5029234515863217
5 0.40404040404040403 4.007715980839878
6 0.5050505050505051 3.95999321461469
7 0.6060606060606061 3.220853916066527
8 0.7070707070707071 3.2211460206798623
9 0.8080808080808081 4.2516957270374505
10 0.9090909090909091 4.311826715084292
11 1.0101010101010102 4.153966583608258
12 1.1111111111111112 4.224290328721461
13 1.2121212121212122 4.551324602105953
14 1.3131313131313131 5.157200101408408
15 1.4141414141414141 5.199011258508288
16 1.5151515151515151 5.248911218901843
17 1.6161616161616161 5.789628423512512
18 1.7171717171717171 5.126592322934872
19 1.8181818181818181 4.546631494636344
20 1.9191919191919191 5.7260434379514065
21 2.0202020202020203 5.607446671816119
22 2.121212121212121 5.401744626671172
23 2.2222222222222223 5.568078510495838
24 2.323232323232323 6.136817713051054
25 2.4242424242424243 5.399802896696589
26 2.525252525252525 6.7465591899811415
27 2.6262626262626263 6.510002771256968
28 2.727272727272727 6.194107987238278
29 2.8282828282828283 6.280445605022811
30 2.929292929292929 6.413289184504817
31 3.0303030303030303 8.178951965980268
32 3.131313131313131 7.438933818741419
33 3.2323232323232323 8.161193108124682
34 3.3333333333333335 6.466487953447159
35 3.4343434343434343 7.6815673373443385
36 3.5353535353535355 7.412509123916619
37 3.6363636363636362 7.712231039046388
38 3.7373737373737375 7.512155302443977
39 3.8383838383838382 8.169468174953455
40 3.9393939393939394 8.201406255891817
41 4.040404040404041 9.413915839209679
42 4.141414141414141 7.2131607261403
43 4.242424242424242 8.244196707034996
44 4.343434343434343 8.059400613529792
45 4.444444444444445 9.127093042087843
46 4.545454545454545 8.232456814994503
47 4.646464646464646 9.026988954051767
48 4.747474747474747 8.936405824368308
49 4.848484848484849 8.838334259675397
50 4.94949494949495 9.717080564295035
51 5.05050505050505 9.635892324495916
52 5.151515151515151 10.802758752616178
53 5.252525252525253 9.889268431487773
54 5.353535353535354 9.262021983987134
55 5.454545454545454 9.905732041295009
56 5.555555555555555 9.697006564677089
57 5.656565656565657 10.435437946557755
58 5.757575757575758 10.257651825530608
59 5.858585858585858 11.394734709569004
60 5.959595959595959 10.872621683473387
61 6.0606060606060606 10.750944058491058
62 6.161616161616162 11.375400587831757
63 6.262626262626262 11.834436555701465
64 6.363636363636363 11.536088544119654
65 6.4646464646464645 11.261555999325722
66 6.565656565656566 12.529961808490153
67 6.666666666666667 12.19345219105891
68 6.767676767676767 11.950653180245155
69 6.8686868686868685 12.176773142948385
70 6.96969696969697 12.055083206520518
71 7.070707070707071 13.498633194384489
72 7.171717171717171 12.542518727882712
73 7.2727272727272725 13.318372269865769
74 7.373737373737374 12.542630883166513
75 7.474747474747475 12.93490122675753
76 7.575757575757575 14.4040220344926
77 7.6767676767676765 13.314367294113964
78 7.777777777777778 14.061236496574551
79 7.878787878787879 12.686346979737731
80 7.979797979797979 14.024375221983842
81 8.080808080808081 13.7042096336008
82 8.181818181818182 13.342730021126272
83 8.282828282828282 14.136548357864573
84 8.383838383838384 14.619569834949138
85 8.484848484848484 14.01453898823226
86 8.585858585858587 15.154877807203663
87 8.686868686868687 14.081910297898048
88 8.787878787878787 14.474564310016353
89 8.88888888888889 14.966525346412723
90 8.98989898989899 15.526107019435932
91 9.09090909090909 14.352357736719853
92 9.191919191919192 15.843742065895144
93 9.292929292929292 15.787083172159111
94 9.393939393939394 15.211828607109144
95 9.494949494949495 15.845176532492374
96 9.595959595959595 15.622518083688107
97 9.696969696969697 15.589081237426006
98 9.797979797979798 15.511248085690712
99 9.8989898989899 16.27050774059862
100 10.0 16.1105549896166

View File

@@ -0,0 +1,30 @@
import numpy as np
import csv
# 定义回归方程参数
w = 1.35
b = 2.89
# 生成x值范围
x_min = 0
x_max = 10
# 生成100个在x轴附近的点
x = np.linspace(x_min, x_max, 100)
# 根据回归方程计算y值
y = w * x + b
# 添加一些噪声,使数据更真实
y += np.random.normal(scale=0.5, size=y.shape)
# 将x和y合并成一个二维数组
data = np.column_stack((x, y))
# 将数据保存到CSV文件
with open('data1.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
# 写入表头
# writer.writerow(['x', 'y'])
# 写入数据
writer.writerows(data)

View File

@@ -0,0 +1,22 @@
import numpy as np
import matplotlib.pyplot as plt
# 原始数据
points = np.genfromtxt("data1.csv", delimiter=',')
x = points[:, 0]
y = points[:, 1]
# 拟合直线
x_range = np.linspace(min(x), max(x), 100)
y_pred = 1.6455038785934448 * x_range + 1.827562689781189
# 绘图
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='Original data')
plt.plot(x_range, y_pred, color='red', label='Fitted line')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Fitting a line to random data')
plt.legend()
plt.grid(True)
plt.savefig('print1.png')

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

7
requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
pytorch::pytorch
torchvision
torchaudio
pandas
matplotlib
numpy
g4f

6
test/macTest.py Normal file
View File

@@ -0,0 +1,6 @@
import torch
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
print(torch.device("mps"))

29
test/performance.py Normal file
View File

@@ -0,0 +1,29 @@
import torch
import time
print(torch.__version__)
print(torch.backends.mps.is_available())
print(torch.cuda.is_available())
a = torch.randn(10000,1000)
b = torch.randn(1000,2000)
t0 = time.time()
c = torch.matmul(a, b)
t1 = time.time()
print(a.device,t1-t0,c.norm(2))
device = torch.device('mps')
a = a.to(device)
b = b.to(device)
t0 = time.time()
c = torch.matmul(a, b)
t1 = time.time()
print(a.device,t1-t0,c.norm(2))
t0 = time.time()
c = torch.matmul(a, b)
t1 = time.time()
print(a.device,t1-t0,c.norm(2))