如何从numpy数组中获取tensorflow 2中的窗口化数据集?

问题描述 投票:2回答:1

想象我有一些数据:

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])

它看起来像这样:

array([[1, 2, 3, 4],
       [5, 6, 7, 8]])

每一行代表一个不同的观察值,因此不应将它们组合在一起。我想创建一个窗口化的数据集,每个窗口的大小为3,偏移1。当我通过一个观察时,我得到了想要的结果,像这样:

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

结果:

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]

但是当我传递整个numpy数组时,我什么也得不到。

dataset = tf.data.Dataset.from_tensor_slices(some_data)
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

这是我期望的:

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]

我想我可以遍历some_data并一次传递一个数组,然后将数据集连接起来,但这似乎是一个不好的解决方案。什么是正确的方法?

我正在使用Tensorflow 2.0。谢谢!

python tensorflow tensorflow-datasets
1个回答
0
投票

使用dataset = tf.data.Dataset.from_tensor_slices(some_data[0])时数据集的每一行只有一个元素。

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
for x in dataset:
    print(x.numpy())
1
2
3
4

但是当您使用dataset = tf.data.Dataset.from_tensor_slices(some_data)时,数据集的每一行都有四个元素。

dataset = tf.data.Dataset.from_tensor_slices(some_data)
for x in dataset:
    print(x.numpy())
[1 2 3 4]
[5 6 7 8]

所以您需要做的是转换每一行并合并。

import numpy as np
import tensorflow as tf

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])
dataset = tf.data.Dataset.from_tensor_slices(some_data)

def parse_samples(x):
    return tf.data.Dataset.from_tensor_slices(x)\
        .window(size=3, shift=1, drop_remainder=True)\
        .flat_map(lambda window: window.batch(3))

dataset = dataset.flat_map(parse_samples)

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]
© www.soinside.com 2019 - 2024. All rights reserved.