网络结构
因为主要是学习pytorch,具体原理没有深究。如果将来搞CV的话,可能再回来搞懂吧。
网络结构大概就是,用多个卷积核提取特征,然后将提取到的特征拼接在一起
网络结构如下:
实现思路是,首先定义卷积模型(包括卷积层和BN层),然后再实现Inception的Block(图中所示结构)
卷积模型实现
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x), inplace = True)
return x
这里因为后面要传入卷积核的大小,padding的大小,因此要使用可变长参数
Inception结构实现
class InceptionBlock(nn.Module):
def __init__(self, in_channels, pool_features):
super().__init__()
self.b1x1 = BasicConv(in_channels, 64, kernel_size = 1)
self.b3x3_1 = BasicConv(in_channels, 64, kernel_size = 1)
self.b3x3_2 = BasicConv(64, 96, kernel_size = 3, padding = 1)
self.b5x5_1 = BasicConv(in_channels, 48, kernel_size = 1)
self.b5x5_2 = BasicConv(48, 64, kernel_size = 5, padding = 2)
self.bpool = BasicConv(in_channels, pool_features, kernel_size = 1)
def forward(self, x):
b1x1_out = self.b1x1(x)
b3x3_out = self.b3x3_2(self.b3x3_1(x))
b5x5_out = self.b5x5_2(self.b5x5_1(x))
bpool_out = self.bpool(F.max_pool2d(x, kernel_size = 3, stride = 1, padding = 1))
out= [b1x1_out, b3x3_out, b5x5_out, bpool_out]
return torch.cat(out, dim = 1)