• MXNet 中symbol 绑定


    于看到MXNet 的doc 了。今天准备做些GAN的试验需要手工hack些步骤,遇到要绑定,想起毕业设计时在julia中闷着拼差点又要热血上涌。。。。(罪过)。还好,在example 里面发现了一些可参考的,但doc 中关于bind 的部分有些残缺,似乎只介绍了关于executor的接口:

    A = mx.Variable(:A)
    B = mx.Variable(:B)
    C = A .* B
    a = mx.ones(3) * 4
    b = mx.ones(3) * 2
    c_exec = mx.bind(C, context=mx.cpu(), args=Dict(:A => a, :B => b))
    mx.forward(c_exec)
    copy(c_exec.outputs[1]) # copy turns NDArray into Julia Array
    # =>
    # 3-element Array{Float32,1}:
    # 8.0
    # 8.0
    # 8.0
    

    example中的程序更实在些,稍微改了下,记到这里:

    import mxnet as mx
    import numpy as np
    
    M,N=10,20
    device=mx.cpu()
    
    data=mx.sym.Variable('data')
    label=mx.sym.Variable('label')
    conv1=mx.sym.Convolution(data=data,kernel=(3,3),num_filter=2)
    flatten=mx.sym.Flatten(data=conv1)
    fc1=mx.sym.FullyConnected(data=flatten,num_hidden=1)
    loss_data=fc1# flatten
    loss=mx.sym.LogisticRegressionOutput(data=loss_data,label=label)
    img=np.zeros((M,N)).reshape((1,1,M,N))
    gdt=[1,]
    
    img=mx.nd.array(img)
    gdt=mx.nd.array(gdt)
    
    
    mod=mx.module.Module(symbol=loss,data_names=('data',),label_names=('label',))
    #D={'data':img,'label':gdt}
    mod.bind(data_shapes=[('data',(1,1,M,N))],label_shapes=[('label',(1,))],inputs_need_grad=True)
    mod.init_params()
    mod.init_optimizer(optimizer='adam')
    mod.forward(mx.io.DataBatch([img],[gdt]),is_train=True)
    out=mod.get_outputs()
    
  • 相关阅读:
    Servlet再度学习
    JSP九大内置对象
    Java I/O学习
    Java内存管理
    数据库面试常问的一些基本概念
    JVM类加载原理学习笔记
    Ajax原理学习
    Java基础之泛型
    Java基础之集合
    java多线程快速入门(二)
  • 原文地址:https://www.cnblogs.com/chenyliang/p/6780354.html
Copyright © 2020-2023  润新知