scikit-learn - 使用RandomForestClassifier.predict()的单个字符串?

问题描述 投票:0回答:1

我是一个sklearn假...我试图从配有文本,标签的RandomForestClassifier()预测给定字符串的标签。

很明显我不知道如何使用单个字符串使用predict()。我使用reshape()的原因是因为我前一段时间出现此错误“如果您的数据具有单个要素或array.reshape(1,-1),则使用array.reshape(-1,1)重塑数据如果它包含一个样本。“

如何预测单个文本字符串的标签?

The script:

#!/usr/bin/env python
''' Read a txt file consisting of '<label>: <long string of text>'
    to use as a model for predicting the label for a string
'''

from argparse import ArgumentParser
import json
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder


def main(args):
    '''
    args: Arguments obtained by _Get_Args()
    '''

    print('Loading data...')
    # Load data from args.txtfile and split the lines into
    # two lists (labels, texts).
    data = open(args.txtfile).readlines()
    labels, texts = ([], [])
    for line in data:
        label, text = line.split(': ', 1)
        labels.append(label)
        texts.append(text)

    # Print a list of unique labels
    print(json.dumps(list(set(labels)), indent=4))

    # Instantiate a CountVectorizer class and git the texts
    # and labels into it.
    cv = CountVectorizer(
            stop_words='english',
            strip_accents='unicode',
            lowercase=True,
            )
    matrix = cv.fit_transform(texts)
    encoder = LabelEncoder()
    labels = encoder.fit_transform(labels)
    rf = RandomForestClassifier()
    rf.fit(matrix, labels)

    # Try to predict the label for args.string.
    prediction = Predict_Label(args.string, cv, rf)
    print(prediction)


def Predict_Label(string, cv, rf):
    '''
    string: str() - A string of text
    cv: The CountVectorizer class
    rf: The RandomForestClassifier class
    '''

    matrix = cv.fit_transform([string])
    matrix = matrix.reshape(1, -1)
    try:
        prediction = rf.predict(matrix)
    except Exception as E:
        print(str(E))
    else:
        return prediction


def _Get_Args():
    parser = ArgumentParser(description='Learn labels from text')
    parser.add_argument('-t', '--txtfile', required=True)
    parser.add_argument('-s', '--string', required=True)
    return parser.parse_args()


if __name__ == '__main__':
    args = _Get_Args()
    main(args)

实际学习数据文本文件长度为43663行,但样本位于small_list.txt中,其中包含以下格式的行:<label>: <long text string>

The error is noted in the Exception output:

$ ./learn.py -t small_list.txt -s 'This is a string that might have something to do with phishing or fraud'
Loading data...
[
    "Vulnerabilities__Unknown",
    "Vulnerabilities__MSSQL Browsing Service",
    "Fraud__Phishing",
    "Fraud__Copyright/Trademark Infringement",
    "Attacks and Reconnaissance__Web Attacks",
    "Vulnerabilities__Vulnerable SMB",
    "Internal Report__SBL Notify",
    "Objectionable Content__Russian Federation Objectionable Material",
    "Malicious Code/Traffic__Malicious URL",
    "Spam__Marketing Spam",
    "Attacks and Reconnaissance__Scanning",
    "Malicious Code/Traffic__Unknown",
    "Attacks and Reconnaissance__SSH Brute Force",
    "Spam__URL in Spam",
    "Vulnerabilities__Vulnerable Open Memcached",
    "Malicious Code/Traffic__Sinkhole",
    "Attacks and Reconnaissance__SMTP Brute Force",
    "Illegal content__Child Pornography"
]
Number of features of the model must match the input. Model n_features is 2070 and input n_features is 3 
None
scikit-learn text-classification
1个回答
0
投票

您需要获取第一个CountVectorizer(cv)的词汇表,并用于在预测之前转换新的单个文本。

...

cv = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        )

matrix = cv.fit_transform(texts)
encoder = LabelEncoder()
labels = encoder.fit_transform(labels)
rf = RandomForestClassifier()
rf.fit(matrix, labels)

# Try to predict the label for args.string.
cv_new = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        vocabulary=cv.vocabulary_
        )
prediction = Predict_Label(args.string, cv_new, rf)
print(prediction)

...
© www.soinside.com 2019 - 2024. All rights reserved.