pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的。
我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。但有一些注意技巧:
- 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
- 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
- forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
所有放在构造函数__init__里面的层的都是这个模型的“固有属性”。