为什么triton语言比pytorch更快?

问题描述 投票:0回答:1

This博客介绍了 OpenAI 的新 Python 扩展 Triton,解释了为什么 Triton 可以比 pytorch 更快地进行矩阵数学(参考如何使用 Triton 沿 m × n 矩阵的行计算 Softmax 的示例) )

重要的是,softmax 的这种特殊实现在整个标准化过程中将 X 的行保留在 SRAM 中,这在适用时最大限度地提高了数据重用性(~<32K columns). This differs from PyTorch’s internal CUDA code, whose use of temporary memory makes it more general but significantly slower (below). The bottom line here is not that Triton is inherently better, but that it simplifies the development of specialized kernels that can be much faster than those found in general-purpose libraries.

  1. pytorch如何为设备张量分配内存,这里所说的“临时内存”是什么?为什么这种临时存储器的使用比较普遍,但比使用SRAM慢?
  2. 这里的SRAM是指高速缓冲存储器吗?如果是这样,这个库如何/为什么比 pytorch 内部更好地利用缓存内存?我的理解是,缓存哪些数据主要取决于硬件而不是软件。
pytorch nvidia openai-api
1个回答
0
投票

“临时存储器”是指HBM,或者DRAM,或者VRAM,它是显卡的主存储器。例如,A100 上的 40GB 内存。

SRAM 通常是高速缓存,即片上存储器,其速度明显快于 HBM(片外存储器)。

在 triton 中,您可以使用

tl.load
函数显式地将数据从 HBM 加载到 SRAM,并使用
tl.save
函数显式地将数据从 SRAM 保存到 HBM。

因为 SRAM 速度更快,但比 HBM 更小。开发 Triton 内核的一种常见做法是加载较大矩阵或张量的一小部分,然后执行尽可能多的操作,然后将其保存回 HBM。这个技巧通常被称为“融合”内核。 为了说明上面的说法,这里有一张来自FlashAttention的图片,这是一种融合注意力的算法。

© www.soinside.com 2019 - 2024. All rights reserved.