Tensorflow:从任意长度的复数张量中提取顺序斑块

问题描述 投票:1回答:1

我试图弄清楚如何从长度可变的复数值张量中提取顺序斑块。提取正在作为tf.data管道的一部分执行。

如果张量不复杂,我将像在此tf.image.extract_image_patches中那样使用answer

但是,该函数不适用于复杂的张量。我尝试了以下技术,但由于张量的长度未知,因此失败了。

def extract_sequential_patches(image):
    image_length = tf.shape(image)[0]
    num_patches = image_length // (128 // 4)
    patches = []
    for i in range(num_patches):
        start = i * 128
        end = start + 128
        patches.append(image[start:end, ...])
    return tf.stack(patches)

但是我得到了错误:

InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)

我曾尝试用@tf.function进行自由装饰

python tensorflow tensorflow2.0 tensorflow-datasets
1个回答
0
投票

[我认为您需要调整索引的计算以确保它们不会超出范围,但是撇开这些细节,除了使用Python外,您的代码几乎是tf.function所期望的清单您需要改用TensorArray。

类似这样的方法应该起作用(索引计算可能不完全正确):

@tf.function
def extract_sequential_patches(image, size, stride):
    image_length = tf.shape(image)[0]
    num_patches = (image_length - size) // stride + 1
    patches = tf.TensorArray(image.dtype, size=num_patches)
    for i in range(num_patches):
        start = i * stride
        end = start + size
        patches = patches.write(i, image[start:end, ...])
    return patches.stack()

您可以在autograph reference docs中找到有关为什么Python列表当前不起作用的更多详细信息。

就是说,如果对extract_image_patches的内核进行了优化,则使用real / imag技巧可能会更快。我建议测试两种方法。

© www.soinside.com 2019 - 2024. All rights reserved.