为什么 pytorch 在 m1 pro 10 核上比 Linux CPU 上慢?

问题描述 投票:0回答:2

我从这里运行了以下基准测试。

#!/usr/bin/env python3

import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

我得到了

mul_sum(x, x):  1065.9 us
bmm(x, x):      134.5 us

在 Mac 上

mul_sum(x, x):   52.3 us
bmm(x, x):      120.1 us

在 Linux CPU 上

我看到了巨大的性能差异,这是预期的吗?

我第一次在一个更严肃的程序上注意到这种差异,并试图在这里复制它。

linux pytorch profiling apple-m1
2个回答
4
投票

性能差异可归因于几个明显影响计算执行的关键因素:

  1. 优化进度:PyTorch 对 Apple Silicon 架构的适配仍在完善中。 M1 Pro 的优化可能不如 Linux 系统中常见的成熟 x86_64 CPU 的优化成熟。因此,代码执行效率目前可能有利于更成熟的架构。

  2. 特定于架构的调整:PyTorch 及其依赖的 BLAS 和 LAPACK 库针对特定处理器架构进行了复杂的调整。对于 M1 Pro 的 ARM 架构和 Linux CPU 的典型 x86_64 架构,这些优化可能不会同样完善。因此,针对 x86_64 优化的操作可能无法在 M1 Pro 上高效执行。

  3. 指令集变化:指令集架构显着影响各种操作的执行效率。与 Linux 系统中的传统 x86_64 处理器相比,基于 ARM 的处理器(例如 M1 Pro)部署了独特的指令集。指令集的这种差异本质上会导致执行不同任务的效率不同,从而影响观察到的性能差异。


0
投票

John Zavialov answer 回顾了一般性问题,我将在这里简要列出它们,但我主要是响应赏金,所以这部分基本上将总结该答案并探讨如何加快速度

1.优化进度:PyTorch对Apple Silicon架构的适配还在完善中,还没有Linux的设置那么成熟

特定于架构的调整:PyTorch 是针对特定架构设置的,这意味着它可能不会在每个系统上都具有稳定的性能。

指令集变化:指令集架构显着影响各种操作的执行效率,基于 ARM 的系统(M1 Pros)与 x86_84(linux)不同,这可能导致性能上的巨大差异

为了解决这个问题或加速这个过程。 Mac 添加了新的Metal Performance Shader。如果您为 Pytorch on mac 激活此功能,您应该会看到性能提升。您可以在链接中看到安装说明,我写了一个代码示例来测试并激活它


import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

© www.soinside.com 2019 - 2024. All rights reserved.