使用Spark-NLP和Spark MLlib PySpark的文本分类

问题描述 投票:0回答:1

我想使用Spark-NLP构建文本分类模型以对文档进行预处理,并使用Spark MLlib的OnevsRestClassifier为每个类别标签生成概率预测(我有55个类别标签用于预测)。我已经成功设置了Spark-NLP管道。以下是代码,

# Import Spark NLP
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.pretrained import PretrainedPipeline
import sparknlp
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline

# Start Spark Session with Spark NLP
#spark = sparknlp.start()

spark = SparkSession.builder.appName("Spark NLP").master("local[4]").config("spark.driver.memory","2G").config("spark.driver.maxResultSize", "2G").config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.11:2.4.5").config("spark.kryoserializer.buffer.max", "1000M").getOrCreate()

# File location and type
file_location = r'path_to_data'
file_type = "csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","


df = spark.read.format(file_type).option("inferSchema", infer_schema).option("header", first_row_is_header).option("sep", delimiter).load(file_location)


df.count()

stats=df.groupBy(df['target_labels']).count()

stats.show(55)

documentAssembler = DocumentAssembler().setInputCol("cleantext_woRACs").setOutputCol("document").setCleanupMode("shrink")


sentence_detector = SentenceDetector().setInputCols("document").setOutputCol("sentence")


tokenizer = Tokenizer().setInputCols(["sentence"]).setOutputCol("token").setMinLength(3).setMaxLength(10)


stopwords=['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', "wouldn't"]

stop_words_cleaner = StopWordsCleaner().setInputCols(["token"]).setOutputCol("cleanTokens").setCaseSensitive(False).setStopWords(stopwords)


stemmer = Stemmer().setInputCols(["token"]).setOutputCol("stem")

normalizer = Normalizer().setInputCols(["stem"]).setOutputCol("normalized")

df_subset = df.limit(10)

nlp_pipeline = Pipeline(stages=[documentAssembler,sentence_detector,tokenizer, stop_words_cleaner,stemmer, normalizer])

nlp_model = nlp_pipeline.fit(df_subset)
processed = nlp_model.transform(df_subset).persist()

processed.count()
processed.show()


processed.printSchema()

上面的代码有效,我能够获得标准化的令牌。下一步是生成TF-IDF向量。我正在尝试使用HashingDF和IDF生成向量,但出现错误。

下面是产生错误的代码,

from pyspark.ml.feature import HashingTF, IDF, Tokenizer

# Term frequency
hashingTF = HashingTF(inputCol="normalized", outputCol="rawFeatures", numFeatures=20)
featurized_data = hashingTF.transform(processed)

# Inverse document frequency
idf = IDF(inputCol="rawFeatures", outputCol="features")
idfModel = idf.fit(featurized_data)
rescaled_data = idfModel.transform(featurized_data)

以下为错误,

Py4JJavaError: An error occurred while calling o614.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 23.0 failed 1 times, most recent failure: Lost task 0.0 in stage 23.0 (TID 223, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$class.foreach(Iterator.scala:891)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
    at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
    at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1334)
    at scala.collection.TraversableOnce$class.aggregate(TraversableOnce.scala:214)
    at scala.collection.AbstractIterator.aggregate(Iterator.scala:1334)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$24.apply(RDD.scala:1145)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$24.apply(RDD.scala:1145)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$25.apply(RDD.scala:1146)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$25.apply(RDD.scala:1146)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
    at java.lang.Thread.run(Unknown Source)
