Package edu.stanford.nlp.optimization

Source Code of edu.stanford.nlp.optimization.SGDMinimizer

package edu.stanford.nlp.optimization;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Timing;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Random;

/**
* In place Stochastic Gradient Descent Minimizer.
* <ul>
* <li> Follows weight decay and tuning of learning parameter of crfsgd of
*   Leon Bottou: http://leon.bottou.org/projects/sgd
* <li> Only supports L2 regularization (QUADRATIC)
* <li> Requires objective function to be an AbstractStochasticCachingDiffUpdateFunction.
* </ul>
* NOTE: unlike other minimizers, regularization is done in the minimizer, not the objective function.
*
* This class was previously called StochasticInPlaceMinimizer. This is now SGDMinimizer, and the old SGDMinimizer is now InefficientSGDMinimizer.
*
* @author Angel Chang
*/
public class SGDMinimizer<T extends Function> implements Minimizer<T>, HasEvaluators {

  protected double xscale, xnorm;
  protected double[] x;
  protected int t0;  // Initial stochastic iteration count
  protected final double sigma;
  protected double lambda;
  protected boolean quiet = false;
  private static final int DEFAULT_NUM_PASSES = 50;
  protected final int numPasses; //-1;
  protected int bSize = 1// NOTE: If bSize does not divide evenly into total number of samples,
                            // some samples may get accounted for twice in one pass
  private static final int DEFAULT_TUNING_SAMPLES = 1000;
  protected final int tuningSamples;

  protected Random gen = new Random(1);
  protected long maxTime = Long.MAX_VALUE;

  private int evaluateIters = 0;    // Evaluate every x iterations (0 = no evaluation)
  private Evaluator[] evaluators;  // separate set of evaluators to check how optimization is going


  public SGDMinimizer(double sigma, int numPasses)
  {
    this(sigma, numPasses, -1, 1);
  }

  public SGDMinimizer(double sigma, int numPasses, int tuningSamples) {
    this(sigma, numPasses, tuningSamples, 1);
  }

