https://mxnet.incubator.apache.org/tutorials/basic/module.html
import logging import random logging.getLogger().setLevel(logging.INFO) import mxnet as mx import numpy as np mx.random.seed(1234) np.random.seed(1234) random.seed(1234) # 准备数据 fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data') data = np.genfromtxt(fname=fname,delimiter=',')[:,1:] label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')]) batch_size = 32 ntrain = int(data.shape[0]*0.8) train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True) val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size) # 定义网络 net = mx.sym.Variable('data') net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64) net = mx.sym.Activation(net, name='relu1', act_type="relu") net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26) net = mx.sym.SoftmaxOutput(net, name='softmax') mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"}) # # 创建模块 mod = mx.mod.Module(symbol=net, context=mx.cpu(), data_names=['data'], label_names=['softmax_label']) # # 中层接口 # # 训练模型 # mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label) # mod.init_params(initializer=mx.init.Uniform(scale=.1)) # mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),)) # metric = mx.metric.create('acc') # # for epoch in range(100): # train_iter.reset() # metric.reset() # for batch in train_iter: # mod.forward(batch,is_train=True) # mod.update_metric(metric,batch.label) # mod.backward() # mod.update() # print('Epoch %d,Training %s' % (epoch,metric.get())) # fit 高层接口 train_iter.reset() mod = mx.mod.Module(symbol=net, context=mx.cpu(), data_names=['data'], label_names=['softmax_label']) mod.fit(train_iter, eval_data=val_iter, optimizer='sgd', optimizer_params={'learning_rate':0.1}, eval_metric='acc', num_epoch=10) # 预测和评估 y = mod.predict(val_iter) assert y.shape == (4000,26) # 评分 score = mod.score(val_iter,['acc']) print("Accuracy score is %f"%(score[0][1])) assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1] # 保存和加载 # 构造一个回调函数保存检查点 model_prefix = 'mx_mlp' checkpoint = mx.callback.do_checkpoint(model_prefix) mod = mx.mod.Module(symbol=net) mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint) sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) assert sym.tojson() == net.tojson() # assign the loaded parameters to the module mod.set_params(arg_params, aux_params) mod = mx.mod.Module(symbol=sym) mod.fit(train_iter, num_epoch=21, arg_params=arg_params, aux_params=aux_params, begin_epoch=3) assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]