如何微调Deplot模型(VQA) + LLM模型?

问题描述 投票:0回答:1

我试图从 Huggingface 中微调 DePlot 模型(https://huggingface.co/google/deplot)。它能够加载模型并使用图表图像进行测试以转换为表格。 这是结果。 Deplot test image and result

问题是我如何将解码表的结果用于LLM模型。如果可能的话,我想与 T5model 一起使用进行文本求和和问题回答。

任何人都可以帮助我如何将LLM应用到deplot模型上?

我尝试基于huggingface实现图像字幕(图像字幕演示),但我遇到了错误,你能帮我解决这个问题吗?

这是代码

# Load Library
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm, trange
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
from datasets import concatenate_datasets, load_dataset

# Customized Dataset
dataset = load_dataset("/content/drive/MyDrive/sample_dataset", split="train[:90%]")
dataset

>Results:
Dataset({
    features: ['image', 'metadata'],
    num_rows: 54
})

# features 
dataset.features

>{'image': [{'id': Value(dtype='int64', id=None),
   'filename': Value(dtype='string', id=None),
   'width': Value(dtype='int64', id=None),
   'height': Value(dtype='int64', id=None)}],
 'metadata': {'image_id': Value(dtype='int64', id=None),
  'data_category': Value(dtype='string', id=None),
  'chart_source': Value(dtype='string', id=None),
  'chart_color': Value(dtype='string', id=None),
  'chart_multi': Value(dtype='string', id=None),
  'chart_year': Value(dtype='string', id=None),
  'chart_main': Value(dtype='string', id=None),
  'chart_sub': Value(dtype='string', id=None),
  'chart_text': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}}


# Checkpoint, processor, model
checkpoint = "google/deplot"
processor = AutoProcessor.from_pretrained(checkpoint)
model = Pix2StructForConditionalGeneration.from_pretrained(checkpoint)

def transforms(example_batch):
    images = [x for x in example_batch["image"]]
    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images,
                       text=captions,
                       padding=True,
                       max_length=512,
                       truncation=True )
    inputs.update({"labels": inputs["input_ids"]})
    return inputs



train_ds.set_transform(transforms)
test_ds.set_transform(transforms)

# evaluate function
from evaluate import load
import torch

wer = load("wer")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

# Training args
from transformers import TrainingArguments, Trainer

model_name = checkpoint.split("/")[1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-ko-deplot",
    learning_rate=5e-5,
    num_train_epochs=10,
    fp16=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=50,
    remove_unused_columns=False,
    push_to_hub=True,
    label_names=["labels"],
    load_best_model_at_end=True,
)

# set trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)
trainer.train()

> KeyError                                  Traceback (most recent call last)
<ipython-input-58-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

12 frames
<ipython-input-57-fda57d6f6ce9> in transforms(example_batch)
      1 def transforms(example_batch):
      2     images = [x for x in example_batch["image"]]
----> 3     captions = [x for x in example_batch["text"]]
      4     inputs = processor(images=images, 
      5                        text=captions,

KeyError: 'text'

transformer-model large-language-model
1个回答
0
投票

我们可以通过 google/deplot 添加/使用自定义 LLM 吗?

© www.soinside.com 2019 - 2024. All rights reserved.