如何从SequenceExample TFRecord创建窗口化多变量数据集?

问题描述 投票:1回答:1

我正试图使用tf.data.datasets设置一个Tensorflow管道,以便将一些TFRecord加载到Keras模型中。这些数据是多变量时间序列。

我目前使用的是Tensorflow 2.0。

首先,我从TFRecord中获取我的数据集,并对其进行解析。

dataset = tf.data.TFRecordDataset('...')

context_features = {...}
sequence_features = {...}

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return sequence


dataset = dataset.map(_parse_function)

现在的问题是,它给我一个MapDataset,里面有EagerTensor的dict。

for data in dataset.take(3):
  print(type(data))

<class 'dict'>
<class 'dict'>
<class 'dict'>

# which look like : {feature1 : EagerTensor, feature2 : EagerTensor ...}

因为这些字典,我似乎无法让这些数据被分批、洗牌......以便之后在LSTM层中使用它们。例如这个:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.values().batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows

ds = make_window_dataset(dataset, 10)

gives me :

AttributeError: 'dict_values' object has no attribute 'batch'

谢谢你的帮助。我的研究是基于这个和其他Tensorflow帮助程序。

https:/www.tensorflow.orgguidedata#time_series_windowing

EDIT :

我找到了问题的解决方法。我最终在我的解析函数中使用stack将解析给出的字典转换为一个(None,11)形状的Tensor。

def _parse_function(example_proto):
  # Parse the input `tf.Example` proto using the dictionary above.
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return tf.stack(list(sequence.values()), axis=1)
python-3.x tensorflow-datasets tensorflow2.0 tf.keras
1个回答
0
投票

在这里提供解决方案(答案部分),即使它存在于问题部分,为社区的利益。

将字典转换为具有形状的张量文件 (None,11) 使用 tf.stackparse_function 已经解决了这个问题。

将代码从

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return sequence

def _parse_function(example_proto):
  _, sequence =  tf.io.parse_single_sequence_example(example_proto,context_features, sequence_features)
  return tf.stack(list(sequence.values()), axis=1)
© www.soinside.com 2019 - 2024. All rights reserved.