Caused by: org.apache.spark.SparkException: HashingTF with murmur3 algorithm does not support type org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema of input data.
    at org.apache.spark.mllib.feature.HashingTF$.murmur3Hash(HashingTF.scala:164)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$getHashFunction$1.apply(HashingTF.scala:84)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$getHashFunction$1.apply(HashingTF.scala:84)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$transform$1.apply(HashingTF.scala:101)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$transform$1.apply(HashingTF.scala:100)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35)
    at org.apache.spark.mllib.feature.HashingTF.transform(HashingTF.scala:100)
    at org.apache.spark.ml.feature.HashingTF$$anonfun$1.apply(HashingTF.scala:98)
    at org.apache.spark.ml.feature.HashingTF$$anonfun$1.apply(HashingTF.scala:98)
    ... 29 more

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
    at org.apache.spark.rdd.RDD$$anonfun$fold$1.apply(RDD.scala:1098)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.fold(RDD.scala:1092)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1.apply(RDD.scala:1161)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1137)
    at org.apache.spark.mllib.feature.IDF.fit(IDF.scala:54)
    at org.apache.spark.ml.feature.IDF.fit(IDF.scala:92)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
    at java.lang.reflect.Method.invoke(Unknown Source)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Unknown Source)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
    at scala.collection.Iterator$class.foreach(Iterator.scala:891)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
    at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
    at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1334)
    at scala.collection.TraversableOnce$class.aggregate(TraversableOnce.scala:214)
    at scala.collection.AbstractIterator.aggregate(Iterator.scala:1334)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$24.apply(RDD.scala:1145)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$24.apply(RDD.scala:1145)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$25.apply(RDD.scala:1146)
    at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1$$anonfun$25.apply(RDD.scala:1146)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
    ... 1 more
Caused by: org.apache.spark.SparkException: HashingTF with murmur3 algorithm does not support type org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema of input data.
    at org.apache.spark.mllib.feature.HashingTF$.murmur3Hash(HashingTF.scala:164)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$getHashFunction$1.apply(HashingTF.scala:84)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$getHashFunction$1.apply(HashingTF.scala:84)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$transform$1.apply(HashingTF.scala:101)
    at org.apache.spark.mllib.feature.HashingTF$$anonfun$transform$1.apply(HashingTF.scala:100)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35)
    at org.apache.spark.mllib.feature.HashingTF.transform(HashingTF.scala:100)
    at org.apache.spark.ml.feature.HashingTF$$anonfun$1.apply(HashingTF.scala:98)
    at org.apache.spark.ml.feature.HashingTF$$anonfun$1.apply(HashingTF.scala:98)
    ... 29 more

根据上述日志,我认为该错误是由于类型不匹配引起的。如何将Spark-NLP的输出与Spark MLlib软件包集成。

已处理数据框的架构为,

root
 |-- _c0: integer (nullable = true)
 |-- cleantext_woRACs: string (nullable = true)
 |-- target_labels: string (nullable = true)
 |-- document: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- sentence: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- token: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- cleanTokens: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- stem: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- normalized: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)

这是标准化列的第一行的样子,

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|normalized                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[[token, 3, 7, until, [sentence -> 0], []], [token, 9, 14, jurong, [sentence -> 0], []], [token, 16, 20, point, [sentence -> 0], []], [token, 23, 27, crazi, [sentence -> 0], []], [token, 31, 35, avail, [sentence -> 2], []], [token, 41, 44, onli, [sentence -> 2], []], [token, 49, 52, bugi, [sentence -> 2], []], [token, 57, 61, great, [sentence -> 2], []], [token, 63, 67, world, [sentence -> 2], []], [token, 74, 79, buffet, [sentence -> 2], []], [token, 84, 87, cine, [sentence -> 3], []], [token, 89, 93, there, [sentence -> 3], []], [token, 95, 97, got, [sentence -> 3], []], [token, 99, 102, amor, [sentence -> 3], []], [token, 105, 107, wat, [sentence -> 3], []]]|
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

感谢帮助!

apache-spark pyspark johnsnowlabs-spark-nlp
1个回答
0
投票

我找到了解决问题的方法。我没有使用整理器变压器。

finisher = Finisher() \
    .setInputCols(["stem"]) \
    .setOutputCols(["token_features"]) \
    .setOutputAsArray(True) \
    .setCleanAnnotations(False)

您可以在此link上查看John Snow Labs的详细示例。

© www.soinside.com 2019 - 2024. All rights reserved.