我正在寻找一种在给定 XLA-HLO 计算图时打印运行时间的工具。 我知道有 HLO 成本模型(分析模型)用于打印计算图的算子节点的 FLOP。 但是是否有任何工具可以打印 XLA-HLO 计算图的预期运行时间或运行时间的任何相关值?
我需要它的源代码或示例使用工具。谢谢:)
如果您使用 JAX,则可以使用 提前降低和编译 API 来了解计算的资源消耗情况。例如:
import jax
import numpy as np
def f(M, x):
for i in range(10):
x = M @ x
return x
M = np.random.randn(1000, 1000)
x = np.random.randn(1000)
print(jax.jit(f).lower(M, x).compile().cost_analysis())
[{'bytes accessed': 40080000.0,
'bytes accessed operand 0 {}': 40000000.0,
'bytes accessed operand 1 {}': 40000.0,
'bytes accessed output {}': 40000.0,
'flops': 20000000.0,
'optimal_seconds': 0.0,
'utilization operand 0 {}': 10.0,
'utilization operand 1 {}': 10.0}]
你能告诉我打印计算图算子节点的FLOPs的“HLO成本模型(分析模型)”是什么吗?我真的很需要这个工具,而且我找到它太久了。如果您愿意提供帮助,非常感谢。