这是本文中的一个玩具示例,低秩矩阵和张量训练流形上黎曼优化的自动微分:
import torch
import torch.nn as nn
def f(X):
return torch.sum(X**2)
def g(delta_U, delta_V, U, V, f):
perturbed_matrix = U @ delta_V.t() + delta_U @ V.t()
return f(perturbed_matrix)
def compute_riemannian_gradient(X):
U, S, V = torch.svd(X)
delta_U = U @ torch.diag(S)
delta_V = torch.zeros_like(V)
delta_U.requires_grad_(True)
delta_V.requires_grad_(True)
perturbed_value = g(delta_U, delta_V, U, V, f)
perturbed_value.backward()
return delta_U.grad, delta_V.grad
def apply_gauge_conditions(delta_U, delta_V, V):
delta_V -= V @ (V.t() @ delta_V)
return delta_U, delta_V
def riemannian_gradient(X):
U, _, V = torch.svd(X)
delta_U, delta_V = compute_riemannian_gradient(X)
delta_U, delta_V = apply_gauge_conditions(delta_U, delta_V, V)
return delta_U @ V.t() + U @ delta_V.t()
X = torch.randn(5, 3)
y = X**2 + 0.1*torch.randn_like(X)
rgrad = riemannian_gradient(X)
for i in range(10):
rgrad = riemannian_gradient(X)
X = X - 0.01*rgrad
# X = retraction(X, rgrad, 0.01)
print(f(X))
所以你可以看到,在训练推理中,我不需要X的梯度,或者[U,S,V]。相反,我需要 delta_U 和 delta_V 的梯度来更新 X。因此,如果我想将这段代码集成到 torch.optimizer 模块中,我无法简单地循环访问参数中注册的参数。
我的问题是,当权重 X 通过其他参数的梯度更新时,在 optim.step() 函数中实现此优化算法的正确方法是什么?