尝试针对 Java 的随机森林预训练模型执行classifyInstance。我能够对相同的代码执行 SMO 朴素贝叶斯,但不能对随机森林预训练模型执行低于错误的操作。
java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for length 1
at weka.classifiers.meta.Bagging.distributionForInstance(Bagging.java:791)
at weka.classifiers.AbstractClassifier.classifyInstance(AbstractClassifier.java:173)
at com.msc.sinhalasongpredictorbackend.service.ClassifyService.runRandomForest(ClassifyService.java:97)
代码片段
private Integer runRandomForest() throws Exception {
Bagging randomForest = (Bagging) SerializationHelper.read(randomForestLocation);
CSVLoader loader = new CSVLoader();
try (InputStream fis = new FileInputStream(csvFeatureOutput)) {
loader.setSource(fis);
Instances trainingDataSet = loader.getDataSet();
List values = new ArrayList();
values.add("1.0");
trainingDataSet.insertAttributeAt(new Attribute("label", values), trainingDataSet.numAttributes());
trainingDataSet.setClassIndex(trainingDataSet.numAttributes() - 1);
for (Instance i : trainingDataSet) {
Double result = randomForest.classifyInstance(i);
return result.intValue();
}
return -1;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
CSV 一行(实例)如下所示(System.out.println(i))
13.46,0.1023,0.002532,221.9,0.001693,0.06354,0.05821,26.46,13.22,134.3,0.000406,0.0742,0.3509,0.2297,0.2114,0.1953,0.2013,0.1948,0.1646,0.1817,0,0.166,15.16,1359,107900,22360000,29.32,0.1582,0.001677,1679,0.00328,0.1308,0.5461,67.77,215.6,462.1,0.0095,-0.8899,0.4612,-0.04562,0.2517,-0.005843,-0.004694,0.01806,0.006126,0.02237,0,0.3106,58.65,3476,279200,51460000,?
请找到解决方案,与其他分类不同,应将类别索引值作为值列表添加到属性中。
private Integer runRandomForest() throws Exception {
Classifier randomForest = (Classifier) SerializationHelper.read(randomForestLocation);
CSVLoader loader = new CSVLoader();
try (InputStream fis = new FileInputStream(csvFeatureOutput)) {
loader.setSource(fis);
Instances trainingDataSet = loader.getDataSet();
ArrayList labels = new ArrayList();
labels.add("0.0");
labels.add("1.0");
labels.add("2.0");
Attribute attributeCls = new Attribute("label",labels);
ArrayList<Attribute> attributeArrayList = new ArrayList<>();
for (int i = 0; i < trainingDataSet.numAttributes(); i++) {
Attribute attribute = trainingDataSet.attribute(i);
attributeArrayList.add(attribute);
}
attributeArrayList.add(attributeCls);
double[] arrayDouble = getInstances(trainingDataSet);
Instances dataset = new Instances("testdata", attributeArrayList, 0);
DenseInstance instance = new DenseInstance(1.0, arrayDouble);
instance.setDataset(dataset);
dataset.setClassIndex(dataset.numAttributes() - 1);
double prediction = randomForest.classifyInstance(instance);
System.out.println(prediction);
return -1;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}