• 自定义pass编写


    自定义pass编写

    TVM是一个框架,抽象了机器学习加速器的异质性。有时,用户可能需要自定义一些分析和IR转​​换,使TVM适应自己的专用硬件。本文可帮助用户在TVM中编写自定义pass。

    先决条件

    在阅读本文之前,假设读者已经熟悉以下主题:

    • 在TVM中编写算法并进行调度。否则,请参见示例教程,例如 如何在CPU上优化GEMM
    • HalideIR的基本结构。否则,请参阅HalideIR/src/ir/IR.h以了解定义了IR节点的哪些属性。
    • 访客设计模式。否则,请检查 Python AST模块以查看AST访问者的实现方式。
    • 如何将Schedule降低为IRModule类或LLVM模块。否则,请参考python/tvm/build_module.py以获得一些基础知识。

    import tvm

    from tvm import te

    import numpy as np

    首先编写一个非常简单的矢量加法,并使用默认调度对其进行构建。然后,使用定制的下降通道来直接操纵IR,而不是使用调度原语。

    n = tvm.tir.const(128, "int32")

    a = te.placeholder((n,), name="a")

    b = te.placeholder((n,), name="b")

    c = te.compute((n,), lambda i: a[i] + b[i], name="c")

     

    sch = te.create_schedule(c.op)

    ir = tvm.lower(sch, [a, b, c])

    print(ir)

    输出:

    primfn(a_1: handle, b_1: handle, c_1: handle) -> ()

      attr = {"global_symbol": "main", "tir.noalias": True}

      buffers = {c: Buffer(c_2: Pointer(float32), float32, [128], []),

                 b: Buffer(b_2: Pointer(float32), float32, [128], []),

                 a: Buffer(a_2: Pointer(float32), float32, [128], [])}

      buffer_map = {a_1: a, b_1: b, c_1: c} {

      for (i: int32, 0, 128) {

        c_2[i] = ((float32*)a_2[i] + (float32*)b_2[i])

      }

    }

    写pass

    本质上,“ IR转换遍历”是将语句映射到新语句的功能。因此,定义此向量化函数并逐步实现它。

    TVM已经为用户提供了两类来分析和转换IR。

    IR Vistor访客

    可以用tvm.tir.stmt_functor.post_order_visit(stmt, func)funcfunc来从Halide IR收集信息。这是一个函数回调。在退出当前IR节点之前,即在后订单访问之前,将调用此函数。然后,利用side effects 副作用来存储IR访问的结果,返回值将被忽略。

    必须使用一些数组来存储IR访问的结果。该值甚至是一个变量。这主要是由于Python-C运行时中的限制。每次递归都会刷新变量值,但会保留数组值。

    loops = []

    def find_width8(op):

        """ Find all the 'tir.For' nodes whose extent can be divided by 8. """

        if isinstance(op, tvm.tir.For):

            if isinstance(op.extent, tvm.tir.IntImm):

                if op.extent.value % 8 == 0:

                    loops.append(op)

    IR转换

    转换界面与访问者界面略有不同。访问者中仅存在一个后回调,但是转换访问者既支持前回调又支持后回调。如果要保留原始IR节点,只需返回None。如果要将当前节点更改为某个节点,请使用TVM IR maker界面进行构建并返回此值。

    笔记

    如果调用了预订功能并返回了非“无”的值,则将跳过 post-order 功能。

    def vectorize8(op):

        """ Split can vectorize the loops found in `find_width8`. """

        if op in loops:

            extent = op.extent.value

            name = op.loop_var.name

            lo, li = te.var(name + ".outer"), te.var(name + ".inner")

            body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})

            body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)

            body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)

            return body

        return None

     

     

    @tvm.tir.transform.prim_func_pass(opt_level=0)

    def vectorize(f, mod, ctx):

        global loops

     

        tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)

     

        if not loops:

            return sf

     

        # The last list arugment indicates what kinds of nodes will be transformed.

        # Thus, in this case only `For` nodes will call `vectorize8`

        return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))

    降低胶水Glue to Lowering

    到目前为止,已经完成了编写此IR转换通道的操作。接下来,需要将该pass粘贴到TVM的较低pass上。

    在这种情况下,通过将元组列表作为参数提供给TVM标准降低passtir.add_lower_pass。“元组”表示降低的不同阶段。在TVM中,降级分为四个阶段,每个阶段完成后将调用用户自定义的阶段。

    笔记

    以下是每个阶段完成的基本转换:

    • · 阶段0生成原始IR和环路电平。
    • · 第1阶段将阵列存储平坦化。
    • · 阶段2转换循环,例如展开,向量化和线程绑定。
    • · 第三阶段进行一些清理工作。

    因此,放置此转换过程的好地方就在阶段1之后。

    with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}):

        print(tvm.lower(sch, [a, b, c]))

    输出:

    primfn(a_1: handle, b_1: handle, c_1: handle) -> ()

      attr = {"global_symbol": "main", "tir.noalias": True}

      buffers = {b: Buffer(b_2: Pointer(float32), float32, [128], []),

                 c: Buffer(c_2: Pointer(float32), float32, [128], []),

                 a: Buffer(a_2: Pointer(float32), float32, [128], [])}

      buffer_map = {a_1: a, b_1: b, c_1: c} {

      for (i.outer: int32, 0, 16) {

        c_2[ramp((i.outer*8), 1, 8)] = ((float32x8*)a_2[ramp((i.outer*8), 1, 8)] + (float32x8*)b_2[ramp((i.outer*8), 1, 8)])

      }

    }

    快速浏览

    本文提供了编写自定义IR转换过程的快速视图:-tvm.tir.stmt_functor.post_order_visit用于收集每个IR节点上的信息。-tvm.tir.stmt_functor.ir_transform用于转换IR节点。-总结以上两个内容,编写一个IR转换功能。-用tvm.transform.PassContext将此功能用于TVM降准

    人工智能芯片与自动驾驶
  • 相关阅读:
    python的三大控制机构(ifelse、for、while)
    python 异常处理
    《精通javascript》笔记
    IE6与!important
    point
    js 自制滚动条
    《Head first 设计模式》读书笔记
    Asp.net Webform 数据源绑定控件的扩展(懒人的办法):DropDownList
    Asp.net Binarysoft.Library 数据库通用操作层(Binarysoft.Library.Data)
    Asp.net Webform 从项目的数据库设计说起,什么是一个好的数据库设计。
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14580171.html
Copyright © 2020-2023  润新知