  public SGDMinimizer(double sigma, int numPasses, int tuningSamples, int batchSize)
  {
    this.bSize = batchSize;
    this.sigma = sigma;
    if (numPasses >= 0) {
      this.numPasses = numPasses;
    } else {
      this.numPasses = DEFAULT_NUM_PASSES;
      sayln("  SGDMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
    }
    if (tuningSamples > 0) {
      this.tuningSamples = tuningSamples;
    } else {
      this.tuningSamples = DEFAULT_TUNING_SAMPLES;
      sayln("  SGDMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
    }
  }

  public SGDMinimizer(LogPrior prior, int numPasses, int batchSize, int tuningSamples)
  {
    if (LogPrior.LogPriorType.QUADRATIC == prior.getType()) {
      sigma = prior.getSigma();
    } else {
      throw new RuntimeException("Unsupported prior type " + prior.getType());
    }
    if (numPasses >= 0) {
      this.numPasses = numPasses;
    } else {
      this.numPasses = DEFAULT_NUM_PASSES;
      sayln("  SGDMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
    }
    this.bSize = batchSize;
    if (tuningSamples > 0) {
      this.tuningSamples = tuningSamples;
    } else {
      this.tuningSamples = DEFAULT_TUNING_SAMPLES;
      sayln("  SGDMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
    }
  }

  public void shutUp() {
    this.quiet = true;
  }

  private static final NumberFormat nf = new DecimalFormat("0.000E0");

  protected String getName() {
    return "SGD_InPlace_b" + bSize + "_lambda" + nf.format(lambda);
  }

  @Override
  public void setEvaluators(int iters, Evaluator[] evaluators)
  {
    this.evaluateIters = iters;
    this.evaluators = evaluators;
  }


  //This can be filled if an extending class needs to initialize things.
  protected void init(AbstractStochasticCachingDiffUpdateFunction func) {
  }

  public double getObjective(AbstractStochasticCachingDiffUpdateFunction function, double[] w, double wscale, int[] sample)
  {
    double wnorm = getNorm(w) * wscale*wscale;
    double obj = function.valueAt(w,wscale,sample);
    // Calculate objective with L2 regularization
    return obj + 0.5*sample.length*lambda*wnorm;
  }

  public double tryEta(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int[] sample, double eta)
  {
    int numBatches =  sample.length / bSize;
    double[] w = new double[initial.length];
    double wscale = 1;
    System.arraycopy(initial, 0, w, 0, w.length);
    int[] sampleBatch = new int[bSize];
    int sampleIndex = 0;
    for (int batch = 0; batch < numBatches; batch++) {
      for (int i = 0; i < bSize; i++) {
        sampleBatch[i] = sample[(sampleIndex + i) % sample.length];
      }
      sampleIndex += bSize;
      double gain = eta/wscale;
      function.calculateStochasticUpdate(w, wscale, sampleBatch, gain);
      wscale *= (1 - eta * lambda*bSize);
    }
    double obj = getObjective(function, w, wscale, sample);
    return obj;
  }

  /**
   * Finds a good learning rate to start with.
   * eta = 1/(lambda*(t0+t)) - we find good t0
   * @param function
   * @param initial
   * @param sampleSize
   * @param seta
   */
  public double tune(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int sampleSize, double seta)
  {
    Timing timer = new Timing();
    int[] sample = function.getSample(sampleSize);
    double sobj = getObjective(function, initial, 1, sample);
    double besteta = 1;
    double bestobj = sobj;
    double eta = seta;
    int totest = 10;
    double factor = 2;
    boolean phase2 = false;
    while (totest > 0 || !phase2)
    {
      double obj = tryEta(function, initial, sample, eta);
      boolean okay = (obj < sobj);
      sayln("  Trying eta=" + eta + "  obj=" + obj + ((okay)? "(possible)":"(too large)"));
      if (okay)
      {
        totest -= 1;
        if (obj < bestobj) {
          bestobj = obj;
          besteta = eta;
        }
      }
      if (! phase2)
      {
        if (okay) {
          eta = eta * factor;
        } else {
          phase2 = true;
          eta = seta;
        }
      }
      if (phase2) {
        eta = eta / factor;
      }
    }
    // take it on the safe side (implicit regularization)
    besteta /= factor;
    // determine t
    t0 = (int) (1 / (besteta * lambda));
    sayln("  Taking eta=" + besteta + " t0=" + t0);
    sayln("  Tuning completed in: " + Timing.toSecondsString(timer.report()) + " s");
    return besteta;
  }

  // really this is the square of the L2 norm....
  private static double getNorm(double[] w)
  {
    double norm = 0;
    for (int i = 0; i < w.length; i++) {
      norm += w[i]*w[i];
    }
    return norm;
  }

  private void rescale()
  {
    if (xscale == 1) return;
    for (int i = 0; i < x.length; i++) {
      x[i] *= xscale;
    }
    xscale = 1;
  }

  private void doEvaluation(double[] x) {
    // Evaluate solution
    if (evaluators == null) return;
    for (Evaluator eval:evaluators) {
      sayln("  Evaluating: " + eval.toString());
      eval.evaluate(x);
    }
  }

  @Override
  public double[] minimize(Function function, double functionTolerance, double[] initial) {
    return minimize(function, functionTolerance, initial, -1);
  }

  @Override
  public double[] minimize(Function f, double functionTolerance, double[] initial, int maxIterations) {
    if (!(f instanceof AbstractStochasticCachingDiffUpdateFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffUpdateFunction function = (AbstractStochasticCachingDiffUpdateFunction) f;
    int totalSamples = function.dataDimension();
    int tuneSampleSize = Math.min(totalSamples, tuningSamples);
    if (tuneSampleSize < tuningSamples) {
      System.err.println("WARNING: Total number of samples=" + totalSamples +
              " is smaller than requested tuning sample size=" + tuningSamples + "!!!");
    }
    lambda = 1.0/(sigma*totalSamples);
    sayln("Using sigma=" + sigma + " lambda=" + lambda + " tuning sample size " + tuneSampleSize);
    // tune(function, initial, tuneSampleSize, 0.1);
    t0 = (int) (1 / (0.1 * lambda));

    x = new double[initial.length];
    System.arraycopy(initial, 0, x, 0, x.length);
    xscale = 1;
    xnorm = getNorm(x);
    int numBatches =  totalSamples/ bSize;

    init(function);

    boolean have_max = (maxIterations > 0 || numPasses > 0);

    if (!have_max){
      throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
    } else{
      maxIterations = Math.max(maxIterations, numPasses)*numBatches;
    }

    sayln("       Batch size of: " + bSize);
    sayln("       Data dimension of: " + totalSamples );
    sayln("       Batches per pass through data:  " + numBatches );
    sayln("       Number of passes is = " + numPasses);
    sayln("       Max iterations is = " + maxIterations);

    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //            Loop
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    Timing total = new Timing();
    Timing current = new Timing();
    total.start();
    current.start();
    int t=t0;
    int iters = 0;
    for (int pass = 0; pass < numPasses; pass++)  {
      boolean doEval = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0);
      if (doEval) {
        rescale();
        doEvaluation(x);
      }

      double totalValue = 0;
      double lastValue = 0;
      say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
      for (int batch = 0; batch < numBatches; batch++) {
        iters++;

        //Get the next X
        double eta = 1/(lambda*t);
        double gain = eta/xscale;
        lastValue = function.calculateStochasticUpdate(x, xscale, bSize, gain);
        totalValue += lastValue;
        // weight decay (for L2 regularization)
        xscale *= (1 - eta * lambda*bSize);
        t+=bSize;
      }
      if (xscale < 1e-6) {
        rescale();
      }
      try {
        ArrayMath.assertFinite(x,"x");
      } catch (ArrayMath.InvalidElementException e) {
        System.err.println(e.toString());
        for(int i=0;i<x.length;i++){ x[i]=Double.NaN; }
        break;
      }
      xnorm = getNorm(x)*xscale*xscale;
      // Calculate loss based on L2 regularization
      double loss = totalValue + 0.5 * xnorm * lambda * totalSamples;
      say(String.valueOf(numBatches));
      say("[" + ( total.report() )/1000.0 + " s " );
      say("{" + (current.restart()/1000.0) + " s}] ");
      sayln(" "+lastValue + " " + totalValue + " " + loss);

      if (iters >= maxIterations) {
        sayln("Stochastic Optimization complete.  Stopped after max iterations");
        break;
      }

      if (total.report() >= maxTime){
        sayln("Stochastic Optimization complete.  Stopped after max time");
        break;
      }

    }
    rescale();

    if (evaluateIters > 0) {
      // do final evaluation
      doEvaluation(x);
    }

    sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");

    return x;
  }

  protected void sayln(String s) {
    if (!quiet) {
      System.err.println(s);
    }
  }

  protected void say(String s) {
    if (!quiet) {
      System.err.print(s);
    }
  }

}
TOP

Related Classes of edu.stanford.nlp.optimization.SGDMinimizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.