import torch from torch import nn,optim import torchvision import torchvision.transforms as transforms import sys #params batch_size=256 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_epochs=10 #dataset mnist_train=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor()) mnist_test=torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor()) if sys.platform.startswith('win'): num_workers=0 else: num_workers=4 train_iter=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers) test_iter=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers) #net class LeNet(nn.Module): def __init__(self): super().__init__() self.conv=nn.Sequential( nn.Conv2d(1,6,5), nn.Sigmoid(), nn.MaxPool2d(2,2), nn.Conv2d(6, 16, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2) ) self.fc=nn.Sequential( nn.Linear(16*4*4,120), nn.Sigmoid(), nn.Linear(120,84), nn.Sigmoid(), nn.Linear(84,10) ) def forward(self,img): feature=self.conv(img) output=self.fc(feature.view(img.shape[0],-1)) return output net=LeNet().to(device) def evaluate_accuracy(data_iter,net,device): acc_sum,n=0.,0 with torch.no_grad(): for X,y in data_iter: if isinstance(net,torch.nn.Module): net.eval() acc_sum+=(net(X.to(device)).argmax(dim=1)==y.to(device)).float().sum().cpu().item() net.train() n+=y.shape[0] return acc_sum/n loss=nn.CrossEntropyLoss() optimizer=torch.optim.Adam(net.parameters(),lr=0.001) for epoch in range(num_epochs): train_l_sum,train_acc_sum,n=0.,0.,0 for X,y in train_iter: X,y=X.to(device),y.to(device) y_hat=net(X) l=loss(y_hat,y).sum() optimizer.zero_grad() l.backward() optimizer.step() train_l_sum+=l.cpu().item() train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().cpu().item() n+=y_hat.shape[0] test_acc= evaluate_accuracy(test_iter,net,device) print('epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f,' %(epoch,train_l_sum/n,train_acc_sum/n,test_acc))