• pytorch深度学习:一般分类器


    使用的criterion不是MSE而是交叉熵。

    numpy.shape,tensor.size(),正确遍历变量。

    另外 CrossEntropyLoss的参数真是有够饶人的。

     1 import torch
     2 from torch import nn,optim
     3 import matplotlib.pyplot as plt
     4 
     5 class Classifier(nn.Module):
     6     def __init__(self,input_feature,output_size):
     7         super(Classifier, self).__init__()
     8         self.linear=nn.Linear(input_feature,output_size)
     9         # print(input_feature)
    10         # print(output_size)
    11 
    12     def forward(self,x):
    13         # print(x.size())
    14         x=self.linear(x)
    15         # print(x.size())
    16         x=torch.sigmoid(x)
    17         # print(x.size())
    18         return x
    19 
    20 
    21     def train(self, inp, target, criterion, optimizer, epoches):
    22         for epoch in range(epoches):
    23             output = self.forward(inp)
    24             # print(output.size())
    25             # print(target.size())
    26             loss = criterion(output, target)
    27             optimizer.zero_grad()
    28             loss.backward()
    29             optimizer.step()
    30         return self, loss
    31 
    32 cluster=torch.ones(100,2)
    33 data0=torch.normal(cluster,1)
    34 data1=torch.normal(-cluster,1)
    35 target0=torch.zeros(100,1)
    36 target1=torch.ones(100,1)
    37 inputs=torch.cat((data0,data1),dim=0)
    38 target=torch.cat((target0,target1),dim=0)
    39 print(target.size())
    40 target=torch.squeeze(target)
    41 print(inputs.size())
    42 print(target.size())
    43 
    44 plt.scatter(inputs.numpy()[:,0],inputs.numpy()[:,1],c=target.numpy()[:,0],s=10,cmap='RdYlGn')
    45 plt.show()
    46 
    47 model=Classifier(2,2)
    48 criterion=nn.CrossEntropyLoss()
    49 optimizer = optim.SGD(model.parameters(), lr=1e-2)
    50 
    51 # x=torch.cat((data0,data1),).type(torch.FloatTensor)
    52 # y=torch.cat((torch.zeros(100),torch.ones(100)),).type(torch.LongTensor)
    53 
    54 new_model,loss=model.train(inputs,target.type(torch.LongTensor),criterion,optimizer,100)
    55 print(loss)
  • 相关阅读:
    工厂模式--工厂方法模式(Factory Method Pattern)
    工厂模式--简单工厂模式( Simple Factory Pattern )
    blog2.0--Springboot添加redis缓存(注解方式)
    blog2.0--JSR303参数校验+全局异常处理器
    高并发秒杀——SpringBoot集成redis
    SpringBoot报错HHH000206: hibernate.properties not found
    blog2.0--保存博客文章到本地磁盘
    Swagger注解 传参
    设计模式--单例模式
    跳表
  • 原文地址:https://www.cnblogs.com/St-Lovaer/p/13696443.html
Copyright © 2020-2023  润新知