在for循环之外的批量数据中查找第一次出现的索引的有效方法

问题描述 投票:1回答:1

我正在做一个任务,其中我有批量存储的帧形式的数据。批量的维度就像(batch_size,400),我想找到每400个长度帧中第一次出现的数字1的索引。

目前我使用for循环批量大小但由于数据非常大,因此非常耗时

在tensorflow或numpy中使用某些矩阵运算的任何其他Efficient方法都会

numpy for-loop tensorflow time-complexity batch-processing
1个回答
0
投票

在TensorFlow中:

import tensorflow as tf

def index_of_first_tf(batch, value):
    eq = tf.equal(batch, value)
    has_value = tf.reduce_any(eq, axis=-1)
    _, idx = tf.math.top_k(tf.cast(eq, tf.int8))
    idx = tf.squeeze(idx, -1)
    return tf.where(has_value, idx, -tf.ones_like(idx))

在NumPy中:

import numpy as np

def index_of_first_np(batch, value):
    eq = np.equal(batch, value)
    has_value = np.any(eq, axis=-1)
    idx = np.argmax(eq, axis=-1)
    idx[~has_value] = -1
    return idx

测试:

import tensorflow as tf

batch = [[0, 1, 2, 3],
         [1, 2, 1, 0],
         [0, 2, 3, 4]]
value = 1

print(index_of_first_np(batch, value))
# [ 1  0 -1]

with tf.Graph().as_default(), tf.Session() as sess:
    print(sess.run(index_of_first_tf(batch, value)))
    # [ 1  0 -1]
© www.soinside.com 2019 - 2024. All rights reserved.