Persist BERT模型作为pickle文件在磁盘上

问题描述 投票:1回答:1

我设法使BERT模型在johnsnowlabs-spark-nlp库上工作。我可以按以下方式在磁盘上保存“训练模型”。

适合型号

df_bert_trained = bert_pipeline.fit(textRDD)

df_bert=df_bert_trained.transform(textRDD)

保存模型

df_bert_trained.write().overwrite().save("/home/XX/XX/trained_model")

但是,

首先,根据此处https://nlp.johnsnowlabs.com/docs/en/concepts的文档,有人说可以将模型加载为

EmbeddingsHelper.load(path, spark, format, reference, dims, caseSensitive) 

但是目前我不清楚变量“引用”代表什么。

[其次,有人设法将BERT嵌入内容保存为python中的pickle文件吗?

apache-spark johnsnowlabs-spark-nlp
1个回答
0
投票

在Spark NLP中,BERT是经过预训练的模型。这意味着它已经是已经过训练,拟合等并以正确格式保存的模型。

话虽如此,没有理由再次容纳或保存它。但是,将DataFrame转换为每个令牌都具有BERT嵌入的新DataFrame之后,您可以保存其结果。

示例:

使用Spark NLP软件包在spark-shell中启动Spark会话

spark-shell --packages JohnSnowLabs:spark-nlp:2.4.0
import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.base._

val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentence = new SentenceDetector()
      .setInputCols("document")
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    // Download and load the pretrained BERT model
    val embeddings = BertEmbeddings.pretrained(name = "bert_base_cased", lang = "en")
      .setInputCols("sentence", "token")
      .setOutputCol("embeddings")
      .setCaseSensitive(true)
      .setPoolingLayer(0)

    val pipeline = new Pipeline()
      .setStages(Array(
        documentAssembler,
        sentence,
        tokenizer,
        embeddings
      ))

// Test and transform

   val testData = Seq(
      "I like pancakes in the summer. I hate ice cream in winter.",
      "If I had asked people what they wanted, they would have said faster horses"
    ).toDF("text")

    val predictionDF = pipeline.fit(testData).transform(testData)

predictionDF是一个DataFrame,其中包含数据集中每个令牌的BERT嵌入。 BertEmbeddings预训练模型来自TF Hub,这意味着它们与Google发布的预训练权重完全相同。所有5种型号均可用:

  • bert_base_cased(en)
  • bert_base_uncased(en)
  • bert_large_cased(en)
  • bert_large_uncased(en)
  • bert_multi_cased(xx)

让我知道您是否有任何疑问或问题,我会更新我的答案。

参考


推荐问答