使用 torchaudio.transforms.MelSpectrogram 来处理驻留在 GPU 上的张量

问题描述 投票:0回答:1

我想在 GPU 上使用 torchaudio 计算 MelSpectrogram。为了测试,我编写了以下代码:

from typing import Optional

import torch
import torchaudio

import numpy as np

from tests.__init__ import (
    __target_clock__ as TARGET_CLOCK,
    __number_of_test_data_vals__ as NUMBER_OF_TEST_DATA_VALS,
)

# Set general parameters:
TARGET_DEVICE = "CUDA"
TARGET_FREQUENCY: int = 440
NUMBER_OF_FFT_SLOTS: int = 1024
HOP_LENGTH: Optional[int] = None
NUMBER_OF_MEL_SLOTS: int = 128

if __name__ == "__main__":
    target_device = torch.device(
        "cuda" if (TARGET_DEVICE == "CUDA" and torch.cuda.is_available()) else "cpu"
    )
    print(f"Using device {target_device}")
    sampling_vec: np.ndarray = np.arange(NUMBER_OF_TEST_DATA_VALS) / TARGET_CLOCK
    frequency_vec: np.ndarray = np.sin(
        2 * np.pi * TARGET_FREQUENCY * sampling_vec
    ).astype("float32")
    frequency_tensor: torch.Tensor = torch.Tensor(frequency_vec).to(target_device)
    mel_spectrogram: torch.Tensor = torchaudio.transforms.MelSpectrogram(
        sample_rate=TARGET_CLOCK,
        n_fft=NUMBER_OF_FFT_SLOTS,
        hop_length=HOP_LENGTH,
        n_mels=NUMBER_OF_MEL_SLOTS,
    )(frequency_tensor)
    print(f"Obtained MEL-Spectrogram: {mel_spectrogram}")

在 CPU 上运行时(即将

TARGET_DEVICE
设置为
"CUDA"
以外的任何值),代码运行不会出现问题。然而,当尝试使用 CUDA 时,我收到以下错误:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "~\testing_modules\test_melspectrogram_GPU.py", line 39, in <module>
    mel_spectrogram: torch.Tensor = torchaudio.transforms.MelSpectrogram(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torchaudio\transforms\_transforms.py", line 619, in forward
    specgram = self.spectrogram(waveform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torchaudio\transforms\_transforms.py", line 110, in forward
    return F.spectrogram(
           ^^^^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torchaudio\functional\functional.py", line 126, in spectrogram
    spec_f = torch.stft(
             ^^^^^^^^^^^
  File "~\AppData\Local\pypoetry\Cache\virtualenvs\testbed-rg5q6nje-py3.11\Lib\site-packages\torch\functional.py", line 660, in stft
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stft input and window must be on the same device but got self on cuda:0 and window on cpu

我在这里做错了什么,我该怎么做才能在 GPU 上运行 MelSpectrogram?当前的火炬版本是

2.2.2+cu121

python pytorch cuda torchaudio
1个回答
0
投票

如果您正在考虑在 CUDA 上计算此操作,您应该相应地将转换传输到设备:

transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=TARGET_CLOCK,
    n_fft=NUMBER_OF_FFT_SLOTS,
    hop_length=HOP_LENGTH,
    n_mels=NUMBER_OF_MEL_SLOTS,
).to(device)

mel_spectrogram = transform(frequency_tensor)
© www.soinside.com 2019 - 2024. All rights reserved.