Package edu.stanford.nlp.optimization

Source Code of edu.stanford.nlp.optimization.StochasticMinimizer$setGain

package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Timing;

import java.text.NumberFormat;
import java.text.DecimalFormat;
import java.io.PrintWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.ArrayList;
import java.util.Random;
import edu.stanford.nlp.util.Pair;

/**
* Stochastic Gradient Descent Minimizer.
* Note: If you want a fast SGD minimizer, then you probably want to use
* StochasticInPlaceMinimizer, not this class!
*
* The basic way to use the minimizer is with a null constructor, then
* the simple minimize method:
* <p/>
* <p><code>Minimizer smd = new SGDMinimizer();</code>
* <br><code>DiffFunction df = new SomeDiffFunction(); //Note that it must be a incidence of AbstractStochasticCachingDiffFunction</code>
* <br><code>double tol = 1e-4;</code>
* <br><code>double[] initial = getInitialGuess();</code>
* <br><code>int maxIterations = someSafeNumber;</code>
* <br><code>double[] minimum = qnm.minimize(df,tol,initial,maxIterations);</code>
* <p/>
* Constructing with a null constructor will use the default values of
* <p>
* <br><code>batchSize = 15;</code>
* <br><code>initialGain = 0.1;</code>
* <p/>
*
* @author <a href="mailto:akleeman@stanford.edu">Alex Kleeman</a>
* @version 1.0
* @since 1.0
*/
public abstract class StochasticMinimizer<T extends Function> implements Minimizer<T>, HasEvaluators {

  public boolean outputIterationsToFile = false;
  public int outputFrequency = 1000;
  public double gain = 0.1;

  protected double[] x, newX, grad, newGrad,v;
  protected int numBatches;
  protected int k;
  protected int bSize = 15;
  protected boolean quiet = false;
  protected List<double[]> gradList = null;
  protected int memory = 10;
  protected int numPasses = -1;
  protected Random gen = new Random(1);
  protected PrintWriter file = null;
  protected PrintWriter infoFile = null;
  protected long maxTime = Long.MAX_VALUE;

  private int evaluateIters = 0;    // Evaluate every x iterations (0 = no evaluation)
  private Evaluator[] evaluators;  // separate set of evaluators to check how optimization is going

  public void shutUp() {
    this.quiet = true;
  }

  protected static final NumberFormat nf = new DecimalFormat("0.000E0");


  protected abstract String getName();

  protected abstract void takeStep(AbstractStochasticCachingDiffFunction dfunction);

  public void setEvaluators(int iters, Evaluator[] evaluators)
  {
    this.evaluateIters = iters;
    this.evaluators = evaluators;
  }

  /*
    This is the scaling factor for the gains to ensure convergence
  */
  protected static double gainSchedule(int it, double tau){
    return (tau / (tau + it));
  }

  /*
   * This is used to smooth the gradients, providing a more robust calculation which
   * generally leads to a better routine.
   */

  protected static double[] smooth(List<double[]> toSmooth){
    double[] smoothed = new double[toSmooth.get(0).length];

    for(double[] thisArray:toSmooth){
      ArrayMath.pairwiseAddInPlace(smoothed,thisArray);
    }

    ArrayMath.multiplyInPlace(smoothed,1/((double) toSmooth.size() ));
    return smoothed;
  }


  private void initFiles() {
    if (outputIterationsToFile) {

      String fileName = getName() + ".output";
      String infoName = getName() + ".info";

      try {
        file = new PrintWriter(new FileOutputStream(fileName),true);
        infoFile = new PrintWriter(new FileOutputStream(infoName),true);
      }
      catch (IOException e) {
        System.err.println("Caught IOException outputting data to file: " + e.getMessage());
        System.exit(1);
      }
    }
  }


  public abstract Pair<Integer,Double> tune(Function function,double[] initial, long msPerTest);

  public double tuneDouble(edu.stanford.nlp.optimization.Function function, double[] initial, long msPerTest,PropertySetter<Double> ps,double lower,double upper){
    return this.tuneDouble(function, initial, msPerTest, ps, lower, upper, 1e-3*Math.abs(upper-lower));
  }

