Tensorflow参差不齐的张量中的掩码值

问题描述 投票:1回答:1

我的张量参差不齐:

tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 15938], [126], [10135], [17665]]], dtype=tf.int32)

我想将长度大于1的行中的元素值设置为特定值。例如:

tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 0], [126], [10135], [17665]]], dtype=tf.int32)

我如何在Tensorflow中表达这种转变?

tensorflow tensorflow2.0
1个回答
1
投票

参差不齐的张量总是使事情变得棘手,但这是一种可能的实现:

import tensorflow as tf

# Using an intermediate NumPy array avoids having the second dimension as ragged
a = tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 15938],
                         [126], [10135], [17665]]], dtype=tf.int32)
# Index from which values are replaced
replace_from_idx = 1
# Replacement value
new_value = 0

# Get size of each element in the last dimension
s = a.row_lengths(axis=-1)
# Make ragged ranges
r = tf.ragged.range(s.flat_values)
# Un-flatten
r = tf.RaggedTensor.from_row_lengths(r, a.row_lengths(1))
# Replace values
m = tf.dtypes.cast(r < replace_from_idx, a.dtype)
out = a * m + new_value * (1 - m)
print(out.to_list())
# [[[17712], [16753], [11850], [13028], [10155], [15734, 0], [126], [10135], [17665]]]
© www.soinside.com 2019 - 2024. All rights reserved.