WEKA 为随机森林保存的模型执行分类实例

问题描述 投票:0回答:1

尝试针对 Java 的随机森林预训练模型执行classifyInstance。我能够对相同的代码执行 SMO 朴素贝叶斯,但不能对随机森林预训练模型执行低于错误的操作。

java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for length 1
    at weka.classifiers.meta.Bagging.distributionForInstance(Bagging.java:791)
    at weka.classifiers.AbstractClassifier.classifyInstance(AbstractClassifier.java:173)
    at com.msc.sinhalasongpredictorbackend.service.ClassifyService.runRandomForest(ClassifyService.java:97) 

代码片段

private Integer runRandomForest() throws Exception {
    Bagging randomForest = (Bagging) SerializationHelper.read(randomForestLocation);
    CSVLoader loader = new CSVLoader();
    try (InputStream fis = new FileInputStream(csvFeatureOutput)) {
        loader.setSource(fis);

        Instances trainingDataSet = loader.getDataSet();

        List values = new ArrayList();
        values.add("1.0");
        trainingDataSet.insertAttributeAt(new Attribute("label", values), trainingDataSet.numAttributes());
        trainingDataSet.setClassIndex(trainingDataSet.numAttributes() - 1);

        for (Instance i : trainingDataSet) {
            Double result = randomForest.classifyInstance(i);
            return result.intValue();
        }

        return -1;
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e.getMessage());
    }
}

CSV 一行(实例)如下所示(System.out.println(i))

13.46,0.1023,0.002532,221.9,0.001693,0.06354,0.05821,26.46,13.22,134.3,0.000406,0.0742,0.3509,0.2297,0.2114,0.1953,0.2013,0.1948,0.1646,0.1817,0,0.166,15.16,1359,107900,22360000,29.32,0.1582,0.001677,1679,0.00328,0.1308,0.5461,67.77,215.6,462.1,0.0095,-0.8899,0.4612,-0.04562,0.2517,-0.005843,-0.004694,0.01806,0.006126,0.02237,0,0.3106,58.65,3476,279200,51460000,?
java machine-learning deep-learning random-forest weka
1个回答
0
投票

请找到解决方案,与其他分类不同,应将类别索引值作为值列表添加到属性中。

private Integer runRandomForest() throws Exception {
    Classifier randomForest = (Classifier) SerializationHelper.read(randomForestLocation);
    CSVLoader loader = new CSVLoader();
    try (InputStream fis = new FileInputStream(csvFeatureOutput)) {
        loader.setSource(fis);

        Instances trainingDataSet = loader.getDataSet();

        ArrayList labels = new ArrayList();
        labels.add("0.0");
        labels.add("1.0");
        labels.add("2.0");
        Attribute attributeCls = new Attribute("label",labels);

        ArrayList<Attribute> attributeArrayList = new ArrayList<>();
        for (int i = 0; i < trainingDataSet.numAttributes(); i++) {
            Attribute attribute = trainingDataSet.attribute(i);
            attributeArrayList.add(attribute);
        }
        attributeArrayList.add(attributeCls);
        double[] arrayDouble = getInstances(trainingDataSet);
        Instances dataset = new Instances("testdata", attributeArrayList, 0);

        DenseInstance instance = new DenseInstance(1.0, arrayDouble);
        instance.setDataset(dataset);
        dataset.setClassIndex(dataset.numAttributes() - 1);
        
        double prediction = randomForest.classifyInstance(instance);
        System.out.println(prediction);

        return -1;
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e.getMessage());
    }
}  
© www.soinside.com 2019 - 2024. All rights reserved.