使用 Nd4j 实现 Java 神经网络太慢了

问题描述 投票:0回答:1

我是一名新手学生,正在从事一个用 Java 实现神经网络的个人项目。我按照 Michael Nielsen 先生写的《神经网络和深度学习》一书中的说明进行操作。

但是,书中提供的代码的调用是用Python和Numpy编写的,所以我想我也应该像Numpy中那样利用NDArray对象,所以我在Java中听到和使用的第一个就是Nd4j库。

先生。 Michael Nielsen 的代码非常有趣,我最近刚刚使用 Nd4j 在 Java 中成功实现了一个简单的神经网络:

public class Network{
    int num_layers;
    int[] networkSize;
    List<INDArray> weights = new ArrayList<>();
    List<INDArray> biases = new ArrayList<>();


    public Network(int[] size){
        this.num_layers = size.length;
        this.networkSize = size;
        for(int i=1;i < num_layers;i++){
            // Number of neurons for current layer
            int x = size[i];
            INDArray bias = Nd4j.randn(x, 1);
            this.biases.add(bias);
        }

        for(int i=0; i < num_layers -1; i++){
            // Number of neurons for current layer
            int x = size[i];
            // Number of neurons for next layer
            int y = size[i+1];
            INDArray weight = Nd4j.randn(y,x);
            this.weights.add(weight);
        }
    }


    public INDArray feedforward(INDArray a){
        for(int i = 0; i < this.biases.size(); i++){
            INDArray weightsMatrix = this.weights.get(i);
            INDArray biasesMatrix = this.biases.get(i);
            a = weightsMatrix.mmul(a);
            a = a.add(biasesMatrix);
            a = Transforms.sigmoid(a);
        }
        return a;
    }


    public void stochasticGradientDescent(List<List<INDArray>> training_datas, int epochs, int mini_batch_size, float learning_rate, List<List<INDArray>> test_datas){
        int n_test = test_datas.size();
        int n = training_datas.size();

        for(int i=0; i<epochs; i++){
            Collections.shuffle(training_datas);
            List<List<List<INDArray>>> mini_batches = new ArrayList<>();

            for(int k=0; k<n; k+=mini_batch_size){
                mini_batches.add(training_datas.subList(k, k+mini_batch_size));
            }

            for(List<List<INDArray>> mini_batch : mini_batches){
                this.update_mini_batch(mini_batch, learning_rate);
            }

            System.out.println(String.format("Epoch %d: %d / %d", i, this.evaluate(test_datas), n_test));
        }

    }


    public List<List<INDArray>> backpropagation(INDArray output_activations, INDArray desiredOutput){
        List<INDArray> gradient_biases = new ArrayList<>();
        List<INDArray> gradient_weights = new ArrayList<>();
        
        for(INDArray weight:this.weights){
            gradient_weights.add(Nd4j.zerosLike(weight));
        }
        for(INDArray bias:this.biases){
            gradient_biases.add(Nd4j.zerosLike(bias));
        }

        INDArray activation = output_activations;
        List<INDArray> activations = new ArrayList<>();
        activations.add(output_activations);
        List<INDArray> z_vectors = new ArrayList<>();

        for(int i=0; i<this.biases.size(); i++){
            INDArray weightsMatrix = this.weights.get(i);
            INDArray biasesMatrix = this.biases.get(i);
            INDArray z = weightsMatrix.mmul(activation).add(biasesMatrix);
            z_vectors.add(z);
            activation = Transforms.sigmoid(z);
            activations.add(activation);
        }

        INDArray last_activation_layer = activations.get(activations.size()-1);
        INDArray last_z_vector_layer = z_vectors.get(z_vectors.size()-1);
        // Backward pass
        INDArray delta_vector = this.cost_derivative(last_activation_layer, desiredOutput).mul(Transforms.sigmoidDerivative(last_z_vector_layer));
        gradient_biases.set(gradient_biases.size()-1,delta_vector);
        gradient_weights.set(gradient_weights.size()-1,delta_vector.mmul(activations.get(activations.size()-2).transpose()));


        for(int l=2; l > this.num_layers; l++){
            INDArray z = z_vectors.get(z_vectors.size() - l);
            INDArray sigmoid_prime = Transforms.sigmoidDerivative(z);
            delta_vector = this.weights.get(this.weights.size()-l+1).transpose().mmul(delta_vector).mul(sigmoid_prime);

            gradient_biases.set(gradient_biases.size()-l, delta_vector);
            gradient_weights.set(gradient_weights.size()-l, delta_vector.mmul(activations.get(activations.size()-l-1).transpose()));
        }

        List<List<INDArray>> gradients = new ArrayList<>();
        gradients.add(gradient_biases);
        gradients.add(gradient_weights);
        return gradients;
    }


    public INDArray cost_derivative(INDArray output_activations, INDArray desiredOutput){
        return output_activations.sub(desiredOutput);
    }


    public void update_mini_batch(List<List<INDArray>> mini_batches, float learning_rate){
        List<INDArray> gradient_biases = new ArrayList<>();
        List<INDArray> gradient_weights = new ArrayList<>();
        
        for(INDArray weight:this.weights){
            gradient_weights.add(Nd4j.zerosLike(weight));
        }
        for(INDArray bias:this.biases){
            gradient_biases.add(Nd4j.zerosLike(bias));
        }


        // Iterate throught the mini batches
        for(List<INDArray> mini_batch : mini_batches){
            INDArray output_activations = mini_batch.get(0);
            INDArray desiredOutput = mini_batch.get(1);

            List<List<INDArray>> gradients = this.backpropagation(output_activations, desiredOutput);
            List<INDArray> delta_gradient_biases = gradients.get(0);
            List<INDArray> delta_gradient_weights = gradients.get(1);

            for(int i=0; i<delta_gradient_biases.size(); i++){
                INDArray gradient_bias = gradient_biases.get(i);
                INDArray delta_gradient_bias = delta_gradient_biases.get(i);
                INDArray new_gradient_bias = gradient_bias.add(delta_gradient_bias);

                gradient_biases.set(i, new_gradient_bias);
            }

            for(int i=0; i<delta_gradient_weights.size(); i++){
                INDArray gradient_weight = gradient_weights.get(i);
                INDArray delta_gradient_weight = delta_gradient_weights.get(i);
                INDArray new_gradient_weight = gradient_weight.add(delta_gradient_weight);

                gradient_weights.set(i, new_gradient_weight);
            }

            // Updating the network weights and biases base on the average gradient of the mini batches
            // Update the weights
            for(int i=0; i<this.weights.size(); i++){
                INDArray current_weight = this.weights.get(i);
                INDArray gradient_weight_sum = gradient_weights.get(i);

                INDArray average_weight_gradient = gradient_weight_sum.muli(learning_rate/mini_batches.size());
                INDArray new_weight = current_weight.sub(average_weight_gradient);

                this.weights.set(i, new_weight);
            }

            // Update the biases
            for(int i=0; i<this.biases.size(); i++){
                INDArray current_bias = this.biases.get(i);
                INDArray gradient_bias_sum = gradient_biases.get(i);

                INDArray average_bias_gradient = gradient_bias_sum.muli(learning_rate/mini_batches.size());
                INDArray new_bias = current_bias.sub(average_bias_gradient);

                this.biases.set(i, new_bias);
            }
        }
    }


    public int evaluate(List<List<INDArray>> test_datas){
        int correct_predictions = 0;
        List<List<Integer>> test_results = new ArrayList<>();

        for(List<INDArray> test_data:test_datas){
            INDArray raw_input = test_data.get(0);
            INDArray desired_output = test_data.get(1);

            INDArray finale_output_layer = this.feedforward(raw_input);
            Integer class_index = Nd4j.argMax(finale_output_layer).getInt(0);
            Integer class_result = Nd4j.argMax(desired_output).getInt(0);

            List<Integer> test_result = new ArrayList<>();
            test_result.add(class_index);
            test_result.add(class_result);

            test_results.add(test_result);
        }

        for(List<Integer> result:test_results){
            if(result.get(0) == result.get(1)) correct_predictions++;
        }

        return correct_predictions;
    }
}

我尝试在 MNist 数据集上训练我的第一个网络,当我看到正在发生一些“学习”时,我感到非常兴奋。

Its actually

再次,没过多久我就意识到我的网络的训练时间远远没有迈克尔·尼尔森先生承诺的几分钟。我真的很担心我的网络实现是错误的,因为让它运行已经非常痛苦了。但网络显示了一些学习成果,所以我必须希望我的实现在某种程度上是正确的。

除了低于预期的准确度之外,我的网络的学习速度非常慢(~ 1 分钟/Epoch)。这让我作为一个非常没有经验的学生感到非常失望。

我尝试过的

  • 我试图在网上寻找解决方案或任何其他提供快速矩阵计算的替代库,或NDArray,或线性代数,或任何类似Numpy但适用于Java的库。

  • 我看到了一些非常突出的建议,例如提供 NDArray、EJML 和 Apache Commons Math 的 Deep Java Library。然而,我还在一些论坛上看到基准测试,表明 ND4j 就性能而言仍然是这些库中的最佳选择。

  • 我尝试不在每个纪元后测试网络,但它似乎并没有大大改善时间。

我非常绝望,必须创建我的第一个 Stack Overflow 帐户来寻求帮助。

先生。 Michael Nielsen 的神经网络在几分钟内就在 MNist 数据集上进行了训练,并轻松实现了 95% 的准确率,而我的网络在最佳配置下只能达到 76%。

为什么我的网络这么慢,为什么它的收敛精度这么低?请帮助我,我不知道该怎么办。

java deep-learning neural-network numpy-ndarray nd4j
1个回答
0
投票

此答案旨在作为指南,帮助用户确定性能不佳的原因,并提供解决这些问题的步骤。 性能问题可能包括: CPU/GPU 利用率低 训练或操作执行速度慢于预期 首先,我们总结了一些可能导致性能问题的原因: 使用了错误的 ND4J 后端(例如,预期使用 GPU 后端时使用了 CPU 后端) 使用 CUDA GPU 时不使用 cuDNN ETL(数据加载)瓶颈 垃圾收集开销 小批量 多线程使用MultiLayerNetwork/ComputationGraph进行推理(非线程安全) 应使用单精度时使用的双精度浮点数据类型 不使用工作区进行内存管理(默认启用) 网络配置不当 层或操作仅限 CPU CPU:缺乏对现代 AVX 等扩展的硬件支持 使用 CPU 或 GPU 资源的其他进程 CPU:同时使用多个模型/线程时缺少 OMP_NUM_THREADS 配置。

© www.soinside.com 2019 - 2024. All rights reserved.