加载线性回归模型以支持 Rust-linfa 中的预测

问题描述 投票:0回答:1

我一直在研究 Rust 机器学习的 linfa,特别是线性回归模型。我希望能够保存和加载我训练过的线性回归模型,但是我无法找到方法来做到这一点。

方法一:

到目前为止,我的方法是获取训练中涉及的主要参数,这些参数可以从 linfa 的线性回归实现中获取,并将它们存储在一个可以存储为 JSON 文件的结构中(通过 serde_json 完成)。然而,在此之后我不知道如何将其加载回来进行训练。

以上详情如下:

存储训练参数的结构:

struct ModelJson {
    coefficients: Vec<f64>,
    intercept: f64,
}

存储过程:

let model = lin_reg.fit(&dataset)?;
let model_json = ModelJson {
    coefficients: model.params().to_vec(),
    intercept: model.intercept(),
};

存储数据的外观:

{"coefficients":[-0.00017907873576254802,-0.00100659702068151,-0.0008275037845519519,0.0004613216043979551,0.0010300634934599436],"intercept":50.525680622870084}

方法2:

关于整个模型的序列化和反序列化,我发现以下信息表明 linfa 中支持相同的操作。 加载和保存模型

这引出了我的第二种方法,其中我使用了 linfa-linear 的 serde 功能(包含 LinearRegression 模型),首先在我的 Cargo.toml 中包含以下内容:

linfa-linear = {version="0.7.0", features=["serde"]}

根据我对实现的理解,此功能为 LinearRegression 实现了以下功能: Serde 序列化和反序列化实现 - 派生

上述实施:

#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
/// A fitted linear regression model which can be used for making predictions.
pub struct FittedLinearRegression<F> {
    intercept: F,
    params: Array1<F>,
}

发现于:linfa-线性导出实现

我的实现如下:

let model = lin_reg.fit(&dataset)?;
let serialized = serde_json::to_string(&model).unwrap();

但是此方法出现以下错误:

the trait bound `FittedLinearRegression<f64>: serde::ser::Serialize` is not satisfied
the following other types implement trait `serde::ser::Serialize`:
  bool
  char
  isize
  i8
  i16
  i32
  i64
  i128
and 133 othersrustcClick for full compiler diagnostic
main.rs(82, 22): required by a bound introduced by this call

是否有另一种方法可以做到这一点,或者有什么方法可以使这些方法之一发挥作用?

json machine-learning rust rust-cargo serde
1个回答
0
投票

在方法 1 中,您已手动将模型参数存储在 JSON 文件中,然后尝试将它们加载回来。它可能无法捕获更复杂模型中的所有必要信息。在方法 2 中,

FittedLinearRegression<f64>
这可能是由于 f64 默认情况下在 serde 中没有实现 Serialize 特征。通过使用
serde_with::serde_as
,您可以为
FittedLinearRegression<f64>
提供自定义序列化实现。

use serde::{Serialize, Deserialize};
use serde_with::serde_as;

#[cfg_attr(feature = "serde", serde_as]
#[derive(Serialize, Deserialize)]
pub struct FittedLinearRegression<F> {
    intercept: F,
    params: Array1<F>,
}

let model = lin_reg.fit(&dataset)?;
let serialized = serde_json::to_string(&model).unwrap();
let deserialized: FittedLinearRegression<f64> = serde_json::from_str(&serialized).unwrap();
© www.soinside.com 2019 - 2024. All rights reserved.