一。GoogLeNet网络结构
1.特点:
采用inspection结构和2个辅助的分类器。inspection结构是并行结构。加入了1x1的卷积核来实现降维,能够减少训练参数。
2.网络结构
3.Inspection结构
4.参数列表
二。训练代码
model.py
1 import torch.nn as nn 2 import torch 3 import torch.nn.functional as F 4 5 6 class GoogLeNet(nn.Module): 7 def __init__(self, num_classes=1000, aux_logits=True, init_weights=False): 8 super(GoogLeNet, self).__init__() 9 self.aux_logits = aux_logits 10 11 self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 12 self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 13 14 self.conv2 = BasicConv2d(64, 64, kernel_size=1) 15 self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 16 self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 17 18 self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 19 self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 20 self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 21 22 self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 23 self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 24 self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 25 self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 26 self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 27 self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 28 29 self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 30 self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 31 32 if self.aux_logits: 33 self.aux1 = InceptionAux(512, num_classes) 34 self.aux2 = InceptionAux(528, num_classes) 35 36 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 37 self.dropout = nn.Dropout(0.4) 38 self.fc = nn.Linear(1024, num_classes) 39 if init_weights: 40 self._initialize_weights() 41 42 def forward(self, x): 43 # N x 3 x 224 x 224 44 x = self.conv1(x) 45 # N x 64 x 112 x 112 46 x = self.maxpool1(x) 47 # N x 64 x 56 x 56 48 x = self.conv2(x) 49 # N x 64 x 56 x 56 50 x = self.conv3(x) 51 # N x 192 x 56 x 56 52 x = self.maxpool2(x) 53 54 # N x 192 x 28 x 28 55 x = self.inception3a(x) 56 # N x 256 x 28 x 28 57 x = self.inception3b(x) 58 # N x 480 x 28 x 28 59 x = self.maxpool3(x) 60 # N x 480 x 14 x 14 61 x = self.inception4a(x) 62 # N x 512 x 14 x 14 63 if self.training and self.aux_logits: # eval model lose this layer 64 aux1 = self.aux1(x) 65 66 x = self.inception4b(x) 67 # N x 512 x 14 x 14 68 x = self.inception4c(x) 69 # N x 512 x 14 x 14 70 x = self.inception4d(x) 71 # N x 528 x 14 x 14 72 if self.training and self.aux_logits: # eval model lose this layer 73 aux2 = self.aux2(x) 74 75 x = self.inception4e(x) 76 # N x 832 x 14 x 14 77 x = self.maxpool4(x) 78 # N x 832 x 7 x 7 79 x = self.inception5a(x) 80 # N x 832 x 7 x 7 81 x = self.inception5b(x) 82 # N x 1024 x 7 x 7 83 84 x = self.avgpool(x) 85 # N x 1024 x 1 x 1 86 x = torch.flatten(x, 1) 87 # N x 1024 88 x = self.dropout(x) 89 x = self.fc(x) 90 # N x 1000 (num_classes) 91 if self.training and self.aux_logits: # eval model lose this layer 92 return x, aux2, aux1 93 return x 94 95 def _initialize_weights(self): 96 for m in self.modules(): 97 if isinstance(m, nn.Conv2d): 98 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 99 if m.bias is not None: 100 nn.init.constant_(m.bias, 0) 101 elif isinstance(m, nn.Linear): 102 nn.init.normal_(m.weight, 0, 0.01) 103 nn.init.constant_(m.bias, 0) 104 105 106 class Inception(nn.Module): 107 def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 108 super(Inception, self).__init__() 109 110 self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 111 112 self.branch2 = nn.Sequential( 113 BasicConv2d(in_channels, ch3x3red, kernel_size=1), 114 BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小 115 ) 116 117 self.branch3 = nn.Sequential( 118 BasicConv2d(in_channels, ch5x5red, kernel_size=1), 119 BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小 120 ) 121 122 self.branch4 = nn.Sequential( 123 nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 124 BasicConv2d(in_channels, pool_proj, kernel_size=1) 125 ) 126 127 def forward(self, x): 128 branch1 = self.branch1(x) 129 branch2 = self.branch2(x) 130 branch3 = self.branch3(x) 131 branch4 = self.branch4(x) 132 133 outputs = [branch1, branch2, branch3, branch4] 134 return torch.cat(outputs, 1) 135 136 137 class InceptionAux(nn.Module): 138 def __init__(self, in_channels, num_classes): 139 super(InceptionAux, self).__init__() 140 self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3) 141 self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4] 142 143 self.fc1 = nn.Linear(2048, 1024) 144 self.fc2 = nn.Linear(1024, num_classes) 145 146 def forward(self, x): 147 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 148 x = self.averagePool(x) 149 # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 150 x = self.conv(x) 151 # N x 128 x 4 x 4 152 x = torch.flatten(x, 1) 153 x = F.dropout(x, 0.5, training=self.training) 154 # N x 2048 155 x = F.relu(self.fc1(x), inplace=True) 156 x = F.dropout(x, 0.5, training=self.training) 157 # N x 1024 158 x = self.fc2(x) 159 # N x num_classes 160 return x 161 162 163 class BasicConv2d(nn.Module): 164 def __init__(self, in_channels, out_channels, **kwargs): 165 super(BasicConv2d, self).__init__() 166 self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 167 self.relu = nn.ReLU(inplace=True) 168 169 def forward(self, x): 170 x = self.conv(x) 171 x = self.relu(x) 172 return x
train.py
1 import torch 2 import torch.nn as nn 3 from torchvision import transforms, datasets 4 import torchvision 5 import json 6 import matplotlib.pyplot as plt 7 import os 8 import torch.optim as optim 9 from model import GoogLeNet 10 11 12 def main(): 13 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 print("using {} device.".format(device)) 15 16 data_transform = { 17 "train": transforms.Compose([transforms.RandomResizedCrop(224), 18 transforms.RandomHorizontalFlip(), 19 transforms.ToTensor(), 20 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 21 "val": transforms.Compose([transforms.Resize((224, 224)), 22 transforms.ToTensor(), 23 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 24 25 data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path 26 image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path 27 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 28 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 29 transform=data_transform["train"]) 30 train_num = len(train_dataset) 31 32 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 33 flower_list = train_dataset.class_to_idx 34 cla_dict = dict((val, key) for key, val in flower_list.items()) 35 # write dict into json file 36 json_str = json.dumps(cla_dict, indent=4) 37 with open('class_indices.json', 'w') as json_file: 38 json_file.write(json_str) 39 40 batch_size = 32 41 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 42 print('Using {} dataloader workers every process'.format(nw)) 43 44 train_loader = torch.utils.data.DataLoader(train_dataset, 45 batch_size=batch_size, shuffle=True, 46 num_workers=0) 47 48 validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 49 transform=data_transform["val"]) 50 val_num = len(validate_dataset) 51 validate_loader = torch.utils.data.DataLoader(validate_dataset, 52 batch_size=batch_size, shuffle=False, 53 num_workers=0) 54 55 print("using {} images for training, {} images fot validation.".format(train_num, 56 val_num)) 57 58 # test_data_iter = iter(validate_loader) 59 # test_image, test_label = test_data_iter.next() 60 61 # net = torchvision.models.googlenet(num_classes=5) 62 # model_dict = net.state_dict() 63 # pretrain_model = torch.load("googlenet.pth") 64 # del_list = ["aux1.fc2.weight", "aux1.fc2.bias", 65 # "aux2.fc2.weight", "aux2.fc2.bias", 66 # "fc.weight", "fc.bias"] 67 # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list} 68 # model_dict.update(pretrain_dict) 69 # net.load_state_dict(model_dict) 70 net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True) 71 net.to(device) 72 loss_function = nn.CrossEntropyLoss() 73 optimizer = optim.Adam(net.parameters(), lr=0.0003) 74 75 best_acc = 0.0 76 save_path = './googleNet.pth' 77 for epoch in range(30): 78 # train 79 net.train() 80 running_loss = 0.0 81 for step, data in enumerate(train_loader, start=0): 82 images, labels = data 83 optimizer.zero_grad() 84 logits, aux_logits2, aux_logits1 = net(images.to(device)) 85 loss0 = loss_function(logits, labels.to(device)) 86 loss1 = loss_function(aux_logits1, labels.to(device)) 87 loss2 = loss_function(aux_logits2, labels.to(device)) 88 loss = loss0 + loss1 * 0.3 + loss2 * 0.3 89 loss.backward() 90 optimizer.step() 91 92 # print statistics 93 running_loss += loss.item() 94 # print train process 95 rate = (step + 1) / len(train_loader) 96 a = "*" * int(rate * 50) 97 b = "." * int((1 - rate) * 50) 98 print(" train loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") 99 print() 100 101 # validate 102 net.eval() 103 acc = 0.0 # accumulate accurate number / epoch 104 with torch.no_grad(): 105 for val_data in validate_loader: 106 val_images, val_labels = val_data 107 outputs = net(val_images.to(device)) # eval model only have last output layer 108 predict_y = torch.max(outputs, dim=1)[1] 109 acc += (predict_y == val_labels.to(device)).sum().item() 110 val_accurate = acc / val_num 111 if val_accurate > best_acc: 112 best_acc = val_accurate 113 torch.save(net.state_dict(), save_path) 114 print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % 115 (epoch + 1, running_loss / step, val_accurate)) 116 117 print('Finished Training') 118 119 120 if __name__ == '__main__': 121 main()
predict.py
1 import torch 2 from model import GoogLeNet 3 from PIL import Image 4 from torchvision import transforms 5 import matplotlib.pyplot as plt 6 import json 7 8 data_transform = transforms.Compose( 9 [transforms.Resize((224, 224)), 10 transforms.ToTensor(), 11 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 12 13 # load image 14 img = Image.open("../rose.jpg") 15 plt.imshow(img) 16 # [N, C, H, W] 17 img = data_transform(img) 18 # expand batch dimension 19 img = torch.unsqueeze(img, dim=0) 20 21 # read class_indict 22 try: 23 json_file = open('./class_indices.json', 'r') 24 class_indict = json.load(json_file) 25 except Exception as e: 26 print(e) 27 exit(-1) 28 29 # create model 30 model = GoogLeNet(num_classes=5, aux_logits=False) 31 # load model weights 32 model_weight_path = "./googleNet.pth" 33 missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False) 34 model.eval() 35 with torch.no_grad(): 36 # predict class 37 output = torch.squeeze(model(img)) 38 predict = torch.softmax(output, dim=0) 39 predict_cla = torch.argmax(predict).numpy() 40 print(class_indict[str(predict_cla)]) 41 plt.show()