神经网络MNIST号码识别

问题描述 投票:-3回答:1

我最近尝试用神经网络创建我的第一个项目,这就是我想出的。我想让它识别MNIST手写的数字。问题是,当我运行这个代码并使其训练像~400k次时,我得到了约28%的测试数据的准确性。这应该是那样的吗? 400k太少无法获得更好的结果,还是因为我的神经网络只能有一个隐藏层?

总结一下简短的问题,事情应该是那样的,还是我做错了什么?下面有很多冗余代码和类似的东西,我只是想让它工作。

假设我的神经网络明显起作用的一切。

public static void main(String[] args) {

  List<Data> trainData = new ArrayList<>();
  List<Data> testData = new ArrayList<>();

  byte[] trainLabels;
  byte[] trainImages;
  byte[] testLabels;
  byte[] testImages;

  try {

     Path tempPath1 = Paths.get("res/train-labels-idx1-ubyte");
     trainLabels = Files.readAllBytes(tempPath1);
     ByteBuffer bufferLabels = ByteBuffer.wrap(trainLabels);
     int magicLabels = bufferLabels.getInt();
     int numberOfItems = bufferLabels.getInt();

     Path tempPath = Paths.get("res/train-images-idx3-ubyte");
     trainImages = Files.readAllBytes(tempPath);
     ByteBuffer bufferImages = ByteBuffer.wrap(trainImages);
     int magicImages = bufferImages.getInt();
     int numberOfImageItems = bufferImages.getInt();
     int rows = bufferImages.getInt();
     int cols = bufferImages.getInt();

     for(int i = 0; i < numberOfItems; i++) {
        int t = bufferLabels.get();
        double[] target = createTargets(t);
        double[] inputs = new double[rows*cols];
        for(int j = 0; j < inputs.length; j++) {
           inputs[j] = bufferImages.get();
           }
         Data tobj = new Data(inputs, target);
         trainData.add(tobj);
       }

      tempPath = Paths.get("res/t10k-labels-idx1-ubyte");
      testLabels = Files.readAllBytes(tempPath);
      ByteBuffer testLabelBuffer = ByteBuffer.wrap(testLabels);
      int testMagicLabels = testLabelBuffer.getInt();
      int numberOfTestLabels = testLabelBuffer.getInt();

      tempPath = Paths.get("res/t10k-images-idx3-ubyte");
      testImages = Files.readAllBytes(tempPath);
      ByteBuffer testImageBuffer = ByteBuffer.wrap(testImages);
      int testMagicImages = testImageBuffer.getInt();
      int numberOfTestImages = testImageBuffer.getInt();
      int testRows = testImageBuffer.getInt();
      int testCols = testImageBuffer.getInt();

      for(int i = 0; i < numberOfTestImages; i++) {
          double[] target = new double[]{testLabelBuffer.get()};
          double[] inputs = new double[testRows*testCols];
          for(int j = 0; j < inputs.length; j++) {
              inputs[j] = testImageBuffer.get();
             }
          Data tobj = new Data(inputs, target);
          testData.add(tobj);
         }

       NeuralNetwork neuralNetwork = new NeuralNetwork(784,64,10);

       int len = trainData.size();
       Random randomGenerator = new Random();
       for(int i = 0; i < 400000; i++) {
           int randomInt = randomGenerator.nextInt(len);
           neuralNetwork.train(trainData.get(randomInt).getInputs(), trainData.get(randomInt).getTargets());
          }

        float rightAnswers = 0;

        for(Data testObj : testData) {
           double[] output = neuralNetwork.feedforward(testObj.getInputs());
           double[] answer = testObj.getTargets(); 
         }
            System.out.println(percentage);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public static double[] createTargets(int number) {
            double[] result = new double[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
            result[number] = 1;
            return  result;

        }
java neural-network mnist
1个回答
0
投票

如果有人有兴趣,那我就有一个错误。在记录所有内容时,我注意到输入像素值的范围从-255到255,并且从MNIST文档中它们应该是0-255。最重要的是,我的输入没有被标准化,所以当其他255时,其中一些是0。这就是我添加的内容。希望我不会错过任何东西。现在我的准确度达到了90%。

for(int i = 0; i < numberOfTestImages; i++) {

   double[] target = new double[]{testLabelBuffer.get()& 0xFF};
   double[] inputs = new double[testRows*testCols];
   or(int j = 0; j < inputs.length; j++) {
   // Normalize input from 0-255 to 0-1
   double temp = (testImageBuffer.get() & 0xFF) / 255f;
   inputs[j] = temp;
 }
 Data tobj = new Data(inputs, target);
 testData.add(tobj);
}
© www.soinside.com 2019 - 2024. All rights reserved.