我一直在研究 Rust 机器学习的 linfa,特别是线性回归模型。我希望能够保存和加载我训练过的线性回归模型,但是我无法找到方法来做到这一点。
到目前为止,我的方法是获取训练中涉及的主要参数,这些参数可以从 linfa 的线性回归实现中获取,并将它们存储在一个可以存储为 JSON 文件的结构中(通过 serde_json 完成)。然而,在此之后我不知道如何将其加载回来进行训练。
以上详情如下:
存储训练参数的结构:
struct ModelJson {
coefficients: Vec<f64>,
intercept: f64,
}
存储过程:
let model = lin_reg.fit(&dataset)?;
let model_json = ModelJson {
coefficients: model.params().to_vec(),
intercept: model.intercept(),
};
存储数据的外观:
{"coefficients":[-0.00017907873576254802,-0.00100659702068151,-0.0008275037845519519,0.0004613216043979551,0.0010300634934599436],"intercept":50.525680622870084}
关于整个模型的序列化和反序列化,我发现以下信息表明 linfa 中支持相同的操作。 加载和保存模型
这引出了我的第二种方法,其中我使用了 linfa-linear 的 serde 功能(包含 LinearRegression 模型),首先在我的 Cargo.toml 中包含以下内容:
linfa-linear = {version="0.7.0", features=["serde"]}
根据我对实现的理解,此功能为 LinearRegression 实现了以下功能: Serde 序列化和反序列化实现 - 派生
上述实施:
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A fitted linear regression model which can be used for making predictions.
pub struct FittedLinearRegression<F> {
intercept: F,
params: Array1<F>,
}
发现于:linfa-线性导出实现
我的实现如下:
let model = lin_reg.fit(&dataset)?;
let serialized = serde_json::to_string(&model).unwrap();
但是此方法出现以下错误:
the trait bound `FittedLinearRegression<f64>: serde::ser::Serialize` is not satisfied
the following other types implement trait `serde::ser::Serialize`:
bool
char
isize
i8
i16
i32
i64
i128
and 133 othersrustcClick for full compiler diagnostic
main.rs(82, 22): required by a bound introduced by this call
是否有另一种方法可以做到这一点,或者有什么方法可以使这些方法之一发挥作用?
在方法 1 中,您已手动将模型参数存储在 JSON 文件中,然后尝试将它们加载回来。它可能无法捕获更复杂模型中的所有必要信息。在方法 2 中,
FittedLinearRegression<f64>
这可能是由于 f64 默认情况下在 serde 中没有实现 Serialize 特征。通过使用 serde_with::serde_as
,您可以为 FittedLinearRegression<f64>
提供自定义序列化实现。
use serde::{Serialize, Deserialize};
use serde_with::serde_as;
#[cfg_attr(feature = "serde", serde_as]
#[derive(Serialize, Deserialize)]
pub struct FittedLinearRegression<F> {
intercept: F,
params: Array1<F>,
}
let model = lin_reg.fit(&dataset)?;
let serialized = serde_json::to_string(&model).unwrap();
let deserialized: FittedLinearRegression<f64> = serde_json::from_str(&serialized).unwrap();