我在尝试修改、分块和重新保存 Huggingface 数据集时收到以下错误。
我想知道是否有人可以提供帮助?
Traceback (most recent call last):
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\main.py", line 39, in <module>
new_dataset = dataset.map(process_row, batched=True, batch_size=1, remove_columns=None)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 602, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 567, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3156, in map
for rank, done, content in Dataset._map_single(**dataset_kwargs):
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3570, in _map_single
writer.write_batch(batch)
File "C:\Users\conno\LegalAIDataset\LegalAIDataset\.venv\Lib\site-packages\datasets\arrow_writer.py", line 571, in write_batch
pa_table = pa.Table.from_arrays(arrays, schema=schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "pyarrow\\table.pxi", line 4642, in pyarrow.lib.Table.from_arrays
File "pyarrow\\table.pxi", line 3922, in pyarrow.lib.Table.validate
File "pyarrow\\error.pxi", line 91, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Column 1 named type expected length 44 but got length 21
我的最小可重现代码如下:
import datasets
from datasets import load_dataset, Dataset
from semantic_text_splitter import TextSplitter
# Step 1: Load the existing dataset
dataset = load_dataset('HF_Dataset')
# Slice the 'train' split of the dataset
sliced_data = dataset['train'][:100]
# Convert the sliced data back into a Dataset object
dataset = Dataset.from_dict(sliced_data)
def chunk_text(text_list, metadata):
splitter = TextSplitter(1000)
chunks = [chunk for text in text_list for chunk in splitter.chunks(text)]
return {"text_chunks": chunks, **metadata}
# Define a global executor
#executor = ThreadPoolExecutor(max_workers=1)
def process_row(batch):
# Initialize a dictionary to store the results
results = {k: [] for k in batch.keys()}
results['text_chunks'] = [] # Add 'text_chunks' key to the results dictionary
# Process each row in the batch
for i in range(len(batch['text'])):
# Apply the chunk_text function to the text
chunks = chunk_text(batch['text'][i], {k: v[i] for k, v in batch.items() if k != 'text'})
# Add the results to the dictionary
for k, v in chunks.items():
results[k].extend(v)
# Return the results
return results
# Apply the function to the dataset
new_dataset = dataset.map(process_row, batched=True, batch_size=1, remove_columns=None)
# Save and upload the new dataset
new_dataset.to_json('dataset.jsonl')
dataset_dict = datasets.DatasetDict({"split": new_dataset})
# dataset_dict.save_to_disk("", format="json")
# dataset_dict.upload_to_hub("", "This is a test dataset")
我期望代码对数据集进行分块,保留元数据并将其保存为 .jsonl 文件。
相反,我收到了上述错误。
如果其他人也遇到此问题,我已设法解决它。
请参阅下面的我的工作代码。
def process_row(batch):
# Initialize lists to store the processed rows
processed_texts = []
processed_metadata = {k: [] for k, v in batch.items() if k != 'text'}
# Process each row in the batch
for i in range(len(batch['text'])):
# Apply the chunk_text function to the text
text = batch['text'][i]
metadata = {k: v[i] for k, v in batch.items() if k != 'text'}
chunks = chunk_text([text], metadata)['text_chunks']
# For each chunk, add it to the processed_texts and replicate the metadata
for chunk in chunks:
processed_texts.append(chunk)
for k in processed_metadata:
processed_metadata[k].append(metadata[k])
# Combine processed texts and metadata into a single dictionary
results = {'text': processed_texts}
for k, v in processed_metadata.items():
results[k] = v
return results