• mxnet 线性模型


    mxnet 线性模型

    <wiz_code_mirror>
     
     
     
    74
    def data_loader(batch_size, X, y, shuffle=False):
     
     
     
    1
    import mxnet
    2
    import mxnet.ndarray as nd
    3
    from mxnet import gluon
    4
    from mxnet import autograd
    5
    
    
    6
    
    
    7
    # create data
    8
    
    
    9
    def set_data(true_w, true_b, num_examples, *args, **kwargs):
    10
        num_inputs = len(true_w)
    11
        X = nd.random_normal(shape=(num_examples, num_inputs))
    12
        y = 0
    13
        for num in range(num_inputs):
    14
            # print(num)
    15
            y += true_w[num] * X[:, num]
    16
        y += true_b
    17
        y += 0.1 * nd.random_normal(shape=y.shape)
    18
        return X, y
    19
    
    
    20
    
    
    21
    # create data loader
    22
    def data_loader(batch_size, X, y, shuffle=False):
    23
        data_set = gluon.data.ArrayDataset(X, y)
    24
        data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
    25
        return data_iter
    26
    
    
    27
    
    
    28
    # create net
    29
    def set_net(node_num):
    30
        net = gluon.nn.Sequential()
    31
        net.add(gluon.nn.Dense(node_num))
    32
        net.initialize()
    33
        return net
    34
    
    
    35
    
    
    36
    # create trainer
    37
    def trainer(net, loss_method, learning_rate):
    38
        trainer = gluon.Trainer(
    39
            net.collect_params(), loss_method, {'learning_rate': learning_rate}
    40
        )
    41
        return trainer
    42
    
    
    43
    
    
    44
    square_loss = gluon.loss.L2Loss()
    45
    
    
    46
    
    
    47
    # start train
    48
    def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
    49
        for e in range(epochs):
    50
            total_loss = 0
    51
            for data, label in data_iter:
    52
                with autograd.record():
    53
                    output = net(data)
    54
                    loss = loss_method(output, label)
    55
                loss.backward()
    56
                trainer.step(batch_size)
    57
                total_loss += nd.sum(loss).asscalar()
    58
            print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
    59
        dense = net[0]
    60
    
    
    61
        print(dense.weight.data())
    62
        print(dense.bias.data())
    63
        return dense.weight.data(), dense.bias.data()
    64
    
    
    65
    
    
    66
    true_w = [5, 8, 6]
    67
    true_b = 6
    68
    X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
    69
    data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
    70
    net = set_net(1)
    71
    trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
    72
    start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
    73
                num_examples=1000)
    74
    
    
     
     
  • 相关阅读:
    springmvc结合freemarker,非自定义标签
    springmvc的ModelAndView的简单使用
    tomcat无法正常启动的一个原因
    通过springmvc的RequestMapping的headers属性的使用
    springmvc入门demo
    Redis的入门Demo(java)
    Ubuntu18.0.4查看显示器型号
    APS审核经验+审核资料汇总——计算机科学与技术专业上海德语审核
    Java连接GBase并封装增删改查
    SpringMVC源码阅读:异常解析器
  • 原文地址:https://www.cnblogs.com/liaoxianfu/p/a00a6b34a5de7cdbd50ef91b7121cef3.html
Copyright © 2020-2023  润新知