• mxnet 线性模型


    mxnet 线性模型

    <wiz_code_mirror>
     
     
     
    74
    def data_loader(batch_size, X, y, shuffle=False):
     
     
     
    1
    import mxnet
    2
    import mxnet.ndarray as nd
    3
    from mxnet import gluon
    4
    from mxnet import autograd
    5
    
    
    6
    
    
    7
    # create data
    8
    
    
    9
    def set_data(true_w, true_b, num_examples, *args, **kwargs):
    10
        num_inputs = len(true_w)
    11
        X = nd.random_normal(shape=(num_examples, num_inputs))
    12
        y = 0
    13
        for num in range(num_inputs):
    14
            # print(num)
    15
            y += true_w[num] * X[:, num]
    16
        y += true_b
    17
        y += 0.1 * nd.random_normal(shape=y.shape)
    18
        return X, y
    19
    
    
    20
    
    
    21
    # create data loader
    22
    def data_loader(batch_size, X, y, shuffle=False):
    23
        data_set = gluon.data.ArrayDataset(X, y)
    24
        data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
    25
        return data_iter
    26
    
    
    27
    
    
    28
    # create net
    29
    def set_net(node_num):
    30
        net = gluon.nn.Sequential()
    31
        net.add(gluon.nn.Dense(node_num))
    32
        net.initialize()
    33
        return net
    34
    
    
    35
    
    
    36
    # create trainer
    37
    def trainer(net, loss_method, learning_rate):
    38
        trainer = gluon.Trainer(
    39
            net.collect_params(), loss_method, {'learning_rate': learning_rate}
    40
        )
    41
        return trainer
    42
    
    
    43
    
    
    44
    square_loss = gluon.loss.L2Loss()
    45
    
    
    46
    
    
    47
    # start train
    48
    def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
    49
        for e in range(epochs):
    50
            total_loss = 0
    51
            for data, label in data_iter:
    52
                with autograd.record():
    53
                    output = net(data)
    54
                    loss = loss_method(output, label)
    55
                loss.backward()
    56
                trainer.step(batch_size)
    57
                total_loss += nd.sum(loss).asscalar()
    58
            print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
    59
        dense = net[0]
    60
    
    
    61
        print(dense.weight.data())
    62
        print(dense.bias.data())
    63
        return dense.weight.data(), dense.bias.data()
    64
    
    
    65
    
    
    66
    true_w = [5, 8, 6]
    67
    true_b = 6
    68
    X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
    69
    data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
    70
    net = set_net(1)
    71
    trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
    72
    start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
    73
                num_examples=1000)
    74
    
    
     
     
  • 相关阅读:
    触发器基本使用
    查询结果合并用逗号分隔
    查询报表增加小计功能
    sql语句格式化数字(前面补0)
    如何在选择画面中创建下拉列表(drop down list)-as list box
    如何更改函数的函数组(function group)
    ABAP语言中如何定义嵌套内表(nested internal table)
    [REUSE_ALV_GRID_DISPLAY]如何指定单元格颜色
    如何创建嵌套动态内表(Nested dynamic internal table)
    如何根据方法名(method)查找所在类(class)-SE84
  • 原文地址:https://www.cnblogs.com/liaoxianfu/p/a00a6b34a5de7cdbd50ef91b7121cef3.html
Copyright © 2020-2023  润新知