我正在 Google colab 中使用 Transformers 库,当我使用 Transformers 库中的 TrainingArguments 时,我收到此代码的导入错误:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
model = AutoModelForImageClassification.from_pretrained(
checkpoint,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="my_awesome_food_model",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=3,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=food["train"],
eval_dataset=food["test"],
tokenizer=image_processor,
compute_metrics=compute_metrics,
)
trainer.train()
这是我遇到的错误
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
<ipython-input-26-cfdcb4612c2d> in <cell line: 1>()
----> 1 training_args = TrainingArguments(
2 output_dir="my_awesome_food_model",
3 remove_unused_columns=False,
4 evaluation_strategy="epoch",
5 save_strategy="epoch",
4 frames
/usr/local/lib/python3.10/dist-packages/transformers/training_args.py in _setup_devices(self)
1785 if not is_sagemaker_mp_enabled():
1786 if not is_accelerate_available(min_version="0.20.1"):
-> 1787 raise ImportError(
1788 "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
1789 )
ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`
我已经尝试过一次又一次安装加速,但每次都出现同样的错误。 同样使用虚拟环境仍然没有区别。 是的,Accelerte 和 Transformers 的版本都是最新的。
transformers.__version__, accelerate.__version__
这些是版本 -> ('4.35.2', '0.27.2') 我使用的是Colab环境
我能够复制这个问题,然后通过将运行时类型更改为 T4 GPU 来解决它。如果您使用的是 Colab,可以通过右上角的设置来完成此操作。请参阅此处