• pytorch实现VGG


    一。网络结构和参数

    特点:堆叠多个小尺寸的卷积核来做到和大卷积核一样的感受野。减少网络参数的同时加深了网络深度。

     二。模型定义和训练代码

    model.py

     1 import torch.nn as nn
     2 import torch
     3 
     4 
     5 class VGG(nn.Module):
     6     def __init__(self, features, num_classes=1000, init_weights=False):
     7         super(VGG, self).__init__()
     8         self.features = features
     9         self.classifier = nn.Sequential(
    10             nn.Dropout(p=0.5),
    11             nn.Linear(512*7*7, 2048),
    12             nn.ReLU(True),
    13             nn.Dropout(p=0.5),
    14             nn.Linear(2048, 2048),
    15             nn.ReLU(True),
    16             nn.Linear(2048, num_classes)
    17         )
    18         if init_weights:
    19             self._initialize_weights()
    20 
    21     def forward(self, x):
    22         # N x 3 x 224 x 224
    23         x = self.features(x)
    24         # N x 512 x 7 x 7
    25         x = torch.flatten(x, start_dim=1)
    26         # N x 512*7*7
    27         x = self.classifier(x)
    28         return x
    29 
    30     def _initialize_weights(self):
    31         for m in self.modules():
    32             if isinstance(m, nn.Conv2d):
    33                 # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    34                 nn.init.xavier_uniform_(m.weight)
    35                 if m.bias is not None:
    36                     nn.init.constant_(m.bias, 0)
    37             elif isinstance(m, nn.Linear):
    38                 nn.init.xavier_uniform_(m.weight)
    39                 # nn.init.normal_(m.weight, 0, 0.01)
    40                 nn.init.constant_(m.bias, 0)
    41 
    42 
    43 def make_features(cfg: list):
    44     layers = []
    45     in_channels = 3
    46     for v in cfg:
    47         if v == "M":
    48             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    49         else:
    50             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
    51             layers += [conv2d, nn.ReLU(True)]
    52             in_channels = v
    53     return nn.Sequential(*layers)
    54 
    55 
    56 cfgs = {
    57     'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    58     'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    59     'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    60     'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    61 }
    62 
    63 
    64 def vgg(model_name="vgg16", **kwargs):
    65     try:
    66         cfg = cfgs[model_name]
    67     except:
    68         print("Warning: model number {} not in cfgs dict!".format(model_name))
    69         exit(-1)
    70     model = VGG(make_features(cfg), **kwargs)
    71     return model

    train.py

      1 import torch.nn as nn
      2 from torchvision import transforms, datasets
      3 import json
      4 import os
      5 import torch.optim as optim
      6 from model import vgg
      7 import torch
      8 
      9 
     10 def main():
     11     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     12     print("using {} device.".format(device))
     13 
     14     data_transform = {
     15         "train": transforms.Compose([transforms.RandomResizedCrop(224),
     16                                      transforms.RandomHorizontalFlip(),
     17                                      transforms.ToTensor(),
     18                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
     19         "val": transforms.Compose([transforms.Resize((224, 224)),
     20                                    transforms.ToTensor(),
     21                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
     22 
     23     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
     24     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
     25     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
     26     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
     27                                          transform=data_transform["train"])
     28     train_num = len(train_dataset)
     29 
     30     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
     31     flower_list = train_dataset.class_to_idx
     32     cla_dict = dict((val, key) for key, val in flower_list.items())
     33     # write dict into json file
     34     json_str = json.dumps(cla_dict, indent=4)
     35     with open('class_indices.json', 'w') as json_file:
     36         json_file.write(json_str)
     37 
     38     batch_size = 32
     39     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
     40     print('Using {} dataloader workers every process'.format(nw))
     41 
     42     train_loader = torch.utils.data.DataLoader(train_dataset,
     43                                                batch_size=batch_size, shuffle=True,
     44                                                num_workers=0)
     45 
     46     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
     47                                             transform=data_transform["val"])
     48     val_num = len(validate_dataset)
     49     validate_loader = torch.utils.data.DataLoader(validate_dataset,
     50                                                   batch_size=batch_size, shuffle=False,
     51                                                   num_workers=0)
     52     print("using {} images for training, {} images fot validation.".format(train_num,
     53                                                                            val_num))
     54 
     55     # test_data_iter = iter(validate_loader)
     56     # test_image, test_label = test_data_iter.next()
     57 
     58     model_name = "vgg16"
     59     net = vgg(model_name=model_name, num_classes=5, init_weights=True)
     60     net.to(device)
     61     loss_function = nn.CrossEntropyLoss()
     62     optimizer = optim.Adam(net.parameters(), lr=0.0001)
     63 
     64     best_acc = 0.0
     65     save_path = './{}Net.pth'.format(model_name)
     66     for epoch in range(30):
     67         # train
     68         net.train()
     69         running_loss = 0.0
     70         for step, data in enumerate(train_loader, start=0):
     71             images, labels = data
     72             optimizer.zero_grad()
     73             outputs = net(images.to(device))
     74             loss = loss_function(outputs, labels.to(device))
     75             loss.backward()
     76             optimizer.step()
     77 
     78             # print statistics
     79             running_loss += loss.item()
     80             # print train process
     81             rate = (step + 1) / len(train_loader)
     82             a = "*" * int(rate * 50)
     83             b = "." * int((1 - rate) * 50)
     84             print("
    train loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
     85         print()
     86 
     87         # validate
     88         net.eval()
     89         acc = 0.0  # accumulate accurate number / epoch
     90         with torch.no_grad():
     91             for val_data in validate_loader:
     92                 val_images, val_labels = val_data
     93                 optimizer.zero_grad()
     94                 outputs = net(val_images.to(device))
     95                 predict_y = torch.max(outputs, dim=1)[1]
     96                 acc += (predict_y == val_labels.to(device)).sum().item()
     97             val_accurate = acc / val_num
     98             if val_accurate > best_acc:
     99                 best_acc = val_accurate
    100                 torch.save(net.state_dict(), save_path)
    101             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
    102                   (epoch + 1, running_loss / step, val_accurate))
    103 
    104     print('Finished Training')
    105 
    106 
    107 if __name__ == '__main__':
    108     main()
  • 相关阅读:
    QuickStart系列:docker部署之Gitlab本地代码仓库
    https环境搭建(本地搭建)
    docker搭建elk
    使用本机IP调试web项目
    VC++ 异常处理 __try __except的用法
    Delphi编程常用快捷键大全
    Delphi2007安装报Invalid Serial Number问题
    Cannot create file "C:UsersADMINI~1AppDataLocalTempEditorLineEnds.ttr"
    delphi 调试的时候变量全部显示Inaccessible value的解决办法
    Delphi idhttp解决获取UTF-8网页中文乱码问题
  • 原文地址:https://www.cnblogs.com/sclu/p/14163969.html
Copyright © 2020-2023  润新知