Package edu.uci.jforestsx.applications

Source Code of edu.uci.jforestsx.applications.ClassificationApp

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package edu.uci.jforestsx.applications;

import java.io.InputStream;
import java.util.Properties;
import java.util.Random;

import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.dataset.Dataset;
import edu.uci.jforestsx.dataset.DatasetLoader;
import edu.uci.jforestsx.eval.AUC;
import edu.uci.jforestsx.eval.Accuracy;
import edu.uci.jforestsx.eval.BalancedYoundenIndex;
import edu.uci.jforestsx.eval.EvaluationMetric;
import edu.uci.jforestsx.eval.RMSE;
import edu.uci.jforestsx.learning.LearningModule;
import edu.uci.jforestsx.learning.LearningProgressListener;
import edu.uci.jforestsx.learning.boosting.GradientBoosting;
import edu.uci.jforestsx.learning.classification.GradientBoostingBinaryClassifier;
import edu.uci.jforestsx.learning.trees.Ensemble;
import edu.uci.jforestsx.learning.trees.decision.RandomForest;
import edu.uci.jforestsx.learning.trees.regression.RegressionTreeLearner;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ConfigHolder;
import edu.uci.jforestsx.util.Constants;
import edu.uci.jforestsx.util.IOUtils;
import edu.uci.jforestsx.util.Timer;
import edu.uci.jforestsx.util.concurrency.BlockingThreadPoolExecutor;

/**
* @author Yasser Ganjisaffar <ganjisaffar at gmail dot com>
*/

public class ClassificationApp {

  protected Dataset trainDataset;
  protected Dataset validDataset;
  protected LearningModule topLearner;

  protected Sample trainSet;
  protected Sample validSet;

  protected IOUtils ioUtils;
  protected EvaluationMetric evaluationMetric;
  protected Random rnd;

  protected TrainingConfig trainingConfig;
  protected ConfigHolder configHolder;
 
  protected LearningProgressListener progressListener = null;
 
  public ClassificationApp() {
    initIOUtils();
  }

  public void setProgressListener(LearningProgressListener progressListener) {
    this.progressListener = progressListener;
  }

  protected void initIOUtils() {
    if (ioUtils == null) {
      ioUtils = new IOUtils();
    }
  }

  protected void loadConfig() {
    trainingConfig = new TrainingConfig();
    trainingConfig.init(configHolder);
  }

  protected void init() throws Exception {
    BlockingThreadPoolExecutor.init(trainingConfig.numThreads);

    initDataset(trainDataset);
    if (validSet != null) {
      initDataset(validSet.dataset);
    }

    if (trainingConfig.featureNamesFilename != null) {
      trainDataset.loadFeatureNamesFromExternalResource(ioUtils.getInputStream(trainingConfig.featureNamesFilename));
    }
  }

