• torch_02_多项式回归


     1 """
     2 torch.float64对应torch.DoubleTensor
     3 torch.float32对应torch.FloatTensor
     4 将真实函数的数据点能够拟合成一个多项式
     5 eg:y = 0.9 +0.5×x + 3×x*x + 2.4 ×x*x*x
     6 """
     7 import torch
     8 
     9 from torch import nn
    10 
    11 def make_features(x):
    12     x = x.unsqueeze(1)#在原来的基础上扩充了一维
    13     return torch.cat([x ** i for i in range(1,4)], 1)
    14 
    15 
    16 def get_batch(batch_size=32):
    17 
    18     random = torch.randn(batch_size)
    19     # print('random')
    20     # print(random) #32个数
    21 
    22     x = make_features(random)#进行维度扩充,扩充后32*1,又进行1,2,3次幂运算,拼接后32*3
    23 
    24     '''Compute the actual results'''
    25     y = f(x) # 32*3 *3*1
    26     if torch.cuda.is_available():
    27         return torch.autograd.Variable(x).cuda(), torch.autograd.Variable(y).cuda()
    28     else:
    29         return torch.autograd.Variable(x), torch.autograd.Variable(y)
    30 
    31 
    32 w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1)#三行一列
    33 b_target = torch.FloatTensor([0.9])
    34 
    35 
    36 def f(x):
    37     return x.mm(w_target)+b_target[0]
    38 
    39 class poly_model(nn.Module):
    40     def __init__(self):
    41         super(poly_model, self).__init__()
    42         self.poly = nn.Linear(3, 1)# 输入是3维,输出是1维
    43 
    44     def forward(self, x):
    45         out = self.poly(x)
    46         return out
    47 
    48 if torch.cuda.is_available():
    49  model = poly_model().cuda()
    50 else:
    51  model = poly_model()
    52 
    53 criterion = nn.MSELoss()
    54 optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    55 
    56 epoch = 0
    57 for epoch in range(20):
    58     batch_x,batch_y = get_batch()#batch_x 和get_batch里面的x是一样的
    59     output = model(batch_x)
    60     loss = criterion(output,batch_y)
    61     print_loss = loss
    62     print(loss.item()) # 0.4版本之后使用loss.item()从标量中获得Python number
    63     optimizer.zero_grad()
    64     loss.backward()
    65     optimizer.step()
    66 print('finished')
  • 相关阅读:
    软件工程个人作业01
    动手动脑
    大道至简感想终结篇
    课后作业
    反思
    课后作业
    不忘初心,方得始终
    课后作业
    沟通,让一切变得简单

  • 原文地址:https://www.cnblogs.com/shuangcao/p/11711640.html
Copyright © 2020-2023  润新知