  public double tuneDouble(edu.stanford.nlp.optimization.Function function, double[] initial, long msPerTest,PropertySetter<Double> ps,double lower,double upper,double TOL){

    double[] xtest = new double[initial.length];
    this.maxTime = msPerTest;
    // check for stochastic derivatives
    if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function;

    List<Pair<Double,Double>> res = new ArrayList<Pair<Double,Double>>();
    Pair<Double,Double> best = new Pair<Double,Double>(lower,Double.POSITIVE_INFINITY); //this is set to lower because the first it will always use the lower first, so it has to be best
    Pair<Double,Double> low = new Pair<Double,Double>(lower,Double.POSITIVE_INFINITY);
    Pair<Double,Double> high = new Pair<Double,Double>(upper,Double.POSITIVE_INFINITY);
    Pair<Double,Double> cur = new Pair<Double,Double>();
    Pair<Double,Double> tmp = new Pair<Double,Double>();

    List<Double> queue = new ArrayList<Double>();
    queue.add(lower);
    queue.add(upper);
    //queue.add(0.5* (lower + upper));

    boolean  toContinue = true;
    this.numPasses = 10000;

    do{
      System.arraycopy(initial, 0, xtest, 0, initial.length);
      if(queue.size() != 0){
        cur.first = queue.remove(0);
      }else{
        cur.first = 0.5*( low.first() + high.first() );
      }

      ps.set(cur.first() );

      System.err.println("");
      System.err.println("About to test with batch size:  " + bSize +
              "  gain: "  + gain + " and  " +
              ps.toString() + " set to  " + cur.first());

      xtest = this.minimize(function, 1e-100, xtest);

      if(Double.isNaN( xtest[0] ) ){
        cur.second = Double.POSITIVE_INFINITY;
      } else {
        cur.second = dfunction.valueAt(xtest);
      }

      if( cur.second() < best.second() ){

        copyPair(best,tmp);
        copyPair(cur,best);

        if(tmp.first() > best.first()){
          copyPair(tmp,high); // The old best is now the upper bound
        }else{
          copyPair(tmp,low)// The old best is now the lower bound
        }
        queue.add( 0.5 * ( cur.first() + high.first()  ) ); // check in the right interval next
      } else if ( cur.first() < best.first() ){
        copyPair(cur,low);
      } else if( cur.first() > best.first()){
        copyPair(cur,high);
      }

      if( Math.abs( low.first() - high.first() ) < TOL   ) {
        toContinue = false;
      }

      res.add(new Pair<Double,Double>(cur.first(),cur.second()));

      System.err.println("");
      System.err.println("Final value is: " + nf.format(cur.second()));
      System.err.println("Optimal so far using " + ps.toString() + " is: "+ best.first() );
    } while(toContinue);


    //output the results to screen.
    System.err.println("-------------");
    System.err.println(" RESULTS          ");
    System.err.println(ps.getClass().toString());
    System.err.println("-------------");
    System.err.println("  val    ,    function after " + msPerTest + " ms");
    for(int i=0;i<res.size();i++ ){
      System.err.println(res.get(i).first() + "    ,    " + res.get(i).second() );
     }
    System.err.println("");
    System.err.println("");

    return best.first();

  }

  private static void copyPair(Pair<Double,Double> from, Pair<Double,Double> to) {
    to.first = from.first();
    to.second = from.second();
  }


  private class setGain implements PropertySetter<Double>{
    StochasticMinimizer<T> parent = null;

    public setGain(StochasticMinimizer<T> min) {
      parent = min;
    }

    public void set(Double in) {
      gain = in ;
    }
  }


  public double tuneGain(Function function, double[] initial, long msPerTest, double lower, double upper){

    return tuneDouble(function,initial,msPerTest,new setGain(this),lower,upper);
  }


