如何从Android中的DataSetlterator,DL4J获取mnist数据?

问题描述 投票:0回答:1
package com.example.minwoo_k.neural_network;

import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.reflections.vfs.CommonsVfs2UrlType;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;

import static android.R.id.input;
import static org.reflections.Reflections.log;

public class MainActivity extends AppCompatActivity {

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    AsyncTask.execute(new Runnable() {
        @Override
        public void run() {
            try {
                createAndUseNetwork();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    });
}

private void createAndUseNetwork() throws IOException {
    DenseLayer inputLayer = new DenseLayer.Builder()  // Input Layer
            .nIn(784)
            .nOut(200)
            .name("Input")
            .activation(Activation.SIGMOID)  // Sigmoid Activation function
            .build();

    DenseLayer hiddenLayer = new DenseLayer.Builder()  // Hidden Layer
            .nIn(200)
            .nOut(10)
            .name("Hidden")
            .activation(Activation.SIGMOID)  // Sigmoid Activation function
            .build();

    OutputLayer outputLayer = new OutputLayer.Builder()  // Output Layer
            .nIn(10)
            .nOut(10)
            .name("Output")
            .activation(Activation.SOFTMAX)  // Softmax Activation function
            .build();

    NeuralNetConfiguration.Builder nncBuilder = new NeuralNetConfiguration.Builder(); 
    nncBuilder.iterations(5);
    nncBuilder.learningRate(0.05);  // Learning Rate
    nncBuilder.weightInit(WeightInit.XAVIER);
    nncBuilder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);  // use SGD

    NeuralNetConfiguration.ListBuilder listBuilder = nncBuilder.list();
    listBuilder.layer(0, inputLayer);
    listBuilder.layer(1, hiddenLayer);
    listBuilder.layer(2, outputLayer);
    listBuilder.backprop(true);  // backpropagation

    Log.d("ANN","****************Create ANN********************");
    MultiLayerNetwork myNetwork = new MultiLayerNetwork(listBuilder.build());
    myNetwork.init();

    myNetwork.setListeners(new ScoreIterationListener(1));

    Log.d("ANN","****************Get Data********************");
    DataSetIterator mnistTrain = new MnistDataSetIterator(500, 10000, true);
    DataSetIterator mnistTest = new MnistDataSetIterator(500, 100, true);

    Log.d("ANN","****************Train ANN********************");
    myNetwork.fit(mnistTrain);

    Log.d("ANN","****************Evaluate ANN********************");
    Evaluation eval = new Evaluation(10); //create an evaluation object with 10 possible classes
    while(mnistTest.hasNext()){
        DataSet next = mnistTest.next();
        INDArray output = myNetwork.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    log.info(eval.stats());
    log.info("****************Example finished********************");
}
}

这是我的程序的完整源代码,我无法读取mnist数据。如何获取mnist数据集?

12-15 12:26:06.526 3910-3930 / com.example.minwoo_k.neural_network W / System.err:java.io.IOException:无法mkdir / MNIST 12-15 12:26:06.526 3910-3930 / com。 example.minwoo_k.neural_network W / System.err:at org.deeplearning4j.base.MnistFetcher.downloadAndUntar(MnistFetcher.java:66)12-15 12:26:06.529 3910-3930 / com.example.minwoo_k.neural_network W / System .err:at org.deeplearning4j.datasets.fetchers.MnistDataFetcher。(MnistDataFetcher.java:65)12-15 12:26:06.529 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at org.deeplearning4j .datasets.iterator.impl.MnistDataSetIterator。(MnistDataSetIterator.java:65)12-15 12:26:06.529 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at org.deeplearning4j.datasets.iterator。 impl.MnistDataSetIterator。(MnistDataSetIterator.java:43)12-15 12:26:06.529 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at com.example.minwoo_k.neural_network.MainActivity.createAndUseNetwork(MainActivity .java:93)12-15 12:26:06.529 3910-3930 / co m.example.minwoo_k.neural_network W / System.err:at com.example.minwoo_k.neural_network.MainActivity.access $ 000(MainActivity.java:33)12-15 12:26:06.531 3910-3930 / com.example.minwoo_k .neural_network W / System.err:at com.example.minwoo_k.neural_network.MainActivity $ 1.run(MainActivity.java:44)12-15 12:26:06.531 3910-3930 / com.example.minwoo_k.neural_network W / System .err:at android.os.AsyncTask $ SerialExecutor $ 1.run(AsyncTask.java:245)12-15 12:26:06.532 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at java.util .concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1162)12-15 12:26:06.532 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at java.util.concurrent.ThreadPoolExecutor $ Worker.run (ThreadPoolExecutor.java:636)12-15 12:26:06.532 3910-3930 / com.example.minwoo_k.neural_network W / System.err:at java.lang.Thread.run(Thread.java:764)

这是我的Logcat记录。我怎么解决这个问题?

android mnist dl4j
1个回答
0
投票

我认为错误信息非常明确。

Blockquote 12-15 12:26:06.526 3910-3930 / com.example.minwoo_k.neural_network W / System.err:java.io.IOException:无法mkdir / MNIST 12-15 12:26:06.526

可能你的程序无法在“/”中写入要创建“MNIST”目录的位置。

这里是它来自:https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-core/src/main/java/org/deeplearning4j/base/MnistFetcher.java

© www.soinside.com 2019 - 2024. All rights reserved.