  protected LearningModule getLearningModule(String name) throws Exception {
    int maxNumTrainInstances = trainDataset.numInstances;
    int maxNumValidInstances = (validDataset != null ? validDataset.numInstances : trainDataset.numInstances);
    if (name.equals("GradientBoostingBinaryClassifier")) {
      GradientBoostingBinaryClassifier learner = new GradientBoostingBinaryClassifier();
      learner.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
      return learner;
    } else if (name.equals("GradientBoosting")) {
      GradientBoosting learner = new GradientBoosting();
      learner.init(configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
      return learner;
    } else if (name.equals("RegressionTree")) {
      RegressionTreeLearner learner = new RegressionTreeLearner();
      learner.init(trainDataset, configHolder, maxNumTrainInstances);
      return learner;
    } else if (name.equals("RandomForest")) {
      RandomForest learner = new RandomForest();
      learner.init(trainDataset, configHolder, maxNumTrainInstances, maxNumValidInstances, evaluationMetric);
      return learner;
    } else {
      throw new Exception("Unknown algorithm: " + name);
    }
  }

  protected EvaluationMetric getEvaluationMetric(String name) throws Exception {
    if (name.equals("AUC")) {
      return new AUC();
    } else if (name.equals("RMSE")) {
      return new RMSE();
    } else if (name.equals("Accuracy")) {
      return new Accuracy();
    } else if (name.equals("BalancedYoundenIndex")) {
      return new BalancedYoundenIndex();
    } else {
      throw new Exception("Unknown evaluation metric: " + name);
    }
  }

  protected void createLearner() throws Exception {
    String[] parts = trainingConfig.learningAlgorithm.split("-");
    topLearner = getLearningModule(parts[0]);
    if (progressListener != null) {
      topLearner.setProgressListener(progressListener);
    }
    LearningModule curModule = topLearner;
    for (int i = 1; i < parts.length; i++) {
      LearningModule newModule = getLearningModule(parts[i]);
      if (progressListener != null) {
        newModule.setProgressListener(progressListener);
      }
      curModule.setSubModule(newModule);
      curModule = newModule;
    }
  }

  protected void loadDataset(InputStream in, Dataset dataset) throws Exception {
    DatasetLoader.load(in, dataset);
  }

  public void loadDataset(String uri, Dataset dataset) throws Exception {
    if (dataset != null && dataset.uri != null && dataset.uri.equals(uri)) {
      // This data set is already loaded.
      dataset.needsInitialization = false;
      return;
    }
    InputStream in = ioUtils.getInputStream(uri);
    loadDataset(in, dataset);
    dataset.uri = uri;
    dataset.needsInitialization = true;
    in.close();
  }

  protected double getMeasurement(double[] scores, Sample sample) throws Exception {
    return sample.evaluate(scores, evaluationMetric, 1.0);
  }

  protected Dataset newDataset() {
    return new Dataset();
  }

  protected void initDataset(Dataset dataset) throws Exception {
    // Will be overridden by subclasses
  }

  protected Sample createSample(Dataset dataset, boolean trainSample) {
    return new Sample(dataset);
  }

  protected int getMaxTrainInstances() {
    return trainDataset.numInstances;
  }

  public Ensemble run(Properties configProperties) {
    try {
      configHolder = new ConfigHolder(configProperties);
      loadConfig();
      if (!trainingConfig.validate(ioUtils)) {
        System.out.println("Error: " + trainingConfig.getErrorMessage());
        return null;
      }
      rnd = new Random(trainingConfig.randomSeed);

      System.out.println("Loading datasets...");
      if (trainDataset == null) {
        trainDataset = newDataset();
      }
      loadDataset(trainingConfig.trainFilename, trainDataset);
      int maxInstances = getMaxTrainInstances();

      if (trainingConfig.validFilename != null) {
        if (validDataset == null) {
          validDataset = newDataset();
        }
        loadDataset(trainingConfig.validFilename, validDataset);
        if (validDataset.numInstances > maxInstances) {
          maxInstances = validDataset.numInstances;
        }
      } else {
        validDataset = null;
      }
      System.out.println("Finished loading datasets.");

      Constants.init(maxInstances);

      Sample allTrainSample = createSample(trainDataset, true);
      trainSet = allTrainSample.getRandomSubSample(trainingConfig.trainFraction, rnd);

      if (validDataset != null) {
        validSet = createSample(validDataset, false);
        if (trainingConfig.validFraction < 1.0) {
          validSet = validSet.getRandomSubSample(trainingConfig.validFraction, rnd);
        }
      } else if (trainingConfig.validOutOfTrain) {
        validSet = allTrainSample.getOutOfSample(trainSet);
      }

      init();

      evaluationMetric = getEvaluationMetric(trainingConfig.evaluationMetric);

      createLearner();

      Timer timer = new Timer();
      timer.start();
      Ensemble ensemble = topLearner.learn(trainSet, validSet);
      System.out.println("Time taken to build model: " + (timer.getElapsedMillis() / 1000.0) + " seconds.");
      return ensemble;

    } catch (Exception e) {
      e.printStackTrace();
    }

    return null;
  }

  public int getTrainingRandomSeed() {
    return trainingConfig.randomSeed;
  }

  public static void shutdown() {
    BlockingThreadPoolExecutor executor = BlockingThreadPoolExecutor.getInstance();
    if (executor != null && !executor.isShutdown()) {
      executor.shutdownNow();
    }
  }

  public EvaluationMetric getEvaluationMetric() {
    return evaluationMetric;
  }

  public double getValidMeasurement() throws Exception {
    return topLearner.getValidationMeasurement();
  }

  public Sample getTrainSample() {
    return trainSet;
  }

  public Sample getValidSample() {
    return validSet;
  }
 
  public ConfigHolder getConfigHolder() {
    return configHolder;
  }
 
  public IOUtils getIOUtils() {
    return ioUtils;
  }
 
  public LearningProgressListener getProgressListener() {
    return progressListener;
  }
}
TOP

Related Classes of edu.uci.jforestsx.applications.ClassificationApp

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.