使用数据集猫狗大战
import time import torch import torchvision from torchvision import datasets, transforms import os import matplotlib.pyplot as plt from torch.autograd import Variable os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用GPU 0 data_dir = "DogsVsCats" # 设置数据格式 data_transform = {x: transforms.Compose([transforms.Scale([64, 64]), # scale类将原始图缩放至64*64 transforms.ToTensor()]) for x in ["train", "valid"]} # 加载数据 image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x]) for x in ["train", "valid"]} # 数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。 # 在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。 dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=16, shuffle=True) for x in ["train", "valid"]} # 获取一个批次的装载数据 x_example(16,3,64,64) y_example 进行了独热编码,里面全为0和1 x_example, y_example = next(iter(dataloader["train"])) # index_classes的 输出结果为{'cat':0,'dog',1} index_classes = image_datasets["train"].class_to_idx #将原始标签的结果存在example_clasees中 {'cat','dog'} example_clasees = image_datasets["train"].classes # 做成网格数据 img = torchvision.utils.make_grid(x_example) img = img.numpy().transpose([1, 2, 0]) # 转换维度 # print([example_clasees[i] for i in y_example]) # plt.imshow(img) # plt.show() # VGGNet模型 class Models(torch.nn.Module): def __init__(self): super(Models,self).__init__() self.Conv = torch.nn.Sequential( torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), ) self.Classes = torch.nn.Sequential( torch.nn.Linear(4*4*512, 1024), torch.nn.ReLU(), torch.nn.Dropout(p=0.5), torch.nn.Linear(1024, 1024), torch.nn.Dropout(p=0.5), torch.nn.Linear(1024, 2) ) def forward(self, input): x = self.Conv(input) x = x.view(-1, 4*4*512) x = self.Classes(x) return x model = Models() # print(model) loss_f = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(),lr=0.00001) Use_gpu = torch.cuda.is_available() # 判断是否存在cuda if Use_gpu: model = model.cuda() # *********************************************************** epoch_n = 10 time_open = time.time() for epoch in range(epoch_n): print("Epoch{}/{}".format(epoch,epoch_n-1)) print("-"*10) for phase in ["train", "valid"]: if phase == "train": print("Training...") model.train(True) else: print("Validing...") model.train(False) running_loss = 0.0 running_corrects = 0 for batch, data in enumerate(dataloader[phase], 1): # enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标, X, y = data if Use_gpu: X, y =Variable(X.cuda()),Variable(y.cuda()) # ************************************** else: X, y = Variable(X), Variable(y) y_pred = model(X) # 得到预测值 _,pred =torch.max(y_pred,1) optimizer.zero_grad() # 清空梯度 loss = loss_f(y_pred, y) # 定义损失函数 if phase == "train": loss.backward() # 如果是训练,进行反向传播 optimizer.step() # 更新各节点的参数 running_loss += loss.item() running_corrects += torch.sum(pred == y.data) if batch%500 == 0 and phase == "train": print("Batch{},TrainLoss:{:.4f},Train ACC:{:.4f}".format( batch,running_loss/batch, 100*running_corrects/(16*batch))) epocn_loss = running_loss*16/len(image_datasets[phase]) epoch_acc = 100*running_corrects/len(image_datasets[phase]) print("{} Loss:{:.4f} Acc:{:4f}%".format(phase, epocn_loss, epoch_acc)) time_end = time.time()-time_open print(time_end)