使用共享变量
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 4 23:28:21 2014
@author: wencc
"""
from theano import shared
from theano import function
import theano.tensor as T
if __name__ == '__main__':
state = shared(0)
inc = T.iscalar('inc')
accumulator = function([inc], state, updates=[(state, state+inc)])
print state.get_value()
print accumulator(1)
print state.get_value()
print accumulator(20)
print state.get_value()
fn_of_state = state*2 + inc
foo = T.scalar(dtype=state.dtype)
skip_shared = function([inc, foo], fn_of_state, givens=[(state, foo)])
skip_shared(1, 3)
state.get_value()
shared函数构造共享变量,共享变量的get_value,set_value函数用来查看和设置共享变量的值
function函数中的updates参数用来更新共享变量,它是一个list,list中的每一项用map(共享变量,共享变量的新值表达式)的形式来表示。