CNN中的果蝇,哈哈。
1 import torch 2 from torchvision import datasets,transforms 3 from torch import nn,optim 4 import torch.nn.functional as F 5 6 trans=(transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))) 7 trainset=datasets.MNIST('data',train=True,download=True,transform=trans) 8 testset=datasets.MNIST('data',train=False,download=True,transform=trans) 9 10 class LeNet(nn.Module): 11 def __init__(self): 12 super(LeNet, self).__init__() 13 self.c1=nn.Conv2d(1,6,(5,5)) 14 self.c3=nn.Conv2d(6,16,5) 15 self.fc1=nn.Linear(16*4*4,120) 16 self.fc2=nn.Linear(120,84) 17 self.fc3=nn.Linear(84,10) 18 19 def forward(self,x): 20 x=F.max_pool2d(F.relu(self.c1(x)),2) 21 x=F.max_pool2d(F.relu(self.c3(x)),2) 22 x=x.view(-1,self.num_flat_features(x)) 23 x=F.relu(self.fc1(x)) 24 x=F.relu(self.fc2(x)) 25 x=self.fc3(x) 26 return x 27 28 def sum_flat_features(self,x): 29 size=x.size()[1:] 30 num_features=1 31 for s in size: 32 num_features*=s 33 return num_features