• 机器学习笔记(4):多类逻辑回归-使用gluton


    接上一篇机器学习笔记(3):多类逻辑回归继续,这次改用gluton来实现关键处理,原文见这里 ,代码如下:

    import matplotlib.pyplot as plt
    import mxnet as mx
    from mxnet import gluon
    from mxnet import ndarray as nd
    from mxnet import autograd
    
    def transform(data, label):
        return data.astype('float32')/255, label.astype('float32')
    
    mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
    mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
    
    def show_images(images):
        n = images.shape[0]
        _, figs = plt.subplots(1, n, figsize=(15, 15))
        for i in range(n):
            figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
            figs[i].axes.get_xaxis().set_visible(False)
            figs[i].axes.get_yaxis().set_visible(False)
        plt.show()
    
    def get_text_labels(label):
        text_labels = [
            'T 恤', '长 裤', '套头衫', '裙 子', '外 套',
            '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'
        ]
        return [text_labels[int(i)] for i in label]
    
    data, label = mnist_train[0:10]
    
    print('example shape: ', data.shape, 'label:', label)
    
    show_images(data)
    
    print(get_text_labels(label))
    
    batch_size = 256
    
    train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
    test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
    
    num_inputs = 784
    num_outputs = 10
    
    W = nd.random_normal(shape=(num_inputs, num_outputs))
    b = nd.random_normal(shape=num_outputs)
    params = [W, b]
    
    for param in params:
        param.attach_grad()
    
    def accuracy(output, label):
        return nd.mean(output.argmax(axis=1) == label).asscalar()
    
    def _get_batch(batch):
        if isinstance(batch, mx.io.DataBatch):
            data = batch.data[0]
            label = batch.label[0]
        else:
            data, label = batch
        return data, label
    
    def evaluate_accuracy(data_iterator, net):
        acc = 0.
        if isinstance(data_iterator, mx.io.MXDataIter):
            data_iterator.reset()
        for i, batch in enumerate(data_iterator):
            data, label = _get_batch(batch)
            output = net(data)
            acc += accuracy(output, label)
        return acc / (i+1)
    
    #使用gluon定义计算模型
    net = gluon.nn.Sequential()
    with net.name_scope():
        net.add(gluon.nn.Flatten())
        net.add(gluon.nn.Dense(10))
    net.initialize()
    
    #损失函数(使用交叉熵函数)
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    
    #使用梯度下降法生成训练器,并设置学习率为0.1
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
    
    for epoch in range(5):
        train_loss = 0.
        train_acc = 0.
        for data, label in train_data:
            with autograd.record():
                output = net(data)
                #计算损失
                loss = softmax_cross_entropy(output, label) 
            loss.backward()
            #使用sgd的trainer继续向前"走一步"
            trainer.step(batch_size)
            
            train_loss += nd.mean(loss).asscalar()
            train_acc += accuracy(output, label)
    
        test_acc = evaluate_accuracy(test_data, net)
        print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
            epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
    
    data, label = mnist_test[0:10]
    show_images(data)
    print('true labels')
    print(get_text_labels(label))
    
    predicted_labels = net(data).argmax(axis=1)
    print('predicted labels')
    print(get_text_labels(predicted_labels.asnumpy()))
    

    相对上一版原始手动方法,使用gluon修改的地方都加了注释,不多解释。运行效果如下:

    相对之前的版本可以发现,几乎相同的参数,但是准确度有所提升,从0.7几上升到0.8几,10个里错误的预测数从4个下降到3个,说明gluon在一些细节上做了更好的优化。关于优化的细节,这里有一些讨论,供参考

  • 相关阅读:
    wireshark如何抓取本机包
    模拟post请求方法
    Spring Boot中使用RabbitMQ
    Dubbo注册中心的四种配置方式详解
    spring扩展点之三:Spring 的监听事件 ApplicationListener 和 ApplicationEvent 用法,在spring启动后做些事情
    zookeeper 大量连接断开重连原因排查
    分布式一致性协议之:Gossip(八卦)算法
    MongoDB分析工具之一:explain()语句分析工具
    MongoDB分析工具之二:MongoDB分析器Profile
    MySQL安装
  • 原文地址:https://www.cnblogs.com/yjmyzz/p/8034597.html
Copyright © 2020-2023  润新知