我正在尝试将 CSV 转换为 2D 数组,以使用 Linfa 板条箱训练模型。 据我了解,该函数试图同时返回 2 个错误。但我不想创建另一种错误类型。 虽然我的 array_from_csv 方法似乎无法返回可接受的结果,但这是我的代码:
use csv::ReaderBuilder;
use linfa::prelude::*;
use linfa::Dataset;
use linfa_linear::{Result, TweedieRegressor};
use ndarray::prelude::*;
use ndarray::Axis;
use ndarray_csv::{Array2Reader, ReadError};
use std::io::Read;
// Convert CSV bytes into 2D array
fn array_from_csv<R: Read>(
csv: R,
has_headers: bool,
separator: u8,
) -> Result<Array2<f64>, ReadError> {
// parse CSV
let mut reader = ReaderBuilder::new()
.has_headers(has_headers)
.delimiter(separator)
.from_reader(csv);
// extract ndarray
reader.deserialize_array2_dynamic();
}
我打电话到这里:
fn get_dataset() -> Dataset<f64, Ix1> {
let data = include_bytes!("examples/data/AMZN_data.csv");
let data = array_from_csv(&data[..], true, b',').unwrap();
let targets = include_bytes!("examples/data/AMZN_targets.csv");
let targets = array_from_csv(&targets[..], true, b',')
.unwrap()
.column(0)
.to_owned();
let feature_names = vec![
"date",
"open",
"high",
"low",
"close",
"adj_close",
"volume",
];
Dataset::new(data, targets).with_feature_names(feature_names)
}
然后我得到:
error[E0277]: the trait bound `ReadError: linfa::Float` is not satisfied
--> src/linear_regression.rs:15:6
|
15 | ) -> Result<Array2<f64>, ReadError> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `linfa::Float` is not implemented for `ReadError`
|
= help: the following other types implement trait `linfa::Float`:
f32
f64
note: required by a bound in `LinearError`
--> /home/theod/.cargo/registry/src/index.crates.io-6f17d22bba15001f/linfa-linear-0.7.0/src/error.rs:10:25
|
10 | pub enum LinearError<F: Float> {
| ^^^^^ required by this bound in `LinearError`
这是我的 Cargo.toml:
[package]
name = "rust-programs-library"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ta = "0.4.0"
csv = "1.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
linfa = "0.7.0"
linfa-linear = "0.7.0"
ndarray = "0.15.6"
ndarray-csv = "0.5.2"
flate2 = "1.0.28"
[dev-dependencies]
assert_approx_eq = "1.0.0"
bencher = "0.1.5"
rand = "0.6.5"
bincode = "1.3.1"
报告的问题来自
use linfa_linear::Result
,它将您范围内的内置 Result
替换为 linfa_linear
提供的非标准定义。
如果删除该
use
,则会收到不同的错误,单位类型与声明的 Result
之间类型不匹配。这是由于 ;
末尾有多余的 array_from_csv()
造成的。一旦删除它,array_from_csv()
就会编译。
剩下的错误是:
26 | fn get_dataset() -> Dataset<f64, Ix1> {
| ----------------- expected `DatasetBase<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<Dim<[usize; 1]>>, Dim<[usize; 2]>>>` because of return type
...
45 | Dataset::new(data, targets).with_feature_names(feature_names)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `Dim<[usize; 1]>`, found `f64`
|
= note: expected struct `DatasetBase<_, ArrayBase<OwnedRepr<Dim<[usize; 1]>>, Dim<[usize; 2]>>>`
found struct `DatasetBase<_, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>>`
...但我无能为力。