具有任意嵌套Vect的张量定义

问题描述 投票:1回答:1

我正在尝试创建Tensor类型,但是在使用构造函数的类型签名时遇到了麻烦。在thisthis问题中,他们将Tensor定义为Vect s的Tensor,在this问题中将其定义为嵌套Vect s的类型别名,但都不适合我目的。我需要一个Tensor是原子的(它不是由其他Tensor组成的),以及一个独特的类型(它不会因为成为别名而继承方法)。

[我尝试了以下方法,该方法通过Vect从任意嵌套的array_type中隐式提取形状和数据类型,并将其包装为最小的Tensor类型

import Data.Vect

total array_type: (shape: Vect r Nat) -> (dtype: Type) -> Type
array_type [] dtype = dtype
array_type (d :: ds) dtype = Vect d (array_type ds dtype)

data Tensor : (shape: Vect r Nat) -> (dtype: Type) -> Type where
  MkTensor : array_type shape dtype -> Tensor shape dtype

然后,我定义了各种功能以检查其是否正常工作(此处未包括)。所有这些都可以很好地编译,但是当我尝试定义一个将每个元素乘以2的函数时,我陷入了真正的纠结。我试图首先在嵌套的Vect上定义它:

times_two : Num dtype => array_type shape dtype -> array_type shape dtype
times_two (x :: xs) = (times_two x) :: (times_two xs)
times_two x = 2 * x

但我知道

[检查times_two的左侧时:检查Main.times_two的应用程序时:没有歧义,因为没有名称具有合适的类型:Prelude.List。::,Prelude.Stream。::,Data.Vect。::

::替换Data.Vect.::没有帮助。我想做的事可能吗?和明智?

idris
1个回答
1
投票

您无法在array_type shape dtype上进行匹配,因为它不是数据类型。您需要先弄清楚(即匹配)shape是什么,然后该类型才能简化为数据类型。

times_two {shape = []} x = 2 * x
times_two {shape = n :: ns} xs = map times_two xs

((在这种情况下,xs上的匹配项位于map内。)

© www.soinside.com 2019 - 2024. All rights reserved.