  // [cdm 2012: The version that used to be here was clearly buggy;
  // I changed it a little, but didn't test it. It's now more correct, but
  // I think it is still conceptually faulty, since it will keep growing the
  // batch size so long as any minute improvement in the function value is
  // obtained, whereas the whole point of using a small batch is to get speed
  // at the cost of small losses.]
  public int tuneBatch(Function function, double[] initial, long msPerTest, int bStart) {

    double[] xTest = new double[initial.length];
    int bOpt = 0;
    double min = Double.POSITIVE_INFINITY;
    this.maxTime = msPerTest;
    double prev = Double.POSITIVE_INFINITY;

    // check for stochastic derivatives
    if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffFunction dFunction = (AbstractStochasticCachingDiffFunction) function;

    int b = bStart;
    boolean  toContinue = true;

    do {
      System.arraycopy(initial, 0, xTest, 0, initial.length);
      System.err.println("");
      System.err.println("Testing with batch size:  " + b );
      bSize = b;
      shutUp();
      this.minimize(function, 1e-5, xTest);
      double result = dFunction.valueAt(xTest);

      if (result < min) {
        min = result;
        bOpt = bSize;
        b *= 2;
        prev = result;
      } else if(result < prev) {
        b *= 2;
        prev = result;
      } else if (result > prev) {
        toContinue = false;
      }

      System.err.println("");
      System.err.println("Final value is: " + nf.format(result));
      System.err.println("Optimal so far is:  batch size: " + bOpt );
    } while (toContinue);

    return bOpt;
  }

  public Pair<Integer,Double> tune(Function function, double[] initial, long msPerTest,List<Integer> batchSizes, List<Double> gains){

    double[] xtest = new double[initial.length];
    int bOpt = 0;
    double gOpt = 0.0;
    double min = Double.POSITIVE_INFINITY;

    double[][] results = new double[batchSizes.size()][gains.size()];

    this.maxTime = msPerTest;

    for( int b=0;b<batchSizes.size();b++){
      for(int g=0;g<gains.size();g++){
        System.arraycopy(initial, 0, xtest, 0, initial.length);
        bSize = batchSizes.get(b);
        gain = gains.get(g);
        System.err.println("");
        System.err.println("Testing with batch size: " + bSize + "    gain:  " + nf.format(gain) );
        this.quiet = true;
        this.minimize(function, 1e-100, xtest);
        results[b][g] = function.valueAt(xtest);

        if( results[b][g] < min ){
          min = results[b][g];
          bOpt = bSize;
          gOpt = gain;
        }

        System.err.println("");
        System.err.println("Final value is: " + nf.format(results[b][g]));
        System.err.println("Optimal so far is:  batch size: " + bOpt + "   gain:  " + nf.format(gOpt) );

      }
    }

    return new Pair<Integer,Double>(bOpt,gOpt);
  }


  //This can be filled if an extending class needs to initialize things.
  protected void init(AbstractStochasticCachingDiffFunction func){
  }


  private void doEvaluation(double[] x) {
    // Evaluate solution
    if (evaluators == null) return;
    for (Evaluator eval:evaluators) {
      sayln("  Evaluating: " + eval.toString());
      eval.evaluate(x);
    }
  }

  public double[] minimize(Function function, double functionTolerance, double[] initial) {
    return minimize(function, functionTolerance, initial, -1);
  }


