我想使用 PyTorch 实现非负矩阵分解。这是我最初的实现:
def nmf(X, k, lr, epochs):
# X: input matrix of size (m, n)
# k: number of latent factors
# lr: learning rate
# epochs: number of training epochs
m, n = X.shape
W = torch.rand(m, k, requires_grad=True) # initialize W randomly
H = torch.rand(k, n, requires_grad=True) # initialize H randomly
# training loop
for i in range(epochs):
# compute reconstruction error
loss = torch.norm(X - torch.matmul(W, H), p='fro')
# compute gradients
loss.backward()
# update parameters using additive update rule
with torch.no_grad():
W -= lr * W.grad
H -= lr * H.grad
W.grad.zero_()
H.grad.zero_()
if i % 10 == 0:
print(f"Epoch {i}: loss = {loss.item()}")
return W.detach(), H.detach()
Lee 和 Seung 在这篇论文中提出使用自适应学习率来避免减法,从而避免负元素的产生。 Here 是我得到一些想法的 stats.SE 线程。但是我不知道如何在pytorch中实现W,H的乘法更新规则,因为它需要分别分离它们梯度的正负部分。 是的,我可以手动实现,但我想将其用于 torch autograd。
知道如何做到这一点吗?提前致谢。
在乘法更新规则中,梯度的正负部分是分开的,更新是根据正负部分的比率计算的。
注:小值eps加到分母上,避免被零除
def nmf(X, k, lr, epochs):
# X: input matrix of size (m, n)
# k: number of latent factors
# lr: learning rate
# epochs: number of training epochs
m, n = X.shape
W = torch.rand(m, k, requires_grad=True) # initialize W randomly
H = torch.rand(k, n, requires_grad=True) # initialize H randomly
eps = 1e-9 # small value to avoid division by zero
# training loop
for i in range(epochs):
# compute reconstruction error
loss = torch.norm(X - torch.matmul(W, H), p='fro')
# compute gradients
W_pos = torch.relu(W) # separate positive and negative parts of W
W_neg = torch.relu(-W)
H_pos = torch.relu(H) # separate positive and negative parts of H
H_neg = torch.relu(-H)
grad_W_pos = torch.matmul((torch.matmul(W_pos, H_pos) - X), H_pos.t())
grad_W_neg = torch.matmul((torch.matmul(W_neg, H_pos) - X), H_pos.t())
grad_H_pos = torch.matmul(W_pos.t(), (torch.matmul(W_pos, H_pos) - X))
grad_H_neg = torch.matmul(W_pos.t(), (torch.matmul(W_pos, H_neg) - X))
# update parameters using multiplicative update rule
W *= torch.sqrt((grad_W_pos + eps) / (grad_W_neg + eps))
H *= torch.sqrt((grad_H_pos + eps) / (grad_H_neg + eps))
if i % 10 == 0:
print(f"Epoch {i}: loss = {loss.item()}")
return W.detach(), H.detach()
但是,在 PyTorch 中为 NMF 实现自适应学习率可能更复杂,可能需要额外的代码