将 tf tf.data.Dataset 元组拆分为多个数据集

问题描述 投票:0回答:2

我有一个具有以下形状的 tf.data.Dataset:

<ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float64, tf.float64)>

我可以拆分这个数据集以获得两个如下所示的数据集吗:

<Dataset shapes: (None, None, 12), types: tf.float64>
<Dataset shapes: (None, 5), types: tf.float64>
python tensorflow tensorflow-datasets
2个回答
1
投票

您可以使用

map
功能来分割它们。

演示:

import tensorflow as tf

# Create a random tensorflow dataset.
dataset1 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 10, 12]), tf.random.uniform([40, 5]))).batch(16)
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 12, 12]), tf.random.uniform([40, 5]))).batch(16)

dataset = dataset1.concatenate(dataset2)
dataset
>> <ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float32, tf.float32)>

为了分割:

data = dataset.map(lambda x, y: x)
labels = dataset.map(lambda x, y: y)

0
投票

我对上述分割 tf.data.Dataset 元组的方法有一个后续问题

我已经创建了 3D tf.data.Dataset 进行训练,我需要拆分为 train_X 和 train_Y,因为我的主系统需要这种方式。 我使用上面的方法来分割但得到了奇怪的结果。 有人可以发表评论或提供帮助吗? 我不擅长张量流。

import tensorflow as tf
import numpy as np

window_size = 4
batch_size = 5
shuffle_buffer_size = 1000
n_character=6
x_train_All=np.arange(0,window_size*batch_size*n_character)
x_train_All=np.reshape(x_train_All,(window_size*batch_size,n_character))


dataset = tf.data.Dataset.from_tensor_slices(x_train_All)
dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
dataset = dataset.map(lambda window: (window[:-1], window[1:]))
dataset1 = dataset.shuffle(shuffle_buffer_size)
datasetX = dataset1.map(lambda x,y : x)
datasetY = dataset1.map(lambda x,y : y)

dataset_Num_X=[]
dataset_Num_Y=[]
dataset_NumXAfterSplit=[]
dataset_NumYAfterSplit=[]

for element in dataset1.as_numpy_iterator():
    e,f=element
    dataset_Num_X.append(e)
    dataset_Num_Y.append(f)

for window in datasetX.as_numpy_iterator():
    g=window
    dataset_NumXAfterSplit.append(g)

for window in datasetY.as_numpy_iterator():
    g=window
    dataset_NumYAfterSplit.append(g)

根据设计,dataset_Num_X 应与 dataset_NumXAfterSplit 相同,而 dataset_NumYAfterSplit 应与 dataset_Num_Y 相同,但事实并非如此。任何帮助将不胜感激。

最好的,

© www.soinside.com 2019 - 2024. All rights reserved.