使用TensorFlow Object Detection API确定最大批量大小

问题描述 投票:0回答:1

TF Object Detection API默认抓取所有GPU内存,因此很难说我可以进一步增加批量大小。通常我会继续增加它,直到我收到CUDA OOM错误。

另一方面,PyTorch默认不会占用所有GPU内存,因此很容易看到我剩下的百分比,没有所有的试验和错误。

有没有更好的方法来确定我丢失的TF对象检测API的批量大小?像allow-growthmodel_main.py旗帜?

tensorflow object-detection-api batchsize
1个回答
1
投票

我一直在寻找源代码,我发现没有与此相关的FLAG。

但是,在model_main.py的文件https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py中,您可以找到以下主要函数定义:

def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
...

我们的想法是以类似的方式修改它,例如以下方式:

config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True

config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)

所以,添加config_proto和改变config,但保持所有其他事情相等。

此外,allow_growth使程序使用尽可能多的GPU内存。所以,根据你的GPU,你最终可能会吃掉所有的内存。在这种情况下,您可能想要使用

config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9

它定义了要使用的内存部分。

希望这有所帮助。

如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何FLAG。除非FLAG

flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')

意味着与此相关的事情。但我不认为这是因为它似乎在model_lib.py它与火车,评估和推断配置有关,而不是GPU使用配置。

© www.soinside.com 2019 - 2024. All rights reserved.