tf.device()出现异常

问题描述 投票:0回答:1

当我使用tf.device()分配GPU编号时,它似乎是一个例外。这是我第一次在Stack Overflow中提问,如果有错误,请原谅我,并告诉我。

当我在代码中放入allow_soft_placement = True时,它可以工作。

tensorflow nlp
1个回答
0
投票
def init_graph(self):
    """
    init bert graph
    """
    with self.tf_instance.device('device:GPU:{}'.format(str(self.gpu_no))):
        # add tokenizer
        from bert import tokenization
        self.tokenizer = tokenization.FullTokenizer(self.args.vocab_file)
        from bert import modeling
        bert_config = modeling.BertConfig.from_json_file(self.args.config_file)
        self.model = modeling.BertModel(config=bert_config,
                                        is_training=False,
                                        input_ids=self.input_ids,
                                        input_mask=self.input_mask,
                                        token_type_ids=self.input_type_ids,
                                        use_one_hot_embeddings=False)

        # get output weights and output bias
        reader = self.tf_instance.train.NewCheckpointReader(self.args.ckpt_file)
        output_weights = reader.get_tensor('output_weights')
        output_bias = reader.get_tensor('output_bias')

        # get result op
        output_layer = self.model.get_pooled_output()
        logits = self.tf_instance.matmul(output_layer, output_weights, transpose_b=True)
        logits = self.tf_instance.nn.bias_add(logits, output_bias)
        self.probabilities = self.tf_instance.nn.softmax(logits, axis=-1)

        sess_config = self.tf_instance.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        graph = self.probabilities.graph
        saver = self.tf_instance.train.Saver()
        self.sess = self.tf_instance.Session(config=sess_config, graph=graph)
        self.sess.run(self.tf_instance.global_variables_initializer())
        self.tf_instance.reset_default_graph()
        saver.restore(self.sess, self.args.ckpt_file)
© www.soinside.com 2019 - 2024. All rights reserved.