防止Tensosflow Dataset在多次调用model.predict时重置生成器。

问题描述 投票:0回答:1

我正在使用tensorflow Dataset from_generator方法在不同批次上使用CNN模型进行预测。但我想在每个批次预测后添加一些额外的逻辑。具体来说,我想汇总不同的结果。

这是我的生成器函数。

def gen_predict(img_no):
  img_data = nib.load('./testing-images/10' + '%02d' %img_no + '_3.nii.gz').get_fdata()
  patch_size = 23
  dist_center = (patch_size - 1) // 2
  l, b, h = img_data.shape
  for zc in range(dist_center, h - dist_center - 1):
    for yc in range(dist_center, b - dist_center - 1):
      for xc in range(dist_center, l - dist_center - 1):    
        print(xc,yc,zc) 
        xl, yl, zl = (xc - dist_center, yc - dist_center, zc - dist_center)
        xr, yr, zr = (xc + dist_center, yc + dist_center, zc + dist_center)
        cartesianCoordinate = np.array([xc, yc, zc])
        spectralCoordinates = np.array([0, 0, 0])
        X = (np.array(img_data[xl:(xr + 1), yl:(yr + 1), zl:(zr + 1)]), np.concatenate((cartesianCoordinate, spectralCoordinates)).reshape((6,1)))
        yield (X,)

问题是在每次预测调用后,生成器都会重置 在下一次预测调用时,它会给出同一组数据的预测结果。这是我的代码。

dataset_pred = tf.data.Dataset.from_generator(lambda: gen_predict(3), ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##

我想模仿 model.predict(dataset_pred, batch_size=BS, steps=num_batches) 的行为,并加上额外的逻辑。另外,由于num_batches很大,我无法存储这个调用的结果。

EDIT: 我已经添加了答案。但真的很感谢任何能提高效率的帮助。

python python-3.x tensorflow tensorflow-datasets
1个回答
0
投票

我找到了答案。基本上,我们可以将相应的生成器存储在一个变量中,然后使用lambda使其可调用。这样就不会重置生成器了。

cur_gen = gen_predict(img_no)
dataset_pred = tf.data.Dataset.from_generator(lambda: cur_gen, ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##
© www.soinside.com 2019 - 2024. All rights reserved.