使用 arrow-rs 加载数据时进行单热编码

问题描述 投票:0回答:1

在我的 Rust 项目中,我从 Mongo 加载文档并将它们反序列化为 serde_json 值:

match cursor.deserialize_current() {
    Ok(d) => {
        let doc = serde_json::to_value(&d).unwrap();
        doc_vec.push(doc);
    }

之后,我使用解码器创建了一个箭头

RecordBatch

let mut decoder = ReaderBuilder::new(schema.clone()).build_decoder().unwrap();
if !doc_vec.is_empty() {
    decoder.serialize(&doc_vec).unwrap();
    let batch = decoder.flush().unwrap().unwrap();

我的架构是:

let schema = Schema::new(vec![
    Field::new("Amount", DataType::Float32, false),
    Field::new(
        "Country",
        DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
        false,
    ),
]);

代码失败并显示:

called `Result::unwrap()` on an `Err` value: NotYetImplemented("Support for Dictionary(UInt16, Utf8) in JSON reader")called `Result::unwrap()` on an `Err` value: NotYetImplemented("Support for Dictionary(UInt16, Utf8) in JSON reader")

当我通过 arrow Flight 将国家/地区发送到 pyarrow 客户端时,我希望对该国家/地区进行单热编码,然后将其转换为 Pandas 数据帧。

您能指导我如何从这里继续吗?我对所有使用的技术都很陌生。

rust one-hot-encoding apache-arrow
1个回答
0
投票

解决方法是将列读取为

Utf8
,然后使用
cast
内核将其转换为字典编码。

根据我的理解,one-hot 编码与字典编码不同。您可以通过使用比较内核,与不同的“国家/地区”值进行比较来获得单热编码布尔列。

© www.soinside.com 2019 - 2024. All rights reserved.