在tensorflow中,如何计算从具有多列NaN的csv生成的批处理中每一列的均值?

问题描述 投票:0回答:1

我正在批量读取csv,每个批次在不同位置都有null。我不想使用tensorflow转换,因为它需要将整个数据加载到内存中。目前,我无法忽略每一列中存在的NaN,而计算意味着我是否要一次对整个批次进行处理。我可以遍历每列,然后以这种方式找到每列的均值,但这似乎是一个不太好的解决方案。

有人可以帮忙找到正确的方法来计算多列中存在NaN的csv批次的每列均值。另外,[1,2,np.nan]应该产生1.5而不是1。

python-3.x tensorflow tensorflow2.0 tensorflow-datasets
1个回答
0
投票

我目前正在这样做:给定张量arank 2 tf.math.divide_no_nan(tf.reduce_sum(tf.where(tf.math.is_finite(a),a,0.),axis=0),tf.reduce_sum(tf.cast(tf.math.is_finite(a),tf.float32),axis=0))

让我知道有人有更好的选择

© www.soinside.com 2019 - 2024. All rights reserved.