tf数据窗口多输入

问题描述 投票:0回答:1

我想使用tf数据窗口来创建数据集。如何在下面的代码中flat_map批量2个输入?我在网上找到的所有示例都是只有1个输入,我想用tf数据窗口来创建数据集。

import tensorflow as tf

def make_window_dataset(ds, window_size=3, shift=1, stride=1):

  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub, sub2):
    return sub.batch(window_size, drop_remainder=True) # Pls fix here. How to batch 2 param?
  windows = windows.flat_map(sub_to_batch)
  return windows
# 2 input to dataset
ds = tf.data.Dataset.from_tensor_slices(([[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]], [[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]))
# 1 validation data set
v = tf.data.Dataset.from_tensor_slices([1,3,5,7,1,3,5,7])
ds = make_window_dataset(tf.data.Dataset.zip((ds,v))).batch(2).repeat(2)

for example in ds.take(10):
  print('---', example.numpy())

model.fit(ds, ...
tensorflow2.0 tensorflow-datasets
1个回答
0
投票

答案是我应该在sub_to_batch里面用适当的元组压缩。

import tensorflow as tf
tf.compat.v1.enable_v2_behavior()
def make_window_dataset(ds, window_size=3, shift=1, stride=1):

  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub, sub2):
    sub2batch = sub2.batch(window_size, drop_remainder=True)
    return tf.data.Dataset.zip(((sub[0].batch(window_size, drop_remainder=True), sub[1])
                                , (sub2batch, sub2batch)))
  # windows.flat_map(sub_to_batch)
  windows = windows.flat_map(sub_to_batch)
  return windows

ds = tf.data.Dataset.from_tensor_slices((([[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]
                                          , [[2, 3],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]),(
                                         [[1],[3],[5],[7],[1],[3],[5],[7]])))

ds = make_window_dataset(ds)#.batch(2).repeat(2)
print('---sssss')
for example in ds.take(10):
  print('---', example)

model.fit(ds, ...
© www.soinside.com 2019 - 2024. All rights reserved.