Package tv.floe.metronome.classification.neuralnetworks.iterativereduce.xor

Source Code of tv.floe.metronome.classification.neuralnetworks.iterativereduce.xor.TestTwoWorkersXOR_IR_NN

package tv.floe.metronome.classification.neuralnetworks.iterativereduce.xor;

import static org.junit.Assert.*;

import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Test;

import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;
import tv.floe.metronome.classification.neuralnetworks.iterativereduce.MasterNode;
import tv.floe.metronome.classification.neuralnetworks.iterativereduce.WorkerNode;
import tv.floe.metronome.irunit.IRUnitDriver;

public class TestTwoWorkersXOR_IR_NN {

  public static void PrintNeuralNetwork(NeuralNetwork nn) {
   
   
    for ( int x = 1; x < nn.getLayersCount(); x++ ) {
     
      System.out.println("Layer " + x );
     
      for ( int n = 0; n < nn.getLayerByIndex(x).getNeuronsCount(); n++ ) {
       
        System.out.println("\tNeuron " + n);
       
        for ( int c = 0; c < nn.getLayerByIndex(x).getNeuronAt(n).inConnections.size(); c++) {
         
          System.out.println("\t\tConnection " + c + ", weight: " + nn.getLayerByIndex(x).getNeuronAt(n).inConnections.get(c).getWeight().value );
         
        } // for each incoming connection
       
      } // for each neuron in layer
     
     
    } // for each layer
   
   
   
  }
 
  public static void scoreNeuralNetworkXor(NeuralNetwork mlp_network) throws Exception {
   
    System.out.println("Scoring NN for XOR Function ---------- ");
   
    Vector v0 = new DenseVector(2);
    v0.set(0, 0);
    v0.set(1, 0);
    Vector v0_out = new DenseVector(1);
    v0_out.set(0, 0);
    //xor_recs.add(v0);

    Vector v1 = new DenseVector(2);
    v1.set(0, 0);
    v1.set(1, 1);

    Vector v1_out = new DenseVector(1);
    v1_out.set(0, 1);
    //xor_recs.add(v1);

   
   
    Vector v2 = new DenseVector(2);
    v2.set(0, 1);
    v2.set(1, 0);

    Vector v2_out = new DenseVector(1);
    v2_out.set(0, 1);
    //xor_recs.add(v2);

   
   
    Vector v3 = new DenseVector(2);
    v3.set(0, 1);
    v3.set(1, 1);

    Vector v3_out = new DenseVector(1);
    v3_out.set(0, 0)
   
   
    mlp_network.setInputVector( v0 );
    mlp_network.calculate();
        Vector networkOutput = mlp_network.getOutputVector();

        System.out.println( "> out: 0 =? " + networkOutput.get(0) );
       
             
       
       
    mlp_network.setInputVector( v1 );
    mlp_network.calculate();
        Vector networkOutput_1 = mlp_network.getOutputVector();

        System.out.println( "> out: 1 =? " + networkOutput_1.get(0) );
                 

    mlp_network.setInputVector( v2 );
    mlp_network.calculate();
        Vector networkOutput_2 = mlp_network.getOutputVector();

        System.out.println( "> out: 1 =? " + networkOutput_2.get(0) );


    mlp_network.setInputVector( v3 );
    mlp_network.calculate();
        Vector networkOutput_3 = mlp_network.getOutputVector();

        System.out.println( "> out: 0 =? " + networkOutput_3.get(0) );
       
       
        assertEquals(0.05, networkOutput.get(0), 0.05);
        assertEquals(0.95, networkOutput_1.get(0), 0.05);
        assertEquals(0.95, networkOutput_2.get(0), 0.05);
        assertEquals(0.05, networkOutput_3.get(0), 0.05);
       
   
  }
 
  @Test
  public void testLearnXORFunctionViaIRNN_MLP() throws Exception {
   
    IRUnitDriver polr_ir = new IRUnitDriver("src/test/resources/run_profiles/unit_tests/nn/xor/app.unit_test.nn.xor.twoworkers.properties");
    polr_ir.Setup();

    polr_ir.SimulateRun();

   
    WorkerNode singleWorker = (WorkerNode)polr_ir.getWorker().get(0);
    WorkerNode secondWorker = (WorkerNode)polr_ir.getWorker().get(1);
   
    MasterNode master = (MasterNode) polr_ir.getMaster();
   
    System.out.println("\n\nComplete: ");
    //Utils.PrintVector( master.polr.getBeta().viewRow(0) );

//    this.scoreNeuralNetworkXor( master.master_nn );
   
    System.out.println("w1 > RMSE: " + singleWorker.lastRMSE );
    System.out.println("w2 > RMSE: " + secondWorker.lastRMSE );
   
//    this.scoreNeuralNetworkXor( singleWorker.nn );
//    PrintNeuralNetwork( singleWorker.nn );
   
//    this.scoreNeuralNetworkXor( secondWorker.nn );
//    PrintNeuralNetwork( secondWorker.nn );
   
   
  }

}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.iterativereduce.xor.TestTwoWorkersXOR_IR_NN

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.