如何只选择一个Tensorflow数据集的一部分,并改变维度。

问题描述 投票:1回答:1

我希望在UCF101的10帧片段上训练我的模型,没有任何标签。目前我有这个。

import tensorflow as tf
import tensorflow_datasets as tfds
x_train = tfds.load('ucf101', split='train', shuffle_files=True, batch_size = 64)
>>> print(x_train)
<_OptionsDataset shapes: {label: (None,), video: (None, None, 256, 256, 3)}, types: {label: tf.int64, video: tf.uint8}>

我希望数据集的维度是(无, 10, 256, 256, 3), 而不包括标签.

编辑:我试着用 lambda 表达式在 .map()但这产生了一个错误。

new_x_train = x_train.map(lambda x: tf.py_function(func=lambda y: tf.convert_to_tensor(sample(y.numpy().tolist(), 10), dtype=uint8), inp=[x['video']], Tout=tf.uint8))
NameError: name 'sample' is not defined
python tensorflow tensorflow2.0 tensorflow-datasets
1个回答
1
投票

原谅我的大概答案,因为我不会下载6GB的数据集来测试我的答案。

你为什么不在迭代数据集的时候直接选择视频。

next(iter(x_train))['video']

要选择尺寸,你可以用普通的 numpy 索引。这将是一个与 mnist:

import tensorflow_datasets as tfds

data = tfds.load('mnist', split='train', batch_size=16)
<PrefetchDataset shapes: {image: (None, 28, 28, 1), 
    label: (None,)}, types: {image: tf.uint8, label: tf.int64}>

现在我们只选择 image,并选择前10个观测值。

dim = lambda x: x['image'][:10, ...]

next(iter(data.map(dim))).shape
TensorShape([10, 28, 28, 1])

看看我是如何删除一个 None 中的形状,并进行简单的索引。


0
投票

解决这个问题的方法是简单地在其他地方下载数据集文件,所以我的目录中有一个.avi文件的列表,然后在tensorflow之外对这些文件进行预处理。我使用了cv2库和下面的代码,其中我借用了其他地方的两个函数。

# Utilities to open video files using CV2
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

def load_video(path, max_frames=0, resize=(256, 256)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)

      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames) / 255.0


files = [f for f in glob.glob("**/*.avi", recursive=True)]

for video_path in files:
  video = load_video(video_path)
  video_name = video_path[video_path.find('/')+1:]
  num_frames = video.shape[0]
  print("Video in " + video_path + " has " + str(num_frames) + " frames.")
  for seg_num in range(math.floor(num_frames/10)):
    result = video[seg_num*10:(seg_num+1)*10, ...]
    new_filepath = video_name[:-4] + "_" + str(seg_num).zfill(2) + ".avi"
    print(new_filepath)
    out = cv2.VideoWriter(new_filepath,0, 25.0, (256,256))
    for frame_n in range(0,10):
      out.write(np.uint8(255*result[frame_n, ...]))
    out.release()
    del result
  del video
© www.soinside.com 2019 - 2024. All rights reserved.