• pytorch:全连接多分类小网络代码实现


    import torch
    from torch import nn
    from torch.nn import init
    import numpy as np
    import sys
    import torchvision
    import torchvision.transforms as transforms
    
    mnist_train=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
    mnist_test=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
    
    # feature_train,label_train=mnist_train[0]
    # feature_test,label_test=mnist_test[0]
    # print(len(mnist_train))
    # print(len(mnist_test))
    # print(feature_train.size(),label_train)
    # print(feature_test.size(),label_test)
    
    batch_size=256
    num_inputs=28*28
    num_outputs=10
    
    if sys.platform.startswith('win'):
        num_workers=0
    else:
        num_workers=4
    
    train_iter=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
    test_iter=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers)
    
    class LinearNet(nn.Module):
        def __init__(self,num_inputs,num_outputs):
            super().__init__()
            self.linear=nn.Linear(num_inputs,num_outputs)
    
            #init params way
            init.normal_(self.linear.weight, mean=0, std=0.01)
            init.constant_(self.linear.bias, val=0)
    
        def forward(self,x):
            y=self.linear(x.view(x.shape[0],-1))
            return y
    
    net=LinearNet(num_inputs,num_outputs)
    loss=nn.CrossEntropyLoss()
    optimizer=torch.optim.SGD(net.parameters(),lr=0.1)
    
    num_epochs=20
    
    def evaluate_accuracy(data_iter,net):
        acc_sum,n=0.,0
        for X,y in data_iter:
            acc_sum+=(net(X).argmax(dim=1)==y).float().sum().item()
            n+=y.shape[0]
        return acc_sum/n
    
    for epoch in range(num_epochs):
        train_l_sum,train_acc_sum,n=0.,0.,0
        for X,y in train_iter:
            y_hat=net(X)
            l=loss(y_hat,y).sum()
    
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
    
            train_l_sum+=l.item()
            train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
            n+=y_hat.shape[0]
    
        test_acc= evaluate_accuracy(test_iter,net)
        print('epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f,'
              %(epoch,train_l_sum/n,train_acc_sum/n,test_acc))
    

      

  • 相关阅读:
    Server.UrlEncode UrlDecode 动态绑定gridview列发送接收乱码的问题
    gridview新用法,一直不知道gridview可以这么用
    vm workstation15 迁移至ESXi6.7步骤
    http 502与504的区别
    Asp.net项目部署ActiveReport
    不能在 Page 回调中调用 Response.Redirect 解决方法
    JQuery TextExt 控件使用
    通过ashx获取JSON数据的两种方式
    jQuery Mobile对话框插件
    替换文本框title提示文本
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/13824863.html
Copyright © 2020-2023  润新知