  public double[] minimize(Function function, double functionTolerance, double[] initial, int maxIterations) {

    // check for stochastic derivatives
    if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function;

    dfunction.method = StochasticCalculateMethods.GradientOnly;

    /* ---
      StochasticDiffFunctionTester sdft = new StochasticDiffFunctionTester(dfunction);
      ArrayMath.add(initial, gen.nextDouble() ); // to make sure that priors are working.
      sdft.testSumOfBatches(initial, 1e-4);
      System.exit(1);
    --- */

    x = initial;
    grad = new double[x.length];
    newX = new double[x.length];
    gradList = new ArrayList<double[]>();
    numBatches =  dfunction.dataDimension()/ bSize;
    outputFrequency = (int) Math.ceil( ((double) numBatches) /( (double) outputFrequency) )  ;

    init(dfunction);
    initFiles();

    boolean have_max = (maxIterations > 0 || numPasses > 0);

    if (!have_max){
      throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
    }else{
      maxIterations = Math.max(maxIterations, numPasses)*numBatches;
    }

    sayln("       Batchsize of: " + bSize);
    sayln("       Data dimension of: " + dfunction.dataDimension() );
    sayln("       Batches per pass through data:  " + numBatches );
    sayln("       Max iterations is = " + maxIterations);

    if (outputIterationsToFile) {
      infoFile.println(function.domainDimension() + "; DomainDimension " );
      infoFile.println(bSize + "; batchSize ");
      infoFile.println(maxIterations + "; maxIterations");
      infoFile.println(numBatches + "; numBatches ");
      infoFile.println(outputFrequency  + "; outputFrequency");
    }

    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //            Loop
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    Timing total = new Timing();
    Timing current = new Timing();
    total.start();
    current.start();
    for (k = 0; k<maxIterations ; k++)  {
      try{
        boolean doEval = (k > 0 && evaluateIters > 0 && k % evaluateIters == 0);
        if (doEval) {
          doEvaluation(x);
        }
        int pass = k/numBatches;
        int batch = k%numBatches;
        say("Iter: " + k + " pass " + pass + " batch " + batch);

        // restrict number of saved gradients
        //  (recycle memory of first gradient in list for new gradient)
        if(k > 0 && gradList.size() >= memory){
          newGrad = gradList.remove(0);
        }else{
          newGrad = new double[grad.length];
        }

        dfunction.hasNewVals = true;
        System.arraycopy(dfunction.derivativeAt(x,v,bSize),0,newGrad,0,newGrad.length);
        ArrayMath.assertFinite(newGrad,"newGrad");
        gradList.add(newGrad);
        grad = smooth(gradList);

        //Get the next X
        takeStep(dfunction);

        ArrayMath.assertFinite(newX,"newX");

        //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        // THIS IS FOR DEBUG ONLY
        //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        if(outputIterationsToFile && (k%outputFrequency == 0) && k!=0 ) {
          double curVal = dfunction.valueAt(x);
          say(" TrueValue{ " + curVal + " } ");
          file.println(k + " , " + curVal + " , " + total.report() );
        }
        //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        // END OF DEBUG STUFF
        //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        if (k >= maxIterations) {
          sayln("Stochastic Optimization complete.  Stopped after max iterations");
          x = newX;
          break;
        }

        if (total.report() >= maxTime){
          sayln("Stochastic Optimization complete.  Stopped after max time");
          x = newX;
          break;
        }

        System.arraycopy(newX, 0, x, 0, x.length);

        say("[" + ( total.report() )/1000.0 + " s " );
        say("{" + (current.restart()/1000.0) + " s}] ");
        say(" "+dfunction.lastValue());

        if (quiet) {
          System.err.print(".");
        }else{
          sayln("");
        }

      }catch(ArrayMath.InvalidElementException e){
        System.err.println(e.toString());
        for(int i=0;i<x.length;i++){ x[i]=Double.NaN; }
        break;
      }

    }
    if (evaluateIters > 0) {
      // do final evaluation
      doEvaluation(x);
    }

    if(outputIterationsToFile){
      infoFile.println(k + "; Iterations");
      infoFile.println(( total.report() )/1000.0 + "; Completion Time");
      infoFile.println(dfunction.valueAt(x) + "; Finalvalue");

      infoFile.close();
      file.close();
      System.err.println("Output Files Closed");
      //System.exit(1);
    }

    say("Completed in: " + ( total.report() )/1000.0 + " s");

    return x;
  }


  public interface PropertySetter <T1> {
    public void set(T1 in);
  }

  protected void sayln(String s) {
    if (!quiet) {
      System.err.println(s);
    }
  }

  protected void say(String s) {
    if (!quiet) {
      System.err.print(s);
    }
  }

}
TOP

Related Classes of edu.stanford.nlp.optimization.StochasticMinimizer$setGain

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.