我正在尝试实现一个 Tensor 结构,它将保存 ndarray 箱中的数组,其中
T
是元素类型,I
表示维度。该结构体的核心思想是保存一个二维或三维数组。
use std::fmt::Debug;
use ndarray::ArrayBase;
use ndarray::prelude::*;
use ndarray::OwnedRepr;
use ndarray::{Array};
use ndarray::Array3;
use ndarray::Dimension;
struct Tensor<T,I>
where T: Clone + Sync + Send,
I: PartialEq + Debug
{
data: ArrayBase<OwnedRepr<T>,Dim<I>>
}
impl<T,I> Tensor<T,I>
where T: Clone + Sync + Send,
I: PartialEq + Debug
{
pub fn new(arr:ArrayBase<OwnedRepr<T>,Dim<I>>) -> Self {
Self {data: arr }
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
}
fn main() {
let mut temperature = Array3::<f32>::zeros((3, 4, 5));
let shape = temperature.shape();
let t3D:Tensor<f32,[usize;3]> = Tensor::new(temperature);
}
我想创建一个与
shape()
结构关联的 Tensor
方法,它返回驻留在 data
字段中的 ndarray 的形状。
但是
rustc
会抛出一个错误:
error[E0599]: the method `shape` exists for struct `ArrayBase<OwnedRepr<T>, Dim<I>>`, but its
trait bounds were not satisfied
**error[E0599]: the method `shape` exists for struct `ArrayBase<OwnedRepr<T>, Dim<I>>`,
but its trait bounds were not satisfied
--> src/main.rs:36:20
|
36 | self.data.shape()
| ^^^^^ method cannot be called on `ArrayBase<OwnedRepr<T>, Dim<I>>` due to unsatisfied trait bounds**
doesn't satisfy `<_ as DimAdd<Dim<IxDynImpl>>>::Output = Dim<IxDynImpl>`
| doesn't satisfy `<_ as DimAdd<Dim<[usize; 0]>>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimAdd<Dim<[usize; 1]>>>::Output = <Dim<I> as Dimension>::Larger`
| doesn't satisfy `<_ as DimMax<<Dim<I> as Dimension>::Larger>>::Output = <Dim<I> as Dimension>::Larger`
| doesn't satisfy `<_ as DimMax<<Dim<I> as Dimension>::Smaller>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimMax<Dim<I>>>::Output = Dim<I>`
| doesn't satisfy `<_ as DimMax<Dim<IxDynImpl>>>::Output = Dim<IxDynImpl>`
| doesn't satisfy `<_ as DimMax<Dim<[usize; 0]>>>::Output = Dim<I>`
| doesn't satisfy `<_ as Index<usize>>::Output = usize`
| doesn't satisfy `<_ as Mul<usize>>::Output = Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Add>::Output = ndarray::Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Mul>::Output = ndarray::Dim<I>`
| doesn't satisfy `<ndarray::Dim<I> as Sub>::Output = ndarray::Dim<I>`
| doesn't satisfy `_: DimAdd<<Dim<I> as Dimension>::Larger>`
| doesn't satisfy `_: DimAdd<<Dim<I> as Dimension>::Smaller>`
| doesn't satisfy `_: DimMax<<Dim<I> as Dimension>::Larger>`
| doesn't satisfy `_: DimMax<<Dim<I> as Dimension>::Smaller>`
| doesn't satisfy `ndarray::Dim<I>: AddAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: AddAssign`
| doesn't satisfy `ndarray::Dim<I>: Add`
| doesn't satisfy `ndarray::Dim<I>: Clone`
| doesn't satisfy `ndarray::Dim<I>: Default`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<IxDynImpl>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<[usize; 0]>>`
| doesn't satisfy `ndarray::Dim<I>: DimAdd<ndarray::Dim<[usize; 1]>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<IxDynImpl>>`
| doesn't satisfy `ndarray::Dim<I>: DimMax<ndarray::Dim<[usize; 0]>>`
| doesn't satisfy `ndarray::Dim<I>: Dimension`
| doesn't satisfy `ndarray::Dim<I>: Eq`
| doesn't satisfy `ndarray::Dim<I>: IndexMut<usize>`
| doesn't satisfy `ndarray::Dim<I>: Mul<usize>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign<usize>`
| doesn't satisfy `ndarray::Dim<I>: MulAssign`
| doesn't satisfy `ndarray::Dim<I>: Mul`
| doesn't satisfy `ndarray::Dim<I>: Send`
| doesn't satisfy `ndarray::Dim<I>: SubAssign<&'x ndarray::Dim<I>>`
| doesn't satisfy `ndarray::Dim<I>: SubAssign`
| doesn't satisfy `ndarray::Dim<I>: Sub`
| doesn't satisfy `ndarray::Dim<I>: Sync`
| doesn't satisfy `ndarray::Dim<I>: std::ops::Index<usize>`
我尝试将特征边界应用于
struct
以及 impl
但编译器不断抛出错误
所以我有几个问题:
你可以约束
where Dim<I>: Dimension
。
最简单的方法是在
Dim<I>: Dimension
块上添加 impl
的边界(顺便说一句,请注意,您可以删除 struct
本身的边界):
use ndarray::prelude::*;
use ndarray::Array;
use ndarray::Array3;
use ndarray::ArrayBase;
use ndarray::Dimension;
use ndarray::OwnedRepr;
use std::fmt::Debug;
struct Tensor<T, I> {
data: ArrayBase<OwnedRepr<T>, Dim<I>>,
}
impl<T, I> Tensor<T, I>
where
T: Clone + Sync + Send,
I: PartialEq + Debug,
Dim<I>: Dimension,
{
pub fn new(arr: ArrayBase<OwnedRepr<T>, Dim<I>>) -> Self {
Self { data: arr }
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
}
fn main() {
let mut temperature = Array3::<f32>::zeros((3, 4, 5));
let shape = temperature.shape();
let t3D: Tensor<f32, [usize; 3]> = Tensor::new(temperature);
}
这可以从文档中推断出来,因为定义 impl ArrayBase
方法的
shape
块是:
impl<S, A, D> ArrayBase<S, D> where
S: DataOwned<Elem = MaybeUninit<A>>,
D: Dimension,
在您的情况下,
D
是Dim<I>
,所以您想要Dim<I>: Dimension
。