• 使用gluon实现简单的CNN(二)


    from mxnet import ndarray as nd
    from mxnet import gluon
    from mxnet import autograd
    from mxnet.gluon import nn
    
    def transform(data, label):
        return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
    mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
    mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
    
    batch_size = 256
    train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
    test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
    import mxnet as mx
    try:
        ctx = mx.gpu()
        _ = nd.zeros((1,), ctx = ctx)
    except:
        ctx = mx.cpu()
    ctx
    def accuracy(output, label):
        return nd.mean(output.argmax(axis=1)==label).asscalar()
    
    def evaluate_accuracy(data_iterator, net):
        acc = 0.
        for data, label in data_iterator:
            output = net(data)
            acc += accuracy(output, label)
        return acc / len(data_iterator)
    net = nn.Sequential()
    with net.name_scope():
            net.add(
                nn.Conv2D(channels=20, kernel_size=5, activation='relu'),
                nn.MaxPool2D(pool_size=2, strides=2),
                nn.Conv2D(channels=50, kernel_size=3, activation='relu'),
                nn.MaxPool2D(pool_size=2, strides=2),
                nn.Flatten(),
                nn.Dense(128, activation="relu"),
                nn.Dense(10))
    net.initialize(ctx=ctx)
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.2})
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    
    for epoch in range(5):
        train_loss = 0.
        train_acc = 0.
        for data, label in train_data:
            label = label.as_in_context(ctx)
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label)
            loss.backward()
            trainer.step(batch_size)
            
            train_loss += nd.mean(loss).asscalar()
            train_acc += accuracy(output, label)
            
        test_acc = evaluate_accuracy(test_data, net)
        print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc))
  • 相关阅读:
    软件推荐Q10 CircleDock PHP
    Google Chrome浏览器 扩展程序推荐 PHP
    jsColor取色器 PHP
    在线指法练习【怀旧版】 PHP
    model工厂类(转)
    表变量与临时表的优缺点
    项目的阶段性目标管理
    如何配置不启用安全的WCF服务
    团队高效执行力从何而来
    socket connect函数本质含义
  • 原文地址:https://www.cnblogs.com/hxjbc/p/7908443.html
Copyright © 2020-2023  润新知