• tf.function (TensorFlow > API > TensorFlow Core v2.2.0 > Python)


    tf.function 是 tf 2.x新增的主要功能,函数的装饰器(decorator),将函数编译为可调用的TensorFlow图。

    tf.function(
        func=None, input_signature=None, autograph=True, experimental_implements=None,
        experimental_autograph_options=None, experimental_relax_shapes=False,
        experimental_compile=None
    )
    
    Used in the guide Used in the tutorials
    Introduction to graphs and functions DeepDream
    Concrete functions Neural style transfer
    Better performance with tf.function Distributed Input
    Using the SavedModel format Pix2Pix
    Ragged tensors Transformer model for language understanding

    通过对func中的TensorFlow操作跟踪编译,创建出一个TensorFlow图(tf.Graph),tf.function构建一个可调用函数,来执行这个图,从而将func当作TensorFlow图实现高效的执行。

    使用实例,

    >>> @tf.function
    ... def f(x, y):
    ...   return x ** 2 + y
    >>> x = tf.constant([2, 3])
    >>> y = tf.constant([3, -2])
    >>> f(x, y)
    
    <tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
    
    • 特征
      func可以使用数据依赖的控制流,包括 ifforwhilebreakcontinuereturn语句:
    >>> @tf.function
    ... def f(x):
    ...   if tf.reduce_sum(x) > 0:
    ...     return x * x
    ...   else:
    ...     return -x // 2
    >>> f(tf.constant(-2))
    
    <tf.Tensor: shape=(), dtype=int32, numpy=1>
    

    func函数的闭合可以包含tf.Tensor和tf.Variable对象,

    >>> @tf.function
    ... def f():
    ...   return x ** 2 + y
    >>> x = tf.constant([-2, -3])
    >>> y = tf.Variable([3, -2])
    >>> f()
    
    <tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
    

    func也可以使用有副作用的ops,比如 tf.print,tf.Variable等,

    >>> v = tf.Variable(1)
    >>> @tf.function
    ... def f(x):
    ...   for i in tf.range(x):
    ...     v.assign_add(i)
    >>> f(3)
    >>> v
    
    <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=4>
    

    重点
    python的任何副作用(list append,print等)只会在func被追踪时执行一次。想要在tf.function中执行副作用,需要以TF ops的形式去写这些代码。比如,

    >>> l = []
    >>> @tf.function
    ... def f(x):
    ...   for i in x:
    ...     l.append(i + 1)    # Caution! Will only happen once when tracing
    >>> f(tf.constant([1, 2, 3]))
    >>> l
    
    [<tf.Tensor 'add:0' shape=() dtype=int32>]
    

    列表l扩展只会在追踪(图构建)时发生一次,使用TF collections (tf.TensorArray)可以实现每次迭代都运行,

    >>> @tf.function
    ... def f(x):
    ...   ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    ...   for i in range(len(x)):
    ...     ta = ta.write(i, x[i] + 1)
    ...   return ta.stack()
    >>> f(tf.constant([1, 2, 3]))
    
    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4])>
    

    tf.function是多态的(tf.function is polymorphic)
    Tensorflow建立的指定形状和类型的图会更加高效。对于不同的数据类型和形状的参数,tf.function可以构建多个图,对它们进行支持。tf.function将任何的纯python数值作为未知对象,然后为它所遇到的每个python参数集合都建立一个独立的图。

    为了获取一个单独图,使用tf.function创建的get_concrete_function方法,它可以被与func相同的参数所调用,返回一个特殊的tf.Graph对象,

    >>> @tf.function
    ... def f(x):
    ...   return x + 1
    >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
    True
    

    注意:将python数值或列表作为参数传递给tf.function,tf.function总是会建立新的图。为了避免总是新建图,将数值参数作为Tensor传递:

    >>> @tf.function
    ... def f(x):
    ...   return tf.abs(x)
    >>> f1 = f.get_concrete_function(1)
    >>> f2 = f.get_concrete_function(2)  # Slow - builds new graph
    >>> f1 is f2
    False
    
    >>> f1 = f.get_concrete_function(tf.constant(1))
    >>> f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
    >>> f1 is f2
    True
    

    只又在参数取很少几个不同值的时候,才使用python数值参数,比如超参数:神经网络中的层数。

    输入签名(Input signatures)

    对于Tensor参数来说,tf.function会为每个独特的输入形状和输入类型的集合,实例化一个单独的图,也就是对于同一类型的输入形状和输入数据类型,只实例化一个图。下买你的例子实例化了2个图,每个都有不同的形状,

    >>> @tf.function
    ... def f(x):
    ...   return x + 1
    >>> vector = tf.constant([1.0, 1.0])
    >>> matrix = tf.constant([[3.0]])
    >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
    False
    

    提供给tf.function的输入签名是可选的而非必须的,以控制正在追踪的图。输入签名使用tf.TensorSpec对象 指定每个Tensor参数的形状和类型,也可以使用更通用的形状。当Tensor具有动态形状时,这可以避免创建多个图。但使用同一个图,同时也限制了可以使用的Tensor大小和数据类型,

    >>> @tf.function(
    ...     input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
    ... def f(x):
    ...   return x + 1
    >>> vector = tf.constant([1.0, 1.0])
    >>> matrix = tf.constant([[3.0]])
    >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
    True
    

    变量也许只创建一次(Variables may only be created once)

    tf.function只允许在它第一次被调用的时候创建tf.Variable对象。

    >>> class MyModule(tf.Module):
    ...   def __init__(self):
    ...     self.v = None
    
    ...   @tf.function
    ...   def call(self, x):
    ...     if self.v is None:
    ...       self.v = tf.Variable(tf.ones_like(x))
    ...     return self.v * x
    

    通常,更推荐的方式是,在tf.function之外创建有状态的对象比如tf.Variable,然后将它们作为参数传递,如:

    v = tf.Variable(1.0)
    
    @tf.function
    def f(x):
      return v.assign_add(x)
    
    print(f(1.0))  # 2.0
    print(f(2.0))  # 4.0
    
    Args
    func the function to be compiled. If func is None, tf.function returns a decorator that can be invoked with a single argument - func. In other words, tf.function(input_signature=...)(func) is equivalent to tf.function(func, input_signature=...). The former can be used as decorator.
    input_signature A possibly nested sequence of tf.TensorSpec objects specifying the shapes and dtypes of the Tensors that will be supplied to this function. If None, a separate function is instantiated for each inferred input signature. If input_signature is specified, every input to func must be a Tensor, and func cannot accept **kwargs.
    autograph Whether autograph should be applied on func before tracing a graph. Data-dependent control flow requires autograph=True. For more information, see the tf.function and AutoGraph guide.
    experimental_implements If provided, contains a name of a "known" function this implements. For example "mycompany.my_recurrent_cell". This is stored as an attribute in inference function, which can then be detected when processing serialized function. See standardizing composite ops for details. For an example of utilizing this attribute see this example The code above automatically detects and substitutes function that implements "embedded_matmul" and allows TFLite to substitute its own implementations. For instance, a tensorflow user can use this attribute to mark that their function also implements embedded_matmul (perhaps more efficiently!) by specifying it using this parameter: @tf.function(experimental_implements="embedded_matmul")
    experimental_autograph_options Optional tuple of tf.autograph.experimental.Feature values.
    experimental_relax_shapes When True, tf.function may generate fewer, graphs that are less specialized on input shapes.
    experimental_compile If True, the function is always compiled by XLA. XLA may be more efficient in some cases (e.g. TPU, XLA_GPU, dense tensor computations).
    Returns
    If func is not None, returns a callable that will execute the compiled function (and return zero or more tf.Tensor objects). If func is None, returns a decorator that, when invoked with a single func argument, returns a callable equivalent to the case above.
    Raises
    ValueError when attempting to use experimental_compile, but XLA support is not enabled.
  • 相关阅读:
    数据结构与算法_20 _ 散列表(下):为什么散列表和链表经常会一起使用?
    数据结构与算法_19 _ 散列表(中):如何打造一个工业级水平的散列表?
    数据结构与算法_17 _ 跳表:为什么Redis一定要用跳表来实现有序集合?
    数据结构与算法_18 _ 散列表(上):Word文档中的单词拼写检查功能是如何实现的?
    数据结构与算法_16 _ 二分查找(下):如何快速定位IP对应的省份地址
    数据结构与算法_15 _ 二分查找(上):如何用最省内存的方式实现快速查找功能
    线程池ThreadPoolExecutor源码详解
    用信鸽来解释 HTTPS
    并发集合类之图解CopyOnWriteArrayList
    认识RabbitMQ从这篇文章开始
  • 原文地址:https://www.cnblogs.com/banluxinshou/p/13305557.html
Copyright © 2020-2023  润新知