Package tv.floe.metronome.classification.neuralnetworks.iterativereduce.iris

Source Code of tv.floe.metronome.classification.neuralnetworks.iterativereduce.iris.IrisDatasetUtils

package tv.floe.metronome.classification.neuralnetworks.iterativereduce.iris;

import static org.junit.Assert.assertEquals;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;

import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import tv.floe.metronome.types.Pair;
import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;
import tv.floe.metronome.classification.neuralnetworks.learning.BackPropogationLearningAlgorithm;
import tv.floe.metronome.classification.neuralnetworks.networks.MultiLayerPerceptronNetwork;
import tv.floe.metronome.classification.neuralnetworks.utils.Utils;
import tv.floe.metronome.io.records.MetronomeRecordFactory;

public class IrisDatasetUtils {


  private static JobConf defaultConf = new JobConf();
  private static FileSystem localFs = null;
  static {
    try {
      defaultConf.set("fs.defaultFS", "file:///");
      localFs = FileSystem.getLocal(defaultConf);
    } catch (IOException e) {
      throw new RuntimeException("init failure", e);
    }
  }
 
 
  public static MultiLayerPerceptronNetwork loadModelFromDisk() {
   
    MultiLayerPerceptronNetwork nnet = null;
   
    try {

      Path out = new Path("/tmp/nn.model");
      FileSystem fs =
            out.getFileSystem(defaultConf);
      /*
      FSDataOutputStream fos;

      fos = fs.create(out);
        //LOG.info("Writing master results to " + out.toString());
        mnode.complete(fos);
       
        fos.flush();
        fos.close();
*/
      //BufferedWriter bw = new BufferedWriter(new FileWriter(output_path));
      //master.complete( bw );
     
      //bw.close();
     
      FileInputStream fis = new FileInputStream( out.toString() );
        //ObjectInputStream ois = new ObjectInputStream(fis);
       
        //oos.writeObject( master. );
        //master.complete(oos);
       
        //nnet = (MultiLayerPerceptronNetwork) ois.readObject();
        byte[] b = new byte[1024 * 16];
        int bytes_read = fis.read(b);
       
        System.out.println("bytes read: " + bytes_read);
       
        nnet = (MultiLayerPerceptronNetwork) NeuralNetwork.Deserialize(b);
       
        //oos.flush();
        //ois.close();

     
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
   
   
    return nnet;
   
  }
 
 
  public static Pair<Matrix, Matrix> getIrisAsDataset() throws IOException {
   

    MetronomeRecordFactory rec_factory = new MetronomeRecordFactory("i:4 | o:3");
 
     
      String recs = "src/test/resources/data/uci/iris/iris_data_normalised.mne";

      Vector v_in_0 = new RandomAccessSparseVector(rec_factory.getInputVectorSize());
      Vector v_out_0 = new RandomAccessSparseVector(rec_factory.getOutputVectorSize());
     
     
//      ArrayList<Double[]> input_recs = new ArrayList<Double[]>();

      // not very well built but in a hurry
      BufferedReader br = new BufferedReader(new FileReader(recs));
     
     
      int row_count = 0;
      String line;
      while ((line = br.readLine()) != null) {
        row_count++;
      }

      Matrix input_matrix = new DenseMatrix( row_count, rec_factory.getInputVectorSize() );
      Matrix output_matrix = new DenseMatrix( row_count, rec_factory.getOutputVectorSize() );
     
      int row_num = 0;
      //String line;
      br.close();
      br = new BufferedReader(new FileReader(recs));
      while ((line = br.readLine()) != null) {
       

        rec_factory.vectorizeLine( line, v_in_0, v_out_0 );
       
        //System.out.println("line: " + line);
        //System.out.println("v in: " + v_in_0.toString());
       
        input_matrix.viewRow( row_num ).assign( v_in_0 );
       
        output_matrix.viewRow( row_num ).assign( v_out_0 );
       
        row_num++;
       
      }
      br.close()
     
      return new Pair< Matrix, Matrix >( input_matrix, output_matrix );
   
  }
 
 
  public static void scoreIrisNeuralNetworkModel(String modelFileLocation) throws Exception {
   
    // load model
    MultiLayerPerceptronNetwork mlp = loadModelFromDisk();
   
    //Utils.PrintNeuralNetwork(  mlp );
   
    //System.out.println("Layers: " + mlp.getLayersCount());
   
    MetronomeRecordFactory rec_factory = new MetronomeRecordFactory("i:4 | o:3");
 
     
      String recs = "src/test/resources/data/uci/iris/iris_data_normalised.mne";
      Vector v_in_0 = new RandomAccessSparseVector(rec_factory.getInputVectorSize());
      Vector v_out_0 = new RandomAccessSparseVector(rec_factory.getOutputVectorSize());
     
     
     
      int total_recs = 0;
      int correct = 0;
     
      BufferedReader br = new BufferedReader(new FileReader(recs));
      String line;
      while ((line = br.readLine()) != null) {
       
       

        rec_factory.vectorizeLine( line, v_in_0, v_out_0 );
       
        if (isCorrect(mlp, v_in_0, v_out_0) ) {
          correct++;
        }
        total_recs++;
       
        //System.out.println("rec > " + line);
        //System.out.println("out > " + v_out_0.toString());
       
      }
      br.close();   
     
      BackPropogationLearningAlgorithm bp = ((BackPropogationLearningAlgorithm)mlp.getLearningRule());
     
      System.out.println("Avg RMSE: " + bp.getMetrics().getLastRMSE());
      System.out.println("Total: " + total_recs);
      System.out.println("Correct: " + correct);
   
      double percent = (double)correct / (double)total_recs;
   
      System.out.println("Percent: " + percent);
     
  }
 
  public static int getLargestIndex(Vector v) {
   
    double largest_val = 0;
    int largest_index = 0;
   
    for (int x = 0; x < v.size(); x++) {
     
      if ( v.get(x) > largest_val ) {
        largest_val = v.get(x);
        largest_index = x;
      }
     
    }
   
    return largest_index;
  }
 
  public static boolean isCorrect(NeuralNetwork mlp, Vector input_vec, Vector answer_vec) throws Exception {
   
    mlp.setInputVector( input_vec );
    mlp.calculate();
        Vector networkOutput = mlp.getOutputVector();

       
   
        Vector v_answer_selected = new RandomAccessSparseVector(answer_vec.size());
       
        v_answer_selected.set(getLargestIndex(networkOutput), 1.0);
       
        //System.out.println( "> out: " + v_answer_selected.toString() );
       
        if ( answer_vec.equals(v_answer_selected)) {
          return true;
        }
       
       
    return false;
  }
 
  public static void main(String[] args) throws Exception {
   
    scoreIrisNeuralNetworkModel("");
   
    //Vector
   
  }
 

}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.iterativereduce.iris.IrisDatasetUtils

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.