实现黎曼梯度的正确方法是什么?

问题描述 投票:0回答:1

这是本文中的一个玩具示例,低秩矩阵和张量训练流形上黎曼优化的自动微分

import torch
import torch.nn as nn

def f(X):
    return torch.sum(X**2)

def g(delta_U, delta_V, U, V, f):
    perturbed_matrix = U @ delta_V.t() + delta_U @ V.t()
    return f(perturbed_matrix)

def compute_riemannian_gradient(X):
    U, S, V = torch.svd(X)
    delta_U = U @ torch.diag(S)
    delta_V = torch.zeros_like(V)
    delta_U.requires_grad_(True)
    delta_V.requires_grad_(True)
    perturbed_value = g(delta_U, delta_V, U, V, f)
    perturbed_value.backward()

    return delta_U.grad, delta_V.grad

def apply_gauge_conditions(delta_U, delta_V, V):
    delta_V -= V @ (V.t() @ delta_V)
    return delta_U, delta_V

def riemannian_gradient(X):
    U, _, V = torch.svd(X)
    delta_U, delta_V = compute_riemannian_gradient(X)
    delta_U, delta_V = apply_gauge_conditions(delta_U, delta_V, V)
    return delta_U @ V.t() + U @ delta_V.t()

X = torch.randn(5, 3)
y = X**2 + 0.1*torch.randn_like(X)
rgrad = riemannian_gradient(X)

for i in range(10):
    rgrad = riemannian_gradient(X)
    X = X - 0.01*rgrad
    # X = retraction(X, rgrad, 0.01)
    print(f(X))

所以你可以看到,在训练推理中,我不需要X的梯度,或者[U,S,V]。相反,我需要 delta_U 和 delta_V 的梯度来更新 X。因此,如果我想将这段代码集成到 torch.optimizer 模块中,我无法简单地循环访问参数中注册的参数。

我的问题是,当权重 X 通过其他参数的梯度更新时,在 optim.step() 函数中实现此优化算法的正确方法是什么?

python optimization pytorch autodiff
1个回答
0
投票

我不明白你所说的“训练推理”是什么意思,总的来说,这个问题需要更清晰。

据我快速浏览论文了解到,它们为您提供了一种特殊的算法来计算低秩矩阵的导数。

在 Torch 中,您可以使用

torch.autograd.Function
定义具有前向和后向方法的自定义函数,如 this gist 中所述。请参阅官方文档以获取广泛的概述。

此类应处理梯度计算步骤,而优化器将使用

[X]
作为要优化的变量列表进行实例化。

© www.soinside.com 2019 - 2024. All rights reserved.