使用元组输入(Tuple Inputs)进行计算和归约
备注
单击 此处 下载完整的示例代码
作者:Ziheng Jiang
若要在单个循环中计算具有相同 shape 的多个输出,或执行多个值的归约,例如 argmax
。这些问题可以通过元组输入来解决。
本教程介绍了 TVM 中元组输入的用法。
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
描述批量计算
对于 shape 相同的算子,若要在下一个调度过程中一起调度,可以将它们放在一起作为 te.compute
的输入。
n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
A1 = te.placeholder((m, n), name="A1")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")
# 生成的 IR 代码:
s = te.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
输出结果:
@main = primfn(A0_1: handle, A1_1: handle, B_2: handle, B_3: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A0: Buffer(A0_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
A1: Buffer(A1_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
B: Buffer(B_4: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto"),
B_1: Buffer(B_5: Pointer(float32), float32, [(stride_3: int32*m)], [], type="auto")}
buffer_map = {A0_1: A0, A1_1: A1, B_2: B, B_3: B_1}
preflattened_buffer_map = {A0_1: A0_3: Buffer(A0_2, float32, [m, n: int32], [stride, stride_4: int32], type="auto"), A1_1: A1_3: Buffer(A1_2, float32, [m, n], [stride_1, stride_5: int32], type="auto"), B_2: B_6: Buffer(B_4, float32, [m, n], [stride_2, stride_6: int32], type="auto"), B_3: B_7: Buffer(B_5, float32, [m, n], [stride_3, stride_7: int32], type="auto")} {
for (i: int32, 0, m) {
for (j: int32, 0, n) {
B[((i*stride_2) + (j*stride_6))] = (A0[((i*stride) + (j*stride_4))] + 2f32)
B_1[((i*stride_3) + (j*stride_7))] = (A1[((i*stride_1) + (j*stride_5))]*3f32)
}
}
}