#define y = X @ w
import torch
from torch import nn
#第一模块,数据初始化
n = 100
X = torch.rand(n,2)
true_w = torch.tensor([[-1.],[2]])
y = X @ true_w + torch.rand(n,1)
w = torch.tensor([[1.],[1.]], requires_grad = True)
"""model = nn.Sequential(nn.Linear(2,3),
nn.tanh(),
nn.Linear(3,1),
nn.tanh()
)"""
#第二模块,定义model,定义loss_func,定义优化器optim
model = nn.Linear(2,1)
loss_func = nn.MSELoss()
optim = torch.optim.SGD(model.parameters(), 0.1)
#第三模块,for循环,定义y_hat,定义loss,三步走:优化器参数梯度清零,从loss出发计算梯度,优化器更新各参数
print("epoch loss w ")
epochs = 100
for i in range(epochs):
y_hat = model(X)
loss = loss_func(y_hat, y)
optim.zero_grad()
loss.backward()
optim.step()
print(f"{i} {loss} {model.weight.reshape(2).detach()}")
epoch loss w
0 3.0193979740142822 tensor([-0.4203, -0.2541])
1 1.724941372871399 tensor([-0.3228, -0.1113])
2 1.0764166116714478 tensor([-0.2580, -0.0022])
3 0.7487826347351074 tensor([-0.2163, 0.0831])
4 0.5806469917297363 tensor([-0.1907, 0.1515])
5 0.49187105894088745 tensor([-0.1765, 0.2078])
6 0.44265982508659363 tensor([-0.1703, 0.2556])
7 0.41324958205223083 tensor([-0.1695, 0.2971])
8 0.39382266998291016 tensor([-0.1726, 0.3341])
9 0.3794998526573181 tensor([-0.1784, 0.3679])
10 0.3678540289402008 tensor([-0.1860, 0.3993])
11 0.3576757311820984 tensor([-0.1948, 0.4288])
12 0.34836021065711975 tensor([-0.2044, 0.4569])
13 0.3396031856536865 tensor([-0.2145, 0.4839])
14 0.331249475479126 tensor([-0.2250, 0.5100])
15 0.32321831583976746 tensor([-0.2356, 0.5354])
16 0.3154657781124115 tensor([-0.2463, 0.5601])
17 0.3079665005207062 tensor([-0.2570, 0.5843])
18 0.30070433020591736 tensor([-0.2677, 0.6080])
19 0.29366785287857056 tensor([-0.2783, 0.6313])
20 0.28684815764427185 tensor([-0.2888, 0.6542])
21 0.2802375555038452 tensor([-0.2992, 0.6766])
22 0.2738291919231415 tensor([-0.3095, 0.6987])
23 0.2676165997982025 tensor([-0.3196, 0.7204])
24 0.2615937292575836 tensor([-0.3296, 0.7418])
25 0.255754679441452 tensor([-0.3394, 0.7628])
26 0.25009381771087646 tensor([-0.3492, 0.7835])
27 0.24460570514202118 tensor([-0.3587, 0.8039])
28 0.23928505182266235 tensor([-0.3682, 0.8239])
29 0.2341267466545105 tensor([-0.3775, 0.8437])
30 0.22912582755088806 tensor([-0.3867, 0.8631])
31 0.22427748143672943 tensor([-0.3957, 0.8822])
32 0.21957707405090332 tensor([-0.4046, 0.9011])
33 0.2150200754404068 tensor([-0.4134, 0.9196])
34 0.21060210466384888 tensor([-0.4220, 0.9379])
35 0.2063189297914505 tensor([-0.4305, 0.9558])
36 0.20216642320156097 tensor([-0.4389, 0.9735])
37 0.19814060628414154 tensor([-0.4472, 0.9910])
38 0.194237619638443 tensor([-0.4553, 1.0081])
39 0.1904536783695221 tensor([-0.4633, 1.0250])
40 0.18678519129753113 tensor([-0.4712, 1.0416])
41 0.18322861194610596 tensor([-0.4790, 1.0580])
42 0.17978054285049438 tensor([-0.4867, 1.0741])
43 0.176437646150589 tensor([-0.4943, 1.0900])
44 0.17319674789905548 tensor([-0.5017, 1.1056])
45 0.17005468904972076 tensor([-0.5091, 1.1210])
46 0.16700850427150726 tensor([-0.5163, 1.1361])
47 0.1640552133321762 tensor([-0.5234, 1.1510])
48 0.16119202971458435 tensor([-0.5304, 1.1657])
49 0.15841616690158844 tensor([-0.5374, 1.1802])
50 0.15572498738765717 tensor([-0.5442, 1.1944])
51 0.1531158685684204 tensor([-0.5509, 1.2084])
52 0.15058636665344238 tensor([-0.5575, 1.2222])
53 0.1481340080499649 tensor([-0.5640, 1.2358])
54 0.14575643837451935 tensor([-0.5704, 1.2491])
55 0.14345139265060425 tensor([-0.5768, 1.2623])
56 0.14121665060520172 tensor([-0.5830, 1.2753])
57 0.13905006647109985 tensor([-0.5892, 1.2880])
58 0.1369495391845703 tensor([-0.5952, 1.3006])
59 0.1349131017923355 tensor([-0.6012, 1.3129])
60 0.13293875753879547 tensor([-0.6071, 1.3251])
61 0.13102462887763977 tensor([-0.6129, 1.3371])
62 0.12916886806488037 tensor([-0.6186, 1.3489])
63 0.1273697167634964 tensor([-0.6242, 1.3605])
64 0.1256254017353058 tensor([-0.6297, 1.3719])
65 0.123934306204319 tensor([-0.6352, 1.3832])
66 0.12229477614164352 tensor([-0.6406, 1.3943])
67 0.12070523947477341 tensor([-0.6459, 1.4052])
68 0.11916416138410568 tensor([-0.6511, 1.4159])
69 0.11767008155584335 tensor([-0.6562, 1.4265])
70 0.11622155457735062 tensor([-0.6613, 1.4369])
71 0.11481721699237823 tensor([-0.6663, 1.4472])
72 0.11345569044351578 tensor([-0.6712, 1.4573])
73 0.11213566362857819 tensor([-0.6761, 1.4672])
74 0.11085589230060577 tensor([-0.6809, 1.4770])
75 0.10961514711380005 tensor([-0.6856, 1.4867])
76 0.10841222107410431 tensor([-0.6902, 1.4961])
77 0.10724597424268723 tensor([-0.6948, 1.5055])
78 0.10611527413129807 tensor([-0.6993, 1.5147])
79 0.10501907020807266 tensor([-0.7038, 1.5237])
80 0.10395626723766327 tensor([-0.7081, 1.5326])
81 0.1029258519411087 tensor([-0.7124, 1.5414])
82 0.10192685574293137 tensor([-0.7167, 1.5500])
83 0.10095832496881485 tensor([-0.7209, 1.5585])
84 0.10001931339502335 tensor([-0.7250, 1.5669])
85 0.09910892695188522 tensor([-0.7291, 1.5752])
86 0.09822628647089005 tensor([-0.7331, 1.5833])
87 0.09737054258584976 tensor([-0.7370, 1.5913])
88 0.0965408906340599 tensor([-0.7409, 1.5991])
89 0.09573652595281601 tensor([-0.7448, 1.6069])
90 0.09495667368173599 tensor([-0.7485, 1.6145])
91 0.09420059621334076 tensor([-0.7523, 1.6220])
92 0.09346755594015121 tensor([-0.7559, 1.6294])
93 0.09275685995817184 tensor([-0.7596, 1.6367])
94 0.09206782281398773 tensor([-0.7631, 1.6438])
95 0.09139978885650635 tensor([-0.7666, 1.6509])
96 0.09075210243463516 tensor([-0.7701, 1.6578])
97 0.09012416005134583 tensor([-0.7735, 1.6647])
98 0.08951534330844879 tensor([-0.7769, 1.6714])
99 0.08892509341239929 tensor([-0.7802, 1.6780])