package dl;
import java.util.Arrays;
import java.util.Random;
* Ann layer.
public class AnnLayer {
* The number of input.
int numInput;
* The number of output.
int numOutput;
* The learning rate.
double learningRate;
* The mobp.
double mobp;
* The weight matrix.
double[][] weights;
* The delta weight matrix.
double[][] deltaWeights;
* Error on nodes.
double[] errors;
* The inputs.
double[] input;
* The outputs.
double[] output;
* The output after activate.
double[] activatedOutput;
* The inputs.
Activator activator;
* The inputs.
Random random = new Random();
* The first constructor.
* @param paraActivator
* The activator.
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
double paraLearningRate, double paraMobp) {
numInput = paraNumInput;
numOutput = paraNumOutput;
learningRate = paraLearningRate;
mobp = paraMobp;
weights = new double[numInput + 1][numOutput];
deltaWeights = new double[numInput + 1][numOutput];
for (int i = 0; i < numInput + 1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j] = random.nextDouble();
} // Of for j
} // Of for i
errors = new double[numInput];
input = new double[numInput];
output = new double[numOutput];
activatedOutput = new double[numOutput];
activator = new Activator(paraActivator);
}// Of the first constructor
* Set parameters for the activator.
* @param paraAlpha
* Alpha. Only valid for certain types.
* @param paraBeta
* Beta.
* @param paraAlpha
* Alpha.
public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
}// Of setParameters
* Forward prediction.
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
public double[] forward(double[] paraInput) {
//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
// Copy data.
for (int i = 0; i < numInput; i++) {
input[i] = paraInput[i];
} // Of for i
// Calculate the weighted sum for each output.
for (int i = 0; i < numOutput; i++) {
output[i] = weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i] += input[j] * weights[j][i];
} // Of for j
activatedOutput[i] = activator.activate(output[i]);
} // Of for i
return activatedOutput;
}// Of forward
* Back propagation and change the edge weights.
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
public double[] backPropagation(double[] paraErrors) {
//Step 1. Adjust the errors.
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
}//Of for i
//Step 2. Compute current errors.
for (int i = 0; i < numInput; i++) {
errors[i] = 0;
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j] = mobp * deltaWeights[i][j]
+ learningRate * paraErrors[j] * input[i];
weights[i][j] += deltaWeights[i][j];
} // Of for j
} // Of for i
for (int j = 0; j < numOutput; j++) {
deltaWeights[numInput][j] = mobp * deltaWeights[numInput][j] + learningRate * paraErrors[j];
weights[numInput][j] += deltaWeights[numInput][j];
} // Of for j
return errors;
}// Of backPropagation
* I am the last layer, set the errors.
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
public double[] getLastLayerErrors(double[] paraTarget) {
double[] resultErrors = new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
} // Of for i
return resultErrors;
}// Of getLastLayerErrors
* Show me.
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}// Of toString
* Unit test.
public static void unitTest() {
AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = { 1, 4 };
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}// Of unitTest
* Test the algorithm.
public static void main(String[] args) {
}// Of main
}// Of class AnnLayer