我正在尝试针对具有两列
path
和 transcription
的自定义数据集微调 wav2vec2。 path
包含 wav 文件的位置。我收到错误:
回溯(最近一次调用最后一次): 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/trainer.py”,第2699行,在training_step中 损失= self.compute_loss(模型,输入) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/trainer.py”,第2731行,在compute_loss中 输出=模型(**输入) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 1684 行,向前 输出 = self.wav2vec2( 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 1306 行,向前 extract_features = self.feature_extractor(input_values) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”,第 1501 行,在 _call_impl 中 返回forward_call(*args, **kwargs) 文件“/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py”,第 433 行,向前 隐藏状态 = 输入值[:, 无] 类型错误:列表索引必须是整数或切片,而不是元组
这是其余的代码:
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from transformers import Trainer, TrainingArguments
import pandas as pd
import torch
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
from torch.utils.data import Dataset
import torchaudio
import warnings
# Filter and suppress all warnings
warnings.filterwarnings("ignore")
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")
class CustomDataset(Dataset):
def __init__(self, csv_file, tokenizer, processor):
self.data = pd.read_csv(csv_file)[['path', 'transcription']]
self.tokenizer = tokenizer
self.processor = processor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
audio_path = self.data.iloc[idx]['path']
transcription = self.data.iloc[idx]['transcription']
# Load and process the audio file
waveform, _ = torchaudio.load(audio_path)
input_values = self.processor(waveform, return_tensors="pt").input_values
attention_mask = torch.ones_like(input_values)
# Tokenize the transcription
labels = self.tokenizer(transcription, return_tensors="pt").input_ids
return {"input_values": input_values, "attention_mask": attention_mask, "labels": labels}
# Load the custom dataset
train_dataset = CustomDataset('train_data.csv', tokenizer, processor)
# Define the TrainingArguments
training_args = TrainingArguments(
output_dir="./wav2vec2-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=500,
save_total_limit=2,
learning_rate=1e-4,
logging_dir="./logs",
logging_steps=100,
logging_first_step=True,
overwrite_output_dir=True,
)
# Define the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=lambda data: {
"input_values": [item["input_values"] for item in data],
"attention_mask": [item["attention_mask"] for item in data],
"labels": [item["labels"] for item in data],
},
)
# Start training
trainer.train()