我从这里运行了以下基准测试。
#!/usr/bin/env python3
import torch
def batched_dot_mul_sum(a, b):
'''Computes batched dot by multiplying and summing'''
return a.mul(b).sum(-1)
def batched_dot_bmm(a, b):
'''Computes batched dot by reducing to ``bmm``'''
a = a.reshape(-1, 1, a.shape[-1])
b = b.reshape(-1, b.shape[-1], 1)
return torch.bmm(a, b).flatten(-3)
# Input for benchmarking
x = torch.randn(10000, 64)
# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))
import timeit
t0 = timeit.Timer(
stmt='batched_dot_mul_sum(x, x)',
setup='from __main__ import batched_dot_mul_sum',
globals={'x': x})
t1 = timeit.Timer(
stmt='batched_dot_bmm(x, x)',
setup='from __main__ import batched_dot_bmm',
globals={'x': x})
print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')
我得到了
mul_sum(x, x): 1065.9 us
bmm(x, x): 134.5 us
在 Mac 上
和
mul_sum(x, x): 52.3 us
bmm(x, x): 120.1 us
在 Linux CPU 上
我看到了巨大的性能差异,这是预期的吗?
我第一次在一个更严肃的程序上注意到这种差异,并试图在这里复制它。
性能差异可归因于几个明显影响计算执行的关键因素:
优化进度:PyTorch 对 Apple Silicon 架构的适配仍在完善中。 M1 Pro 的优化可能不如 Linux 系统中常见的成熟 x86_64 CPU 的优化成熟。因此,代码执行效率目前可能有利于更成熟的架构。
特定于架构的调整:PyTorch 及其依赖的 BLAS 和 LAPACK 库针对特定处理器架构进行了复杂的调整。对于 M1 Pro 的 ARM 架构和 Linux CPU 的典型 x86_64 架构,这些优化可能不会同样完善。因此,针对 x86_64 优化的操作可能无法在 M1 Pro 上高效执行。
指令集变化:指令集架构显着影响各种操作的执行效率。与 Linux 系统中的传统 x86_64 处理器相比,基于 ARM 的处理器(例如 M1 Pro)部署了独特的指令集。指令集的这种差异本质上会导致执行不同任务的效率不同,从而影响观察到的性能差异。
John Zavialov answer 回顾了一般性问题,我将在这里简要列出它们,但我主要是响应赏金,所以这部分基本上将总结该答案并探讨如何加快速度
1.优化进度:PyTorch对Apple Silicon架构的适配还在完善中,还没有Linux的设置那么成熟
特定于架构的调整:PyTorch 是针对特定架构设置的,这意味着它可能不会在每个系统上都具有稳定的性能。
指令集变化:指令集架构显着影响各种操作的执行效率,基于 ARM 的系统(M1 Pros)与 x86_84(linux)不同,这可能导致性能上的巨大差异
为了解决这个问题或加速这个过程。 Mac 添加了新的Metal Performance Shader。如果您为 Pytorch on mac 激活此功能,您应该会看到性能提升。您可以在链接中看到安装说明,我写了一个代码示例来测试并激活它
import torch
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
x = torch.ones(1, device=mps_device)
print (x)
else:
print ("MPS device not found.")