Albert TF Hub模型上的Tf 2.0镜像策略(multi gpu)

问题描述 投票:0回答:1

我正在尝试在同一台计算机上的多个GPU上运行Albert Tensorflow集线器版本。该模型可以在单个GPU上完美运行。

这是我的代码的结构:

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) # it prints 2 .. correct
if __name__ == "__main__":
    with strategy.scope():
        run()

run()函数中,我读取了数据,建立了模型,并进行了拟合。

我遇到此错误:

Traceback (most recent call last):
  File "Albert.py", line 130, in <module>
    run()
  File "Albert.py", line 88, in run
    model = build_model(bert_max_seq_length)
  File "Albert.py", line 55, in build_model
    model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])
  File "/home/****/py_transformers/lib/python3.5/site-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/bighanem/py_transformers/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training.py", line 471, in compile
    '  model.compile(...)'% (v, strategy))
ValueError: Variable (<tf.Variable 'bert/embeddings/word_embeddings:0' shape=(30000, 128) dtype=float32>) was not created in the distribution strategy scope of (<tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f62e399df60>). It is most likely due to not all layers or the model or optimizer being created outside the distribution strategy scope. Try to make sure your code looks similar to the following.
with strategy.scope():
  model=_create_model()
  model.compile(...)

是否可能由于(C0]模型是由Tensorflow团队事先准备(构建和编译的)而发生的?

tensorflow tf.keras multi-gpu pre-trained-model tensorflow-hub
1个回答
0
投票

两部分答案:

1)TF Hub托管两个版本的ALBERT(每个版本都有几种大小:)>

未在分发策略范围中创建变量'bert / embeddings / word_embeddings'...尝试确保您的代码与以下内容相似。

TF2 SavedModel format

对于SavedModel(来自TF Hub或其他),是在分发策略范围内需要进行的加载,因为这是在当前程序中重新创建tf.Variable对象的原因。具体来说,以下任何一种从TF Hub加载TF2 SavedModel的方法都必须在分发策略范围内进行以使分发正常工作:

  • with strategy.scope(): model = _create_model() model.compile(...) ;
  • [tf.saved_model.load(),仅调用hub.load()(必要时下载后);]]
  • tf.saved_model.load()当与字符串值模型句柄一起使用时,将在其上调用hub.KerasLayer
热门问题
推荐问题
最新问题