• theano中的scan用法


    Scan是干什么的

    函数scan是Theano中迭代的一般形式,所以可以用于类似循环(looping)的场景。 如果你熟悉Reduction和map两个函数,这两个都是scan的特殊形式,即将某函数依次作用一个序列的每个元素上。 函数scan的输入也是一些序列(一维数组,或者多维数组,以第一维为leading dimension),将某个函数作用于输入序列上,得到每一步输出的结果。 和Reduction和map两个函数不同之处在于,scan在计算的时候,可以访问以前n步的输出结果,所以比较适合RNN网络。

    为什么要使用scan

    看起来scan完全可以用for... loop来代替,然而scan有其自身的优点:

    • 由于Theano是使用符号代数的,迭代的次数就自然成为符号代数的一部分。也就是说迭代次数也会体现在构造符号代数的图中。

    (Theano用一个图来表示符号代数)

    • 由于上面一条,可以直接用Theano计算梯度。
    • 优化减少CPU和GPU之间的数据传输,比Python Loop稍微快一点。
    • 说不定以后还会有符号代数的其他优点,例如自动优化 y = x/x*x。

    大概参数说明

    函数scan调用的一般形式的一个例子大概是这样:

    results, updates = theano.scan(fn = lambda y, p, x_tm2, x_tm1,A: y+p+x_tm2+xtm1+A,
    sequences=[Y, P[::-1]],
    outputs_info=[dict(initial=X, taps=[-2, -1])]),
    non_sequences=A)
    
    • 参数fn是一个你需要计算的函数,一般用lambda来定义,参数是有顺序要求的,先是sequances的参数(y,p),然后是output_info的参数(x_tm2,x_tm1),然后是no_sequences的参数(A)。
    • sequences就是需要迭代的序列,序列的第一个维度(leading dimension)就是需要迭代的次数。所以,Y和P[::-1]的第一维大小应该相同,如果不同的话,就会取最小的。
    • outputs_info描述了需要用到前几次迭代输出的结果,dict(initial=X, taps=[-2, -1])表示使用前一次和前两次输出的结果。如果当前迭代输出为x(t),则计算中使用了(x(t-1)和x(t-2)。 
    • non_sequences描述了非序列的输入,即A是一个固定的输入,每次迭代加的A都是相同的。如果Y是一个向量,A就是一个常数,总之,A比Y少一个维度。

    举例

    计算 Ak , 大材小用一下

    k = T.iscalar("k")
    A = T.vector("A")
    
    # Symbolic description of the result
    result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
    outputs_info=T.ones_like(A),
    non_sequences=A,
    n_steps=k)
    
    # We only care about A**k, but scan has provided us with A**1 through A**k.
    # Discard the values that we don't care about. Scan is smart enough to
    # notice this and not waste memory saving them.
    final_result = result[-1]
    
    # compiled function that returns A**k
    power = theano.function(inputs=[A,k], outputs=final_result, updates=updates)
    
    print power(range(10),2)
    print power(range(10),4)
    


    输出:

    [  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]
    [  0.00000000e+00   1.00000000e+00   1.60000000e+01   8.10000000e+01
       2.56000000e+02   6.25000000e+02   1.29600000e+03   2.40100000e+03
       4.09600000e+03   6.56100000e+03]
    

    计算 Computing tanh(x(t).dot(W) + b) 

    X = T.matrix("X")
    W = T.matrix("W")
    b_sym = T.vector("b_sym")
    
    results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
    compute_elementwise = theano.function(inputs=[X, W, b_sym], outputs=[results])
    
    # test values
    x = np.eye(2, dtype=theano.config.floatX)
    w = np.ones((2, 2), dtype=theano.config.floatX)
    b = np.ones((2), dtype=theano.config.floatX)
    b[1] = 2
    print compute_elementwise(x, w, b)[0]
    # comparison with numpy
    print np.tanh(x.dot(w) + b)
    


    输出:

    [[ 0.96402758  0.99505475]
     [ 0.96402758  0.99505475]]
    [[ 0.96402758  0.99505475]
     [ 0.96402758  0.99505475]]
  • 相关阅读:
    mybatis的注意事项一
    java代码操作word模板生成PDF文件
    使用mybatis框架实现带条件查询多条件(传入实体类)
    MyBatis框架ResultMap节点
    优化mybatis框架中的查询用户记录数的案例
    Mybatis框架联表查询显示问题解决
    使用mybatis框架实现带条件查询单条件
    [DB] 如何彻底卸载删除MySQL 【MYSQL】
    [DB] MySQL窗口输入密码后消失问题 【MYSQL】
    [acm] 曾经 刷题记录 [只有正确的坚持才是胜利]
  • 原文地址:https://www.cnblogs.com/anyview/p/5014650.html
Copyright © 2020-2023  润新知