• Theano笔记


    scan函数

    theano.scan(fnsequences=Noneoutputs_info=None,non_sequences=Nonen_steps=Nonetruncate_gradient=-1,go_backwards=Falsemode=Nonename=Noneprofile=False)

     outputs_info is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. 

    fn是每一步所用的函数,sequences是输入,outputs_info是scan输出在起始的状态。sequences and outputs_info are all parameters of fn in ordered sequence.

    scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                         , Sequence2
                         , dict(input =  Sequence3, taps = 3) ]
           , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                            , dict(initial = Output2, taps = None)
                            , Output3 ]
           , non_sequences = [ Argument1, Argument2])

    fn should expect the following arguments in this given order:

    1. Sequence1[t-3]
    2. Sequence1[t+2]
    3. Sequence1[t-1]
    4. Sequence2[t]
    5. Sequence3[t+3]
    6. Output1[t-3]
    7. Output1[t-5]
    8. Output3[t-1]
    9. Argument1
    10. Argument2

    import theano
    import theano.tensor as T
    mode = theano.Mode(linker='cvm')
    import numpy as np


    def fun(a,b):
    return a+b
    input=T.vector("input")
    output,update=theano.scan(fun,sequences=input,outputs_info=[T.as_tensor_variable(np.asarray(1,input.dtype))])

    out=theano.function(inputs=[input],outputs=output)

    in1=numpy.array([1,2,3])
    print out(in1)

     def fun(a,b):
    return a+b
    input=T.matrix("input")
    output,update=theano.scan(fun,sequences=input,outputs_info=[T.as_tensor_variable(np.asarray([0,0,0],input.dtype))])

    out=theano.function(inputs=[input,],outputs=output)

    in1=numpy.array([[1,2,3],[4,5,6]])
    print(in1)
    print out(in1)

    shared variables相当于全局变量,The value can be accessed and modified by the.get_value() and .set_value() methods.  在function里用updata来修改可以并行。

    scan的输出是一个symbol,用来在后面的theano function里作为output和update的规则。当sequences=None时,n_steps应有一个值来限制对后面theano function里的input的循环次数。当sequences不为空时,theano function直接对sequences循环:

    components, updates = theano.scan(fn=lambda coefficient, power, free_variable: coefficient * (free_variable ** power),
                                      outputs_info=None,
                                      sequences=[coefficients, theano.tensor.arange(max_coefficients_supported)],
                                      non_sequences=x)

    这个例子中,

    theano.tensor.arange(max_coefficients_supported)类似于enumerate的index,coefficientes相当与enumerate里到序列值。这里根据顺序,x为free_variable.

    Debug:

    http://deeplearning.net/software/theano/tutorial/debug_faq.html

    theano.config.compute_test_value = 'warn'
    • off: Default behavior. This debugging mechanism is inactive.
    • raise: Compute test values on the fly. Any variable for which a test value is required, but not provided by the user, is treated as an error. An exception is raised accordingly.
    • warn: Idem, but a warning is issued instead of an Exception.
    • ignore: Silently ignore the computation of intermediate test values, if a variable is missing a test value.
    import theano
    
    def inspect_inputs(i, node, fn):
        print i, node, "input(s) value(s):", [input[0] for input in fn.inputs],
    
    def inspect_outputs(i, node, fn):
        print "output(s) value(s):", [output[0] for output in fn.outputs]
    
    x = theano.tensor.dscalar('x')
    f = theano.function([x], [5 * x],
                        mode=theano.compile.MonitorMode(
                            pre_func=inspect_inputs,
                            post_func=inspect_outputs))
    f(3)
    

     

    mode = 'DEBUG_MODE' 很慢,无效?

    使用print

    x = theano.tensor.dvector('x')
    
    x_printed = theano.printing.Print('this is a very important value')(x)
    
    f = theano.function([x], x * 5)
    f_with_print = theano.function([x], x_printed * 5)
    
    #this runs the graph without any printing
    assert numpy.all( f([1, 2, 3]) == [5, 10, 15])
    
    #this runs the graph with the message, and value printed
    assert numpy.all( f_with_print([1, 2, 3]) == [5, 10, 15])
  • 相关阅读:
    struts2 标签为简单标签
    html a标签链接使用action 参数传递中文乱码
    html 字体加粗
    Unity3D学习笔记(一):Unity简介、游戏物体、组件和生命周期函数
    Unity3D学习笔记(一):Unity3D简介 111
    C#学习笔记(二十):C#总结和月考讲解
    C#学习笔记(十九):字典
    C#学习笔记(十八):数据结构和泛型
    C#学习笔记(十七):委托、事件、观察者模式、匿名委托和lambert表达式
    C#学习笔记(十六):索引器和重载运算符
  • 原文地址:https://www.cnblogs.com/huashiyiqike/p/3553325.html
Copyright © 2020-2023  润新知