代码:
package dl;
import java.util.Arrays;
import java.util.Random;
/**
* Ann layer.
*/
public class AnnLayer {
/**
* The number of input.
*/
int numInput;
/**
* The number of output.
*/
int numOutput;
/**
* The learning rate.
*/
double learningRate;
/**
* The mobp.
*/
double mobp;
/**
* The weight matrix.
*/
double[][] weights;
/**
* The delta weight matrix.
*/
double[][] deltaWeights;
/**
* Error on nodes.
*/
double[] errors;
/**
* The inputs.
*/
double[] input;
/**
* The outputs.
*/
double[] output;
/**
* The output after activate.
*/
double[] activatedOutput;
/**
* The inputs.
*/
Activator activator;
/**
* The inputs.
*/
Random random = new Random();
/**
*********************
* The first constructor.
*
* @param paraActivator
* The activator.
*********************
*/
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
double paraLearningRate, double paraMobp) {
numInput = paraNumInput;
numOutput = paraNumOutput;
learningRate = paraLearningRate;
mobp = paraMobp;
weights = new double[numInput + 1][numOutput];
deltaWeights = new double[numInput + 1][numOutput];
for (int i = 0; i < numInput + 1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j] = random.nextDouble();
} // Of for j
} // Of for i
errors = new double[numInput];
input = new double[numInput];
output = new double[numOutput];
activatedOutput = new double[numOutput];
activator = new Activator(paraActivator);
}// Of the first constructor
/**
********************
* Set parameters for the activator.
*
* @param paraAlpha
* Alpha. Only valid for certain types.
* @param paraBeta
* Beta.
* @param paraAlpha
* Alpha.
********************
*/
public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
activator.setAlpha(paraAlpha);
activator.setBeta(paraBeta);
activator.setGamma(paraGamma);
}// Of setParameters
/**
********************
* Forward prediction.
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
// Copy data.
for (int i = 0; i < numInput; i++) {
input[i] = paraInput[i];
} // Of for i
// Calculate the weighted sum for each output.
for (int i = 0; i < numOutput; i++) {
output[i] = weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i] += input[j] * weights[j][i];
} // Of for j
activatedOutput[i] = activator.activate(output[i]);
} // Of for i
return activatedOutput;
}// Of forward
/**
********************
* Back propagation and change the edge weights.
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] backPropagation(double[] paraErrors) {
//Step 1. Adjust the errors.
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
}//Of for i
//Step 2. Compute current errors.
for (int i = 0; i < numInput; i++) {
errors[i] = 0;
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j] = mobp * deltaWeights[i][j]
+ learningRate * paraErrors[j] * input[i];
weights[i][j] += deltaWeights[i][j];
} // Of for j
} // Of for i
for (int j = 0; j < numOutput; j++) {
deltaWeights[numInput][j] = mobp * deltaWeights[numInput][j] + learningRate * paraErrors[j];
weights[numInput][j] += deltaWeights[numInput][j];
} // Of for j
return errors;
}// Of backPropagation
/**
********************
* I am the last layer, set the errors.
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] getLastLayerErrors(double[] paraTarget) {
double[] resultErrors = new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
} // Of for i
return resultErrors;
}// Of getLastLayerErrors
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}// Of toString
/**
********************
* Unit test.
********************
*/
public static void unitTest() {
AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = { 1, 4 };
System.out.println(tempLayer);
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}// Of unitTest
/**
********************
* Test the algorithm.
********************
*/
public static void main(String[] args) {
unitTest();
}// Of main
}// Of class AnnLayer
结果: