TensorFlow 中用于正交匹配追踪的同步逆矩阵计算

问题描述 投票:0回答:1

我目前正在创建一个可以在不同补丁上同时执行的正交匹配追踪(OMP)版本,利用 TensorFlow 的强大功能通过静态图编译进行优化。

我提供了以下算法主循环的单个步骤的伪代码:


support = tf.TensorArray(dtype=tf.int64, size=1, dynamic_size=True)

# For each element in patches, compute the projection on the dictionary using only TensorFlow API
dot_products = tf.abs(tf.matmul(A, patches, transpose_a=True))
max_index = tf.argmax(dot_products, axis=1)
support = support.write(i, max_index + 1)
support = support.write(i + 1, max_index + 1)
idx = support.stack()
non_zero_rows = tf.reduce_all(tf.not_equal(idx, 0), axis=1)

idx = tf.boolean_mask(idx, non_zero_rows) - 1
A_tilde = tf.gather(A, idx, axis=1)
m, n = A.shape
selected_atoms = tf.matmul(A_tilde, A_tilde, transpose_a=True)

获得selected_atoms(由l个大小为nxm的矩阵组成的3D张量)后,我需要解决最小二乘问题

|patches - selected_atoms * x_coeff| ** 2
为此,我需要计算
selected_atoms.T @ selected_atoms
的倒数。有没有一种方法使用 TensorFlow 的 API 同时计算 l 个逆矩阵并将它们存储在形状为 lxnxm 的 3D 张量中?我将非常感谢任何指导。

谢谢你。

python tensorflow linear-algebra matrix-inverse
1个回答
0
投票

要获得伪逆,可以使用 pinv 函数,它可以满足您的要求(对于最小二乘,伪逆是明确定义的,而逆可能不是)。

import tensorflow as tf
x_batch = tf.constant([[[1.,  0.4,  0.5],
                 [0.4, 0.2,  0.25],
                 [0.5, 0.25, 0.35]],
                [[0.5,  0,  0],
                 [0, 0.5,  0],
                 [0, 0, 0.5]]])
tf.linalg.pinv(x_batch)

函数调用返回:

tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[ 4.9999995e+00, -9.9999933e+00, -4.2915344e-06],
        [-9.9999905e+00,  6.6666672e+01, -3.3333347e+01],
        [-6.1988831e-06, -3.3333344e+01,  2.6666681e+01]],

       [[ 2.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  2.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  2.0000000e+00]]], dtype=float32)

如果你只想要回归的系数,则不需要存储pseudo_inverse,但pinv*patches(可能会节省大量内存)请参见:http://www.seas.ucla.edu/~范登贝/133A/lectures/ls.pdf

© www.soinside.com 2019 - 2024. All rights reserved.