This博客介绍了 OpenAI 的新 Python 扩展 Triton,解释了为什么 Triton 可以比 pytorch 更快地进行矩阵数学(参考如何使用 Triton 沿 m × n 矩阵的行计算 Softmax 的示例) )
重要的是,softmax 的这种特殊实现在整个标准化过程中将 X 的行保留在 SRAM 中,这在适用时最大限度地提高了数据重用性(~<32K columns). This differs from PyTorch’s internal CUDA code, whose use of temporary memory makes it more general but significantly slower (below). The bottom line here is not that Triton is inherently better, but that it simplifies the development of specialized kernels that can be much faster than those found in general-purpose libraries.
“临时存储器”是指HBM,或者DRAM,或者VRAM,它是显卡的主存储器。例如,A100 上的 40GB 内存。
SRAM 通常是高速缓存,即片上存储器,其速度明显快于 HBM(片外存储器)。
在 triton 中,您可以使用
tl.load
函数显式地将数据从 HBM 加载到 SRAM,并使用 tl.save
函数显式地将数据从 SRAM 保存到 HBM。
因为 SRAM 速度更快,但比 HBM 更小。开发 Triton 内核的一种常见做法是加载较大矩阵或张量的一小部分,然后执行尽可能多的操作,然后将其保存回 HBM。这个技巧通常被称为“融合”内核。 为了说明上面的说法,这里有一张来自FlashAttention的图片,这是一种融合注意力的算法。