构建张量流图时修改Python对象

问题描述 投票:0回答:1

我正在尝试根据我收到的输入类型进行数据增强,特别是我想使用文件名来决定将使用哪种类型的增强。

为此,我使用 Tensorflow 的对象检测 API,并修改 augment_input_data 函数。我已将以下代码添加到函数的开头:

my_data_augmentation_options = data_augmentation_options.copy()
  
@tf.function
def filter_supervised(data_aug_options):
  for opt in data_aug_options:
    if opt[0].__name__ in ['random_vertical_flip', 'random_horizontal_flip']:
      data_aug_options.remove(opt)

@tf.function
def filter_unsupervised(data_aug_options):
  for opt in data_aug_options:
    if opt[0].__name__ in ['random_distort_color']:
      data_aug_options.remove(opt)

tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'), filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))

不幸的是,这会引发以下错误:

ValueError: filter_unsupervised() should not modify its Python input arguments. Check if it modifies any lists or dicts passed as arguments. Modifying a copy is allowed.

有没有办法修改Python对象?或者我是否需要搜索它停止热切执行的位置并尝试保持热切执行直到这一部分?

提前致谢。

重现我所面临的错误的一个最小示例是

def custom_augment_input_data(tensor_dict, data_augmentation_options):
  my_data_augmentation_options = data_augmentation_options.copy()
  @tf.function
  def filter_supervised(data_aug_options):
    for opt in data_aug_options:
      if opt[0].__name__ in ['random_horizontal_flip']:
        data_aug_options.remove(opt)

  @tf.function
  def filter_unsupervised(data_aug_options):
    for opt in data_aug_options:
      if opt[0].__name__ in ['random_distort_color']:
        data_aug_options.remove(opt)
  tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'),
          filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))
  # Continue doing stuff below

from object_detection.core.preprocessor import random_horizontal_flip, random_distort_color
data_augmentation_options = [(random_horizontal_flip, {'keypoint_flip_permutation': None, 'probability':0.5}),
                             (random_distort_color, {})]

dummy_dict = {fields.InputDataFields.filename: './crop_32412.jpg'}
custom_augment_input_data(dummy_dict, data_augmentation_options)
  
python tensorflow object-detection-api
1个回答
0
投票

我也遇到了同样的错误。 就我而言,我要修改的对象是

dict
。我尝试使用
copy.deepcopy
进行复制,但在图形中执行时该对象结果是不可序列化的
tensor
,无法以这种方式复制。 (操作似乎做了一个浅拷贝
data_augmentation_options.copy()
,它仍然指向同一个对象。)

所以我的解决方案是手动复制:

dct = {k:v for k,v in dct.items()}
,这样就绕过了这个问题。我预计它也适用于其他类型的对象(例如
lst = [e for e in lst]
)。

© www.soinside.com 2019 - 2024. All rights reserved.