猿问

如何用Java构建一个基本的神经网络?

我正在尝试XOR用 Java构建一个基本的神经网络来计算逻辑函数。


该网络有两个输入神经元,一个包含三个神经元的隐藏层和一个输出神经元。


但是经过几次迭代后,输出中的误差变为NaN。


我已经浏览了其他实现神经网络的实现和教程,但我找不到错误。我觉得问题在于我的落后功能。


请帮助我理解我哪里出错了。


我的代码:


import org.ejml.simple.SimpleMatrix;


import java.util.ArrayList;

import java.util.List;

import java.util.Random;


// SimpleMatrix constructor format: SimpleMatrix(rows, cols)

//The layers are represented as a matrix with 1 row and multiple columns (row vector)

public class Network {

    private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;

    static final double LEARNING_RATE = 0.3;


    Network(List<double[]> ips, List<double[]> ops){

        hidden = new SimpleMatrix(1, 3);

        W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());

        W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);

        initWeights(W1,W2);


        for(int i=0;i<5000;i++){

            for(int j=0;j<ips.size();j++){

                train(ips.get(j), ops.get(j));

            }

        }

        System.out.println("Trained");

    }


    //Prints output matrix

    SimpleMatrix predict(double[] ip){

        SimpleMatrix bkpInputs = inputs.copy();

        SimpleMatrix bkpOutputs = outputs.copy();


        inputs = new SimpleMatrix(1, ip.length);

        inputs.setRow(0, 0, ip);


        forward();

        inputs = bkpInputs;

        outputs = bkpOutputs;


        predicted.print();

        return predicted;

    }


    void train(double[] inputs, double[] outputs){

        this.inputs = new SimpleMatrix(1, inputs.length);

        this.inputs.setRow(0, 0, inputs);

        this.outputs = new SimpleMatrix(1, outputs.length);

        this.outputs.setRow(0,0,outputs);

        this.predicted = new SimpleMatrix(1,outputs.length);


        forward();

        backward();

    }


慕桂英546537
浏览 172回答 2
2回答

江户川乱折腾

因此,这可能不是导致您出现问题的原因,但我注意到:W1.get(i,j)&nbsp;+&nbsp;LEARNING_RATE*W1_delta.get(i,&nbsp;0));当您更新权重时。我认为正确的公式是:所以你的代码应该是:W1(i,j)&nbsp;+=&nbsp;LEARNING_RATE&nbsp;*&nbsp;W1_delta.get(i,&nbsp;0)&nbsp;*&nbsp;&nbsp;<output&nbsp;from&nbsp;the&nbsp;connected&nbsp;node>;它可能无法解决它,但值得一试!

千巷猫影

尝试使用较低的学习率。当错误出现时NaN,通常意味着您的成本/错误函数已经爆炸。尝试范围内的东西[10^-3, 10^-5]。
随时随地看视频慕课网APP

相关分类

Java
我要回答