Package edu.uci.jforestsx.learning.boosting

Source Code of edu.uci.jforestsx.learning.boosting.GradientBoosting

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package edu.uci.jforestsx.learning.boosting;

import java.util.Arrays;
import java.util.Random;

import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.eval.EvaluationMetric;
import edu.uci.jforestsx.learning.LearningModule;
import edu.uci.jforestsx.learning.LearningUtils;
import edu.uci.jforestsx.learning.trees.Ensemble;
import edu.uci.jforestsx.learning.trees.Tree;
import edu.uci.jforestsx.learning.trees.TreeLeafInstances;
import edu.uci.jforestsx.learning.trees.regression.RegressionTree;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ConfigHolder;

/**
* @author Yasser Ganjisaffar <ganjisaffar at gmail dot com>
*/

public class GradientBoosting extends LearningModule {

  protected double[] trainPredictions;
  protected double[] validPredictions;

  protected double[] residuals;

  protected int numInstances;

  private int numSubModules;
  protected double learningRate;
  protected double samplingRate;
  protected double earlyStoppingTolerance;

  protected Sample curTrainSet;
  protected Sample curValidSet;
  protected int curIteration;

  protected double bestValidationMeasurement;
  protected boolean printIntermediateValidMeasurements;
  protected EvaluationMetric evaluationMetric;

  protected Random rnd;

  public GradientBoosting(String algorithmName) {
    super(algorithmName);
  }
 
  public GradientBoosting() {
    super("GradientBoosting");
  }

  public void init(ConfigHolder configHolder, int maxNumTrainInstances, int maxNumValidInstances, EvaluationMetric evaluationMetric) throws Exception {
    this.evaluationMetric = evaluationMetric;
    GradientBoostingConfig gradientBoostingConfig = configHolder.getConfig(GradientBoostingConfig.class);
    this.numSubModules = gradientBoostingConfig.numTrees;
    this.learningRate = gradientBoostingConfig.learningRate;
    this.samplingRate = gradientBoostingConfig.samplingRate;
    this.earlyStoppingTolerance = gradientBoostingConfig.earlyStoppingTolerance;

    trainPredictions = new double[maxNumTrainInstances];
    residuals = new double[maxNumTrainInstances];
    validPredictions = new double[maxNumValidInstances];

    TrainingConfig trainingConfig = configHolder.getConfig(TrainingConfig.class);
    printIntermediateValidMeasurements = trainingConfig.printIntermediateValidMeasurements;
    rnd = new Random(trainingConfig.randomSeed);
  }

  protected void preprocess() {
    Arrays.fill(trainPredictions, 0, curTrainSet.size, 0);
    Arrays.fill(validPredictions, 0, curValidSet.size, 0);
  }

  @Override
  public Ensemble learn(Sample trainSet, Sample validSet) throws Exception {
    curTrainSet = trainSet;
    curValidSet = validSet;

    preprocess();

    Ensemble ensemble = new Ensemble();
    bestValidationMeasurement = Double.NaN;
    int earlyStoppingIteration = 0;
    int bestIteration = 0;
    int[] treeCounts = new int[numSubModules];
    subLearner.setTreeWeight(treeWeight);
    for (curIteration = 1; curIteration <= numSubModules; curIteration++) {
      Sample subLeanerSample = getSubLearnerSample();
      Ensemble subEnsemble = subLearner.learn(subLeanerSample, validSet);
      if (subEnsemble == null) {
        break;
      }
      for (int t = 0; t < subEnsemble.getNumTrees(); t++) {
        Tree tree = subEnsemble.getTreeAt(t);
        ensemble.addTree(tree, subEnsemble.getWeightAt(t));
        if (validSet != null) {
          LearningUtils.updateScores(validSet, validPredictions, ((RegressionTree) tree), 1.0);
        }
      }
      treeCounts[curIteration - 1] = ensemble.getNumTrees();

      if (validSet == null) {
        earlyStoppingIteration = curIteration;
      } else {
        double validMeasurement = getValidMeasurement();
        if (evaluationMetric.isFirstBetter(validMeasurement, bestValidationMeasurement, earlyStoppingTolerance)) {
          earlyStoppingIteration = curIteration;
          if (evaluationMetric.isFirstBetter(validMeasurement, bestValidationMeasurement, 0)) {
            bestValidationMeasurement = validMeasurement;
            bestIteration = curIteration;
          }
        }

        // FIXME
        if (curIteration - bestIteration > 100) {
          break;
        }

        if (printIntermediateValidMeasurements) {
          printTrainAndValidMeasurement(curIteration, validMeasurement, getTrainMeasurement(), evaluationMetric);
        }
      }
      onIterationEnd();
    }

    if (earlyStoppingIteration > 0) {
      int treesToKeep = treeCounts[earlyStoppingIteration - 1];
      int treesToDelete = ensemble.getNumTrees() - treesToKeep;
      ensemble.removeLastTrees(treesToDelete);
    }

    onLearningEnd();
    return ensemble;
  }

  @Override
  public double getValidationMeasurement() {
    return bestValidationMeasurement;
  }

  protected double getValidMeasurement() throws Exception {
    return curValidSet.evaluate(validPredictions, evaluationMetric);
  }

  protected double getTrainMeasurement() throws Exception {
    return curTrainSet.evaluate(trainPredictions, evaluationMetric);
  }

  protected Sample getSubLearnerSample() {
    for (int i = 0; i < curTrainSet.size; i++) {
      residuals[i] = curTrainSet.targets[i] - trainPredictions[i];
    }
    Sample subLearnerSample = curTrainSet.getClone();
    subLearnerSample.targets = residuals;
    subLearnerSample = subLearnerSample.getRandomSubSample(samplingRate, rnd);
    return subLearnerSample;
  }

  protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
    ((RegressionTree) tree).multiplyLeafOutputs(learningRate);
  }

  @Override
  public void postProcess(Tree tree, TreeLeafInstances treeLeafInstances) {
    adjustOutputs(tree, treeLeafInstances);
    LearningUtils.updateScores(curTrainSet, trainPredictions, ((RegressionTree) tree), 1.0);
    postProcessScores();
  }

  protected void postProcessScores() {
    // Do nothing
  }

}
TOP

Related Classes of edu.uci.jforestsx.learning.boosting.GradientBoosting

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.