过滤 Tensorflow 数据集中的 NaN 值

问题描述 投票:0回答:3

是否有一种简单的方法可以从

nan
实例中过滤包含
tensorflow.data.Dataset
值的所有条目?就像 Pandas 中的
dropna
方法一样?


简短示例:

import numpy as np
import tensorflow as tf

X = tf.data.Dataset.from_tensor_slices([[1,2,3], [0,0,0], [np.nan,np.nan,np.nan], [3,4,5], [np.nan,3,4]])
y = tf.data.Dataset.from_tensor_slices([np.nan, 0, 1, 2, 3])
ds = tf.data.Dataset.zip((X,y))
ds = foo(ds)  # foo(x) = ?
for x in iter(ds): print(str(x))

我可以使用什么来获得以下输出:

foo(x)

如果您想亲自尝试,
这里是 Google Colab 笔记本

python tensorflow tensorflow2.0 tensorflow-datasets
3个回答
3
投票
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>) (<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

:

tf.reduce_any

filter_nan = lambda x, y: not tf.reduce_any(tf.math.is_nan(x)) and not tf.math.is_nan(y)

ds = tf.data.Dataset.zip((X,y)).filter(filter_nan)

list(ds.as_numpy_iterator())


2
投票

[(array([0., 0., 0.], dtype=float32), 0.0), (array([3., 4., 5.], dtype=float32), 2.0)]



0
投票

这里我假设在你的张量中你有特征和目标变量作为列,y_indx是你的目标的列索引。您也可以使用合适的布尔掩码。

以下函数从张量 X 中删除目标列中具有 nan 值的行。它为删除的行返回一个布尔掩码,但如果您不想保留它,可以从最后一行跳过它。

def any_nan(t): return tf.reduce_sum( tf.cast( tf.math.is_nan(t), tf.int32, ) ) > tf.constant(0) >>> ds_filtered = ds.filter(lambda x, y: not any_nan(x) and not any_nan(y)) >>> for x in iter(ds_filtered): print(str(x)) (<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>) (<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

下面的函数删除缺失值的列。请注意,对于 
def drop_na_rows(X, y_indx): not_nan = tf.math.logical_not(tf.math.is_nan(X[:, y_indx])) return X[not_nan, :], no_nan

参数,您可以使用适合切片张量的布尔掩码或索引,但如果您不提供任何内容,则函数将返回列的布尔掩码。或者你也可以跳过它。

cols_to_drop

© www.soinside.com 2019 - 2024. All rights reserved.