加载自定义数据集以在 wav2vec2 中进行训练

问题描述 投票:0回答:0

我正在尝试针对具有两列

path
transcription
的自定义数据集微调 wav2vec2。
path
包含 wav 文件的位置。我收到错误:

回溯(最近一次调用最后一次): 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/trainer.py”,第2699行,在training_step中 损失= self.compute_loss(模型,输入) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/trainer.py”,第2731行,在compute_loss中 输出=模型(**输入) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 1684 行,向前 输出 = self.wav2vec2( 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 1306 行,向前 extract_features = self.feature_extractor(input_values) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 433 行,向前 隐藏状态 = 输入值[:, 无] 类型错误:列表索引必须是整数或切片,而不是元组

这是其余的代码:

from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from transformers import Trainer, TrainingArguments
import pandas as pd
import torch
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from torch.utils.data import Dataset
import torchaudio
import warnings

# Filter and suppress all warnings
warnings.filterwarnings("ignore")

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")


class CustomDataset(Dataset):
    def __init__(self, csv_file, tokenizer, processor):
        self.data = pd.read_csv(csv_file)[['path', 'transcription']]
        self.tokenizer = tokenizer
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        audio_path = self.data.iloc[idx]['path']
        transcription = self.data.iloc[idx]['transcription']

        # Load and process the audio file
        waveform, _ = torchaudio.load(audio_path)
        input_values = self.processor(waveform, return_tensors="pt").input_values
        attention_mask = torch.ones_like(input_values)

        # Tokenize the transcription
        labels = self.tokenizer(transcription, return_tensors="pt").input_ids

        return {"input_values": input_values, "attention_mask": attention_mask, "labels": labels}


# Load the custom dataset
train_dataset = CustomDataset('train_data.csv', tokenizer, processor)

# Define the TrainingArguments
training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=500,
    save_total_limit=2,
    learning_rate=1e-4,
    logging_dir="./logs",
    logging_steps=100,
    logging_first_step=True,
    overwrite_output_dir=True,
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=lambda data: {
        "input_values": [item["input_values"] for item in data],
        "attention_mask": [item["attention_mask"] for item in data],
        "labels": [item["labels"] for item in data],
    },
)

# Start training
trainer.train()
python speech-to-text
© www.soinside.com 2019 - 2024. All rights reserved.