由于“tf.shape”和“Tensor.shape”,我的代码出现问题。 `tf.shape` 和 `Tensor.shape` 都不起作用

问题描述 投票:0回答:2

我已经在 Tensorflow 中从头开始编写了 DETR 对象检测管道。
DETR:Kaggle Notebook 链接:包含所有代码;制作您自己的笔记本副本以重现该问题

我已经测试了管道中的所有单独组件并且它可以工作。
但是当我开始在我的数据集上训练它时(以

tf.data.Dataset
形式)
给我错误的部分代码的链接

ValueError: in user code:

    File "/tmp/ipykernel_19/4115406382.py", line 7, in train_step  *
        y_pred = matcher(y_train, y_pred)
    File "/tmp/ipykernel_19/968499204.py", line 64, in __call__  *
        class_prob, bbox_pred = Matcher.match(class_true, bbox_true, class_prob, bbox_pred)
    File "/tmp/ipykernel_19/968499204.py", line 53, in match  *
        C = Matcher.batched_cost_matrix(class_true, bbox_true, class_prob, bbox_pred)
    File "/tmp/ipykernel_19/968499204.py", line 46, in batched_cost_matrix  *
        tf.range(tf.shape(class_true)[0]), fn_output_signature=tf.float32
    File "/tmp/ipykernel_19/968499204.py", line 22, in compute_cost_matrix  *
        N = tf.shape(class_true)[0]

ValueError: slice index 0 of dimension 0 out of bounds. 
for '{{node map/while/strided_slice_4}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](map/while/Shape, map/while/strided_slice_4/stack, map/while/strided_slice_4/stack_1, map/while/strided_slice_4/stack_2)' 
with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.

当我使用

class_true
打印
tf.shape
的形状时,我得到
Tensor("Shape_2:0", shape=(1,), dtype=int32)
我听不懂。
使用
Tensor.shape
时,它会返回带有
None
的形状,因此我再次收到错误。

但是在单独测试(而不是训练)时打印它,我得到了

class_true
的正确形状,如
tf.Tensor([42], shape=(1,), dtype=int32)

我该如何解决这个问题?

python tensorflow keras tf.keras tf.data.dataset
2个回答
0
投票

您的错误显示在

compute_cost_matrix
函数中。

class_true 张量的形状为 (N,),在映射过程中,您尝试访问其元素,就好像它的形状为 (m, N) 一样。这种形状差异可能会导致出界错误。

在映射成本之前尝试重塑张量:

def compute_cost_matrix(class_true, class_prob, bbox_true, bbox_pred):
    """(N), (N, n_classes), (N, 4), (N, 4)"""
    N = tf.shape(class_true)[0]

    # Ensure class_true has the shape (N,)
    class_true = tf.reshape(class_true, (N,))

我无法检查,因为 Kaggle 不断将我注销 - 期待您的测试运行和更多信息。


0
投票

shape_tensor = tf.shape(your_tensor) shape_tuple = your_tensor.shape

© www.soinside.com 2019 - 2024. All rights reserved.