I am trying to implement FFNN in Java with backpropagation and don’t know what I am doing wrong. It worked when I only had one neuron in the network, but I wrote another class to handle large networks, and nothing converges. This seems like a problem in mathematics - more precisely, my implementation of mathematics - but I checked it several times and I can not find anything bad. That should work.
Node class:
package arr; import util.ActivationFunction; import util.Functions; public class Node { public ActivationFunction f; public double output; public double error; private double sumInputs; private double sumErrors; public Node(){ sumInputs = 0; sumErrors = 0; f = Functions.SIG; output = 0; error = 0; } public Node(ActivationFunction func){ this(); this.f = func; } public void addIW(double iw){ sumInputs += iw; } public void addIW(double input, double weight){ sumInputs += (input*weight); } public double calculateOut(){ output = f.eval(sumInputs); return output; } public void addEW(double ew){ sumErrors+=ew; } public void addEW(double error, double weight){ sumErrors+=(error*weight); } public double calculateError(){ error = sumErrors * f.deriv(sumInputs); return error; } public void resetValues(){ sumErrors = 0; sumInputs = 0; } }
LineNetwork Class:
package arr; import util.Functions; public class LineNetwork { public double[][][] weights; //layer of node to, # of node to, # of node from public Node[][] nodes; //layer, # public double lc; public LineNetwork(){ weights = new double[2][][]; weights[0] = new double[2][1]; weights[1] = new double[1][3]; initializeWeights(); nodes = new Node[2][]; nodes[0] = new Node[2]; nodes[1] = new Node[1]; initializeNodes(); lc = 1; } private void initializeWeights(){ for(double[][] layer: weights) for(double[] curNode: layer) for(int i=0; i<curNode.length; i++) curNode[i] = Math.random()/10; } private void initializeNodes(){ for(Node[] layer: nodes) for(int i=0; i<layer.length; i++) layer[i] = new Node(); nodes[nodes.length-1][0].f = Functions.HSF; } public double feedForward(double[] inputs) { for(int j=0; j<nodes[0].length; j++) nodes[0][j].addIW(inputs[j], weights[0][j][0]); double[] outputs = new double[nodes[0].length]; for(int i=0; i<nodes[0].length; i++) outputs[i] = nodes[0][i].calculateOut(); for(int l=1; l<nodes.length; l++){ for(int i=0; i<nodes[l].length; i++){ for(int j=0; j<nodes[l-1].length; j++) nodes[l][i].addIW( outputs[j], weights[l][i][j]); nodes[l][i].addIW(weights[l][i][weights[l][i].length-1]); } outputs = new double[nodes[l].length]; for(int i=0; i<nodes[l].length; i++) outputs[i] = nodes[l][i].calculateOut(); } return outputs[0]; } public void backpropagate(double[] inputs, double expected) { nodes[nodes.length-1][0].addEW(expected-nodes[nodes.length-1][0].output); for(int l=nodes.length-2; l>=0; l--){ for(Node n: nodes[l+1]) n.calculateError(); for(int i=0; i<nodes[l].length; i++) for(int j=0; j<nodes[l+1].length; j++) nodes[l][i].addEW(nodes[l+1][j].error, weights[l+1][j][i]); for(int j=0; j<nodes[l+1].length; j++){ for(int i=0; i<nodes[l].length; i++) weights[l+1][j][i] += nodes[l][i].output*lc*nodes[l+1][j].error; weights[l+1][j][nodes[l].length] += lc*nodes[l+1][j].error; } } for(int i=0; i<nodes[0].length; i++){ weights[0][i][0] += inputs[i]*lc*nodes[0][i].calculateError(); } } public double train(double[] inputs, double expected) { double r = feedForward(inputs); backpropagate(inputs, expected); return r; } public void resetValues() { for(Node[] layer: nodes) for(Node n: layer) n.resetValues(); } public static void main(String[] args) { LineNetwork ln = new LineNetwork(); System.out.println(str2d(ln.weights[0])); for(int i=0; i<10000; i++){ double[] in = {Math.round(Math.random()),Math.round(Math.random())}; int out = 0; if(in[1]==1 ^ in[0] ==1) out = 1; ln.resetValues(); System.out.print(i+": {"+in[0]+", "+in[1]+"}: "+out+" "); System.out.println((int)ln.train(in, out)); } System.out.println(str2d(ln.weights[0])); } private static String str2d(double[][] a){ String str = "["; for(double[] arr: a) str = str + str1d(arr) + ",\n"; str = str.substring(0, str.length()-2)+"]"; return str; } private static String str1d(double[] a){ String str = "["; for(double d: a) str = str+d+", "; str = str.substring(0, str.length()-2)+"]"; return str; } }
Quick explanation of the structure: each node has an activation function f; f.eval evaluates a function, and f.deriv evaluates its derivative. Functions.SIG is a standard sigmoid function, and Functions.HSF is a Heaviside step function. To set the inputs of a function, you call addIW with a value that already includes the weight of the previous output. A similar operation is performed with the reverse extension using addEW . The nodes are organized in a 2d array, and the scales are organized separately in a 3d array, as described.
I understand that this can be a lot to ask - and I certainly understand how many Java conventions this code breaks, but I appreciate any help anyone can offer.
EDIT: since this question and my code are such gigantic walls of text, if there is a line in brackets with a lot of complex expressions that you don’t want to parse, add a comment or something asking me and I try to answer it as quickly as possible.
EDIT 2: The specific problem here is that this network does not converge on XOR. Here are some results to illustrate this:
9995: {1.0, 0.0}: 1 1
9996: {0.0, 1.0}: 1 1
9997: {0.0, 0.0}: 0 1
9998: {0,0, 1,0}: 1 0
9999: {0.0, 1.0}: 1 1
Each line has the format TEST NUMBER: {INPUTS}: EXPECTED ACTUAL The network calls train with each test, so this network is accessed 10,000 times.
Here are two additional classes if someone wants to run it:
package util; public class Functions { public static final ActivationFunction LIN = new ActivationFunction(){ public double eval(double x) { return x; } public double deriv(double x) { return 1; } }; public static final ActivationFunction SIG = new ActivationFunction(){ public double eval(double x) { return 1/(1+Math.exp(-x)); } public double deriv(double x) { double ev = eval(x); return ev * (1-ev); } }; public static final ActivationFunction HSF = new ActivationFunction(){ public double eval(double x) { if(x>0) return 1; return 0; } public double deriv(double x) { return (1); } }; } package util; public interface ActivationFunction { public double eval(double x); public double deriv(double x); }
Now it is even longer. Heck.