自定义层
本节将介绍如何使用NDArray
来自定义一个Gluon
的层,从而可以被重复调用。
不含模型参数的自定义层
下面的CenteredLayer
类通过继承Block
类自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward
函数里。这个层里不含模型参数。
#导包
from mxnet import gluon, nd
from mxnet.gluon import nn
#声明CenteredLayer类继承Block类
class CenteredLayer(nn.Block):
#初始化
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
#定义正向传播,这里是定义一个减掉均值的层
def forward(self, x):
return x - x.mean()
实例化这个层,然后做前向计算:
layer = CenteredLayer()
layer(nd.array([1, 2, 3, 4, 5]))
用它来构造更复杂的模型:
net = nn.Sequential()
net.add(nn.Dense(128),
CenteredLayer())
下面打印自定义层各个输出的均值。因为均值是浮点数,所以它的值是一个很接近0的数:
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))
y.mean().asscalar()
含模型参数的自定义层
可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学出。在自定义含模型参数的层时,我们可以利用Block
类自带的ParameterDict
类型的成员变量params
。它是一个由字符串类型的参数名字映射到Parameter
类型的模型参数的字典。我们可以通过get
函数从ParameterDict
创建Parameter
实例。
params = gluon.ParameterDict()
params.get('param2', shape=(2, 3))
params
尝试实现一个含权重参数和偏差参数的全连接层。它使用ReLU
函数作为激活函数。其中in_units
和units
分别代表输入个数和输出个数。
#申明MyDense继承Block类
class MyDense(nn.Block):
# units为该层的输出个数,in_units为该层的输入个数
def __init__(self, units, in_units, **kwargs):
#初始化父类
super(MyDense, self).__init__(**kwargs)
#初始化权重参数,使用get函数创建Parameter实例,shape=(输入个数,输出个数)
self.weight = self.params.get('weight', shape=(in_units, units))
#初始化偏置参数,同样使用get函数创建Parameter实例
self.bias = self.params.get('bias', shape=(units,))
#定义正向传播
def forward(self, x):
#Y=Xw+b
linear = nd.dot(x, self.weight.data()) + self.bias.data()
#使用激活函数
return nd.relu(linear)
实例化MyDense
类并访问它的模型参数。
#实例化MyDense,输入个数:5,输出个数:3
dense = MyDense(units=3, in_units=5)
dense.params
#输出:
mydense0_ (
Parameter mydense0_weight (shape=(5, 3), dtype=<class 'numpy.float32'>)
Parameter mydense0_bias (shape=(3,), dtype=<class 'numpy.float32'>)
)
可以直接使用自定义层做前向计算:
#初始化模型参数
dense.initialize()
#正向传播
dense(nd.random.uniform(shape=(2, 5)))
也可以使用自定义层构造模型。它和Gluon
的其他层在使用上很类似:
#实例化Sequential类
net = nn.Sequential()
#添加隐藏层shape(64,8),添加输出层shape(8,1)
net.add(MyDense(8, in_units=64),
MyDense(1, in_units=8))
#初始化模型参数
net.initialize()
#正向传播
net(nd.random.uniform(shape=(2, 64)))