既然 Rust 不允许多个可变引用,类似 PyTorch 的自动微分如何在 Rust 中工作?

问题描述 投票:0回答:1

我主要是一个局外人,试图了解 Rust 是否适合我的项目。

Rust 中有一些框架可以自动微分。具体来说,我认为,candle 和其他一些项目,根据他们的描述,以某种类似于 PyTorch 的方式做到了这一点。

但是,我知道 Rust 不允许多个可变引用。看起来这就是类似 PyTorch 的自动微分所需要的:

x = torch.rand(10) # an array of 10 elements
x.requires_grad = True

y = x.sin()
z = x**2

y
z
都必须保留对
x
的可变引用,因为您可能想要反向传播它们,这将修改
x.grad
。例如:

(y.dot(z)).backwards()
print(x.grad)

既然 Rust 不允许多个可变引用,那么如何在 Rust 中实现类似的行为呢?

rust backpropagation automatic-differentiation
1个回答
0
投票

在 Rust 中提供看似多个可变引用的方法是通过“内部可变性”,它允许通过共享引用进行突变。 Rust 仍然提出了不允许同时发生突变的要求,但是有几种方法可以确保这一点,因此有一些类型通常提供内部可变性:Cell

RefCell
Mutex
RwLock
。它们以 
UnsafeCell
 作为核心原语,告诉编译器 
&
并不一定意味着所包含的值是不可变的。
如果我们查看蜡烛的来源,基本

Tensor

包含一个

Arc
,它允许多个张量“句柄”引用相同的
- Arc提供共享所有权(
): pub struct Tensor(Arc<Tensor_>);

隐藏的内部
Tensor_

类型看起来像这样(

source
): pub struct Tensor_ { id: TensorId, // As we provide inner mutability on the tensor content, the alternatives are: // - Using a mutex, this would have the highest cost when retrieving the storage but would // prevent errors when concurrent access takes place. Mutex would also be subject to // deadlocks for example using the current code if the same tensor is used twice by a single // binary op. // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be // verified dynamically, but the resulting tensors would not be send or sync. // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent // accesses. // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but // that's tricky to encode in the current setup. storage: Arc<RwLock<Storage>>, layout: Layout, op: BackpropOp, is_variable: bool, dtype: DType, device: Device, }

其中方便地有一个注释,为我们权衡了 
storage

内部可变性的选项。

RwLock
是允许多个句柄访问和/或改变张量内容的部分。
因此,当反向传播发生时,它通过获取 

storage

来访问相关张量的

RwLockReadGuard
,以便访问数据来执行操作,然后在对结果张量执行任何操作之前释放这些保护以避免死锁(因为如果在现有守卫被控制的情况下尝试突变,
RwLock
将会阻止)。
看来该库没有利用这种内部可变性,因为它更喜欢创建新的张量而不是改变现有的张量,除非它是一个“应该”更新以反映新数据的变量。在这种情况下,它获取一个 

RwLockWriteGuard

以用新值交换数据,并再次快速释放防护。 很难在蜡烛的源代码中给出准确的行,因为有很多层用于反向传播、存储操作和跟踪结果。我也无法用公式进行具体演示,因为我不太熟悉该主题。但我希望这仍然是清楚的,并且可以帮助您自己创业。

    

© www.soinside.com 2019 - 2024. All rights reserved.