Package tv.floe.metronome.classification.neuralnetworks.eval

Source Code of tv.floe.metronome.classification.neuralnetworks.eval.ModelEvaluator

package tv.floe.metronome.classification.neuralnetworks.eval;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Properties;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;
import tv.floe.metronome.classification.neuralnetworks.networks.MultiLayerPerceptronNetwork;
import tv.floe.metronome.io.records.MetronomeRecordFactory;

public class ModelEvaluator {


  private static String model_path = "/tmp/nn.wine.model";
  private static String schema = "i:64 | o:10";
  private static String src_input_data_path = "src/test/resources/data/nist/HandWritingDigits/optdigits.tra.mne";


  //private static Properties props = null;
  //private static String app_properties_file = "";
  private static JobConf defaultConf = new JobConf();
  private static FileSystem localFs = null;
  static {
    try {
      defaultConf.set("fs.defaultFS", "file:///");
      localFs = FileSystem.getLocal(defaultConf);
    } catch (IOException e) {
      throw new RuntimeException("init failure", e);
    }
  }
 
 
  public static MultiLayerPerceptronNetwork loadModelFromDisk(String model_path) {
   
    MultiLayerPerceptronNetwork nnet = null;
   
    try {

      Path out = new Path(model_path.trim());
      FileSystem fs =
            out.getFileSystem(defaultConf);
      /*
      FSDataOutputStream fos;

      fos = fs.create(out);
        //LOG.info("Writing master results to " + out.toString());
        mnode.complete(fos);
       
        fos.flush();
        fos.close();
*/
      //BufferedWriter bw = new BufferedWriter(new FileWriter(output_path));
      //master.complete( bw );
     
      //bw.close();
      System.out.println("Loading Model: " + out.toString().replaceAll("file:", ""));
      FileInputStream fis = new FileInputStream( out.toString().replaceAll("file:", "") );
        //ObjectInputStream ois = new ObjectInputStream(fis);
       
        //oos.writeObject( master. );
        //master.complete(oos);
       
        //nnet = (MultiLayerPerceptronNetwork) ois.readObject();
        byte[] b = new byte[1024 * 200];
        int bytes_read = fis.read(b);
       
        //System.out.println("bytes read: " + bytes_read);
       
        nnet = (MultiLayerPerceptronNetwork) NeuralNetwork.Deserialize(b);
       
        //oos.flush();
        //ois.close();

     
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
   
   
    return nnet;
   
  }
 
  public static void scoreNeuralNetworkModel(String modelFileLocation, String schema, String records_file) throws Exception {
   
    // load model
    MultiLayerPerceptronNetwork mlp = loadModelFromDisk(modelFileLocation);
   
    //Utils.PrintNeuralNetwork(  mlp );
   
    //System.out.println("Layers: " + mlp.getLayersCount());
   
    MetronomeRecordFactory rec_factory = new MetronomeRecordFactory(schema);
 
     
      String recs = records_file;
      Vector v_in_0 = new RandomAccessSparseVector(rec_factory.getInputVectorSize());
      Vector v_out_0 = new RandomAccessSparseVector(rec_factory.getOutputVectorSize());
     
     
     
      int total_recs = 0;
      int correct = 0;
     
      BufferedReader br = new BufferedReader(new FileReader(recs));
      String line;
      while ((line = br.readLine()) != null) {
       
       

        rec_factory.vectorizeLine( line, v_in_0, v_out_0 );
       
        if (isCorrect(mlp, v_in_0, v_out_0) ) {
          correct++;
        }
        total_recs++;
       
        //System.out.println("rec > " + line);
        //System.out.println("out > " + v_out_0.toString());
       
      }
      br.close();   
     
      System.out.println("Total: " + total_recs);
      System.out.println("Correct: " + correct);
   
      double percent = (double)correct / (double)total_recs;
   
      System.out.println("Percent: " + percent);
     
  }
 
  public static int getLargestIndex(Vector v) {
   
    double largest_val = 0;
    int largest_index = 0;
   
    for (int x = 0; x < v.size(); x++) {
     
      if ( v.get(x) > largest_val ) {
        largest_val = v.get(x);
        largest_index = x;
      }
     
    }
   
    return largest_index;
  }
 
  public static boolean isCorrect(NeuralNetwork mlp, Vector input_vec, Vector answer_vec) throws Exception {
   
    mlp.setInputVector( input_vec );
    mlp.calculate();
        Vector networkOutput = mlp.getOutputVector();

       
   
        Vector v_answer_selected = new RandomAccessSparseVector(answer_vec.size());
       
        v_answer_selected.set(getLargestIndex(networkOutput), 1.0);
       
        //System.out.println( "> out: " + v_answer_selected.toString() );
       
        if ( answer_vec.equals(v_answer_selected)) {
          return true;
        }
       
       
    return false;
  }
 
  public static void parsePropertiesFile(String app_properties_file) throws Exception {
   
    Properties props = new Properties();
    // Configuration conf = getConf();

    try {
      FileInputStream fis = new FileInputStream(app_properties_file);
      props.load(fis);
      fis.close();
    } catch (FileNotFoundException ex) {
      // throw ex; // TODO: be nice
      System.out.println(ex);
    } catch (IOException ex) {
      // throw ex; // TODO: be nice
      System.out.println(ex);
   
   
    model_path = props.getProperty("app.output.path");
    if (null == model_path) {
      throw new Exception("Can't find the model output path in the properites file!");
    }

    schema = props.getProperty("tv.floe.metronome.neuralnetwork.conf.InputRecordSchema");
    if (null == schema) {
      throw new Exception("Can't find the input record schema in the properites file!");
    }

    src_input_data_path = props.getProperty("tv.floe.metronome.neuralnetwork.conf.evaluate.dataset.path");
    if (null == src_input_data_path) {
      throw new Exception("Can't find the eval/test recordset in the properites file!");
    }
/*   
    System.out.println("Conf ------------");
    System.out.println("Model: " + model_path);
    System.out.println("Schema: " + schema);
    System.out.println("Test Dataset: " + src_input_data_path);
  */ 
   
  }
 
  public static void eval(String propertiesFile) throws Exception {
   
    parsePropertiesFile(propertiesFile);
   
    scoreNeuralNetworkModel( model_path, schema, src_input_data_path );
   
  }
}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.eval.ModelEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.