• 2、pytorch——Linear模型(最基础版,理解框架,背诵记忆)(调用nn.Modules模块)


    #define y = X @ w
    import torch from torch import nn
    #第一模块,数据初始化 n
    = 100 X = torch.rand(n,2) true_w = torch.tensor([[-1.],[2]]) y = X @ true_w + torch.rand(n,1) w = torch.tensor([[1.],[1.]], requires_grad = True) """model = nn.Sequential(nn.Linear(2,3), nn.tanh(), nn.Linear(3,1), nn.tanh() )"""
    #第二模块,定义model,定义loss_func,定义优化器optim model = nn.Linear(2,1) loss_func = nn.MSELoss() optim = torch.optim.SGD(model.parameters(), 0.1)
    #第三模块,for循环,定义y_hat,定义loss,三步走:优化器参数梯度清零,从loss出发计算梯度,优化器更新各参数
    print("epoch loss w ") epochs = 100 for i in range(epochs): y_hat = model(X) loss = loss_func(y_hat, y) optim.zero_grad() loss.backward() optim.step() print(f"{i} {loss} {model.weight.reshape(2).detach()}")

    epoch	 loss	 w	
    0	 3.0193979740142822	 tensor([-0.4203, -0.2541])
    1	 1.724941372871399	 tensor([-0.3228, -0.1113])
    2	 1.0764166116714478	 tensor([-0.2580, -0.0022])
    3	 0.7487826347351074	 tensor([-0.2163,  0.0831])
    4	 0.5806469917297363	 tensor([-0.1907,  0.1515])
    5	 0.49187105894088745	 tensor([-0.1765,  0.2078])
    6	 0.44265982508659363	 tensor([-0.1703,  0.2556])
    7	 0.41324958205223083	 tensor([-0.1695,  0.2971])
    8	 0.39382266998291016	 tensor([-0.1726,  0.3341])
    9	 0.3794998526573181	 tensor([-0.1784,  0.3679])
    10	 0.3678540289402008	 tensor([-0.1860,  0.3993])
    11	 0.3576757311820984	 tensor([-0.1948,  0.4288])
    12	 0.34836021065711975	 tensor([-0.2044,  0.4569])
    13	 0.3396031856536865	 tensor([-0.2145,  0.4839])
    14	 0.331249475479126	 tensor([-0.2250,  0.5100])
    15	 0.32321831583976746	 tensor([-0.2356,  0.5354])
    16	 0.3154657781124115	 tensor([-0.2463,  0.5601])
    17	 0.3079665005207062	 tensor([-0.2570,  0.5843])
    18	 0.30070433020591736	 tensor([-0.2677,  0.6080])
    19	 0.29366785287857056	 tensor([-0.2783,  0.6313])
    20	 0.28684815764427185	 tensor([-0.2888,  0.6542])
    21	 0.2802375555038452	 tensor([-0.2992,  0.6766])
    22	 0.2738291919231415	 tensor([-0.3095,  0.6987])
    23	 0.2676165997982025	 tensor([-0.3196,  0.7204])
    24	 0.2615937292575836	 tensor([-0.3296,  0.7418])
    25	 0.255754679441452	 tensor([-0.3394,  0.7628])
    26	 0.25009381771087646	 tensor([-0.3492,  0.7835])
    27	 0.24460570514202118	 tensor([-0.3587,  0.8039])
    28	 0.23928505182266235	 tensor([-0.3682,  0.8239])
    29	 0.2341267466545105	 tensor([-0.3775,  0.8437])
    30	 0.22912582755088806	 tensor([-0.3867,  0.8631])
    31	 0.22427748143672943	 tensor([-0.3957,  0.8822])
    32	 0.21957707405090332	 tensor([-0.4046,  0.9011])
    33	 0.2150200754404068	 tensor([-0.4134,  0.9196])
    34	 0.21060210466384888	 tensor([-0.4220,  0.9379])
    35	 0.2063189297914505	 tensor([-0.4305,  0.9558])
    36	 0.20216642320156097	 tensor([-0.4389,  0.9735])
    37	 0.19814060628414154	 tensor([-0.4472,  0.9910])
    38	 0.194237619638443	 tensor([-0.4553,  1.0081])
    39	 0.1904536783695221	 tensor([-0.4633,  1.0250])
    40	 0.18678519129753113	 tensor([-0.4712,  1.0416])
    41	 0.18322861194610596	 tensor([-0.4790,  1.0580])
    42	 0.17978054285049438	 tensor([-0.4867,  1.0741])
    43	 0.176437646150589	 tensor([-0.4943,  1.0900])
    44	 0.17319674789905548	 tensor([-0.5017,  1.1056])
    45	 0.17005468904972076	 tensor([-0.5091,  1.1210])
    46	 0.16700850427150726	 tensor([-0.5163,  1.1361])
    47	 0.1640552133321762	 tensor([-0.5234,  1.1510])
    48	 0.16119202971458435	 tensor([-0.5304,  1.1657])
    49	 0.15841616690158844	 tensor([-0.5374,  1.1802])
    50	 0.15572498738765717	 tensor([-0.5442,  1.1944])
    51	 0.1531158685684204	 tensor([-0.5509,  1.2084])
    52	 0.15058636665344238	 tensor([-0.5575,  1.2222])
    53	 0.1481340080499649	 tensor([-0.5640,  1.2358])
    54	 0.14575643837451935	 tensor([-0.5704,  1.2491])
    55	 0.14345139265060425	 tensor([-0.5768,  1.2623])
    56	 0.14121665060520172	 tensor([-0.5830,  1.2753])
    57	 0.13905006647109985	 tensor([-0.5892,  1.2880])
    58	 0.1369495391845703	 tensor([-0.5952,  1.3006])
    59	 0.1349131017923355	 tensor([-0.6012,  1.3129])
    60	 0.13293875753879547	 tensor([-0.6071,  1.3251])
    61	 0.13102462887763977	 tensor([-0.6129,  1.3371])
    62	 0.12916886806488037	 tensor([-0.6186,  1.3489])
    63	 0.1273697167634964	 tensor([-0.6242,  1.3605])
    64	 0.1256254017353058	 tensor([-0.6297,  1.3719])
    65	 0.123934306204319	 tensor([-0.6352,  1.3832])
    66	 0.12229477614164352	 tensor([-0.6406,  1.3943])
    67	 0.12070523947477341	 tensor([-0.6459,  1.4052])
    68	 0.11916416138410568	 tensor([-0.6511,  1.4159])
    69	 0.11767008155584335	 tensor([-0.6562,  1.4265])
    70	 0.11622155457735062	 tensor([-0.6613,  1.4369])
    71	 0.11481721699237823	 tensor([-0.6663,  1.4472])
    72	 0.11345569044351578	 tensor([-0.6712,  1.4573])
    73	 0.11213566362857819	 tensor([-0.6761,  1.4672])
    74	 0.11085589230060577	 tensor([-0.6809,  1.4770])
    75	 0.10961514711380005	 tensor([-0.6856,  1.4867])
    76	 0.10841222107410431	 tensor([-0.6902,  1.4961])
    77	 0.10724597424268723	 tensor([-0.6948,  1.5055])
    78	 0.10611527413129807	 tensor([-0.6993,  1.5147])
    79	 0.10501907020807266	 tensor([-0.7038,  1.5237])
    80	 0.10395626723766327	 tensor([-0.7081,  1.5326])
    81	 0.1029258519411087	 tensor([-0.7124,  1.5414])
    82	 0.10192685574293137	 tensor([-0.7167,  1.5500])
    83	 0.10095832496881485	 tensor([-0.7209,  1.5585])
    84	 0.10001931339502335	 tensor([-0.7250,  1.5669])
    85	 0.09910892695188522	 tensor([-0.7291,  1.5752])
    86	 0.09822628647089005	 tensor([-0.7331,  1.5833])
    87	 0.09737054258584976	 tensor([-0.7370,  1.5913])
    88	 0.0965408906340599	 tensor([-0.7409,  1.5991])
    89	 0.09573652595281601	 tensor([-0.7448,  1.6069])
    90	 0.09495667368173599	 tensor([-0.7485,  1.6145])
    91	 0.09420059621334076	 tensor([-0.7523,  1.6220])
    92	 0.09346755594015121	 tensor([-0.7559,  1.6294])
    93	 0.09275685995817184	 tensor([-0.7596,  1.6367])
    94	 0.09206782281398773	 tensor([-0.7631,  1.6438])
    95	 0.09139978885650635	 tensor([-0.7666,  1.6509])
    96	 0.09075210243463516	 tensor([-0.7701,  1.6578])
    97	 0.09012416005134583	 tensor([-0.7735,  1.6647])
    98	 0.08951534330844879	 tensor([-0.7769,  1.6714])
    99	 0.08892509341239929	 tensor([-0.7802,  1.6780])
  • 相关阅读:
    几种php加速器比较
    细说firewalld和iptables
    Linux上iptables防火墙的基本应用教程
    mysql 字符串按照数字类型排序
    《设计模式之禅》之六大设计原则下篇
    《设计模式之禅》之六大设计原则中篇
    《设计模式之禅》之六大设计原则上篇
    git bash 乱码问题之解决方案
    nexus没有授权导致的错误
    Java之微信公众号开发
  • 原文地址:https://www.cnblogs.com/qiezi-online/p/13947702.html
Copyright © 2020-2023  润新知