• mxnet 神经网络训练和预测


    https://mxnet.incubator.apache.org/tutorials/basic/module.html

    import logging
    import random
    logging.getLogger().setLevel(logging.INFO)
    
    import mxnet as mx
    import numpy as np
    
    mx.random.seed(1234)
    np.random.seed(1234)
    random.seed(1234)
    
    # 准备数据
    fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data')
    data = np.genfromtxt(fname=fname,delimiter=',')[:,1:]
    label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])
    
    batch_size = 32
    ntrain = int(data.shape[0]*0.8)
    
    train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True)
    val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size)
    
    
    # 定义网络
    net = mx.sym.Variable('data')
    net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
    net = mx.sym.Activation(net, name='relu1', act_type="relu")
    net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
    net = mx.sym.SoftmaxOutput(net, name='softmax')
    mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"})
    
    
    
    # # 创建模块
    mod = mx.mod.Module(symbol=net,
                        context=mx.cpu(),
                        data_names=['data'],
                        label_names=['softmax_label'])
    
    # # 中层接口
    # # 训练模型
    # mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)
    # mod.init_params(initializer=mx.init.Uniform(scale=.1))
    # mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),))
    # metric = mx.metric.create('acc')
    #
    # for epoch in range(100):
    #     train_iter.reset()
    #     metric.reset()
    #     for batch in train_iter:
    #         mod.forward(batch,is_train=True)
    #         mod.update_metric(metric,batch.label)
    #         mod.backward()
    #         mod.update()
    #     print('Epoch %d,Training %s' % (epoch,metric.get()))
    
    # fit 高层接口
    train_iter.reset()
    mod = mx.mod.Module(symbol=net,
                        context=mx.cpu(),
                        data_names=['data'],
                        label_names=['softmax_label'])
    
    mod.fit(train_iter,
            eval_data=val_iter,
            optimizer='sgd',
            optimizer_params={'learning_rate':0.1},
            eval_metric='acc',
            num_epoch=10)
    
    
    
    # 预测和评估
    y = mod.predict(val_iter)
    assert y.shape == (4000,26)
    
    # 评分
    score = mod.score(val_iter,['acc'])
    print("Accuracy score is %f"%(score[0][1]))
    assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1]
    
    # 保存和加载
    # 构造一个回调函数保存检查点
    model_prefix = 'mx_mlp'
    checkpoint = mx.callback.do_checkpoint(model_prefix)
    
    mod = mx.mod.Module(symbol=net)
    mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint)
    
    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
    assert sym.tojson() == net.tojson()
    
    # assign the loaded parameters to the module
    mod.set_params(arg_params, aux_params)
    
    mod = mx.mod.Module(symbol=sym)
    mod.fit(train_iter,
            num_epoch=21,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=3)
    assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

  • 相关阅读:
    【UNR#3】白鸽
    【POI2011】Garbage
    【NOI2010】海拔
    【HNOI2012】矿场搭建
    【UOJ#177】欧拉回路
    【BZOJ4500】矩阵
    【CF429E】Points and Segments
    【agc001F】Wide Swap
    【BZOJ2138】stone
    【JSOI2009】游戏
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10065616.html
Copyright © 2020-2023  润新知