我试图弄清楚如何从长度可变的复数值张量中提取顺序斑块。提取正在作为tf.data
管道的一部分执行。
如果张量不复杂,我将像在此tf.image.extract_image_patches
中那样使用answer。
但是,该函数不适用于复杂的张量。我尝试了以下技术,但由于张量的长度未知,因此失败了。
def extract_sequential_patches(image):
image_length = tf.shape(image)[0]
num_patches = image_length // (128 // 4)
patches = []
for i in range(num_patches):
start = i * 128
end = start + 128
patches.append(image[start:end, ...])
return tf.stack(patches)
但是我得到了错误:
InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)
我曾尝试用@tf.function
进行自由装饰
[我认为您需要调整索引的计算以确保它们不会超出范围,但是撇开这些细节,除了使用Python外,您的代码几乎是tf.function
所期望的清单您需要改用TensorArray。
类似这样的方法应该起作用(索引计算可能不完全正确):
@tf.function
def extract_sequential_patches(image, size, stride):
image_length = tf.shape(image)[0]
num_patches = (image_length - size) // stride + 1
patches = tf.TensorArray(image.dtype, size=num_patches)
for i in range(num_patches):
start = i * stride
end = start + size
patches = patches.write(i, image[start:end, ...])
return patches.stack()
您可以在autograph reference docs中找到有关为什么Python列表当前不起作用的更多详细信息。
就是说,如果对extract_image_patches的内核进行了优化,则使用real / imag技巧可能会更快。我建议测试两种方法。