Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.CRFTrainerByGE

package cc.mallet.fst.semi_supervised;

import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;

/**
* Trains a CRF using Generalized Expectation constraints that
* considers a single label of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* @author Gregory Druck
*/
public class CRFTrainerByGE extends TransducerTrainer {

  private static Logger logger = MalletLogger.getLogger(CRFTrainerByGE.class.getName());
 
  private static final int DEFAULT_NUM_RESETS = 1;
  private static final int DEFAULT_GPV = 10;
 
  private boolean converged;
  private int iteration;
  private int numThreads;
  private double gaussianPriorVariance;
  private HashMap<Integer,GECriterion> constraints;
  private CRF crf;
  private StateLabelMap stateLabelMap;
 
  public CRFTrainerByGE(CRF crf, HashMap<Integer,GECriterion> constraints) {
    this(crf,constraints,1);
  }
 
  public CRFTrainerByGE(CRF crf, HashMap<Integer,GECriterion> constraints, int numThreads) {
    this.converged = false;
    this.iteration = 0;
    this.constraints = constraints;
    this.crf = crf;
    this.numThreads = numThreads;
    this.gaussianPriorVariance = DEFAULT_GPV;
    // default one to one state label map
    // other maps can be set with setStateLabelMap
    this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(),true);
  }
 
  @Override
  public int getIteration() {
    return iteration;
  }

  @Override
  public Transducer getTransducer() {
    return crf;
  }

  @Override
  public boolean isFinishedTraining() {
    return converged;
  }
 
  public void setGaussianPriorVariance(double gpv) {
    this.gaussianPriorVariance = gpv;
  }

  // map between states in CRF FST and labels
  public void setStateLabelMap(StateLabelMap map) {
    this.stateLabelMap = map;
  }
 
  @Override
  public boolean train(InstanceList unlabeledSet, int numIterations) {
   
    assert(constraints.size() > 0);
    if (constraints.size() == 0) {
      throw new RuntimeException("No constraints specified!");
    }

    // TODO implement initialization
    //initMaxEnt(crf);
   
    // Check what type of constraints we have.
    // XXX Could instead implement separate trainers...
    boolean kl = false;
    boolean l2 = false;
    for (GECriterion constraint : constraints.values()) {
      if (constraint instanceof GEL2Criterion) {
        l2 = true;
      }
      else if (constraint instanceof GEKLCriterion) {
        kl = true;
      }
      else {
        throw new RuntimeException("Only KL and L2 constraints are supported " +
          "by this trainer. Constraint type is " + constraint.getClass());
      }
    }
    if (kl && l2) {
      throw new RuntimeException("Currently constraints must be either all KL " +
          "or all L2.");
    }
   
    GECriteria criteria;
    if (kl) {
      logger.info("kl");
      criteria = new GEKLCriteria(crf.numStates(), stateLabelMap, constraints);
    }
    else {
      logger.info("l2");
      criteria = new GEL2Criteria(crf.numStates(), stateLabelMap, constraints);
    }
   
    CRFOptimizableByGECriteria ge =
      new CRFOptimizableByGECriteria(criteria, crf, unlabeledSet, numThreads);
    ge.setGaussianPriorVariance(gaussianPriorVariance);
   
    LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(ge);
   
    converged = false;
    logger.info ("CRF about to train with "+numIterations+" iterations");
    // sometimes resetting the optimizer helps to find
    // a better parameter setting
    int iter = 0;
    for (int reset = 0; reset < DEFAULT_NUM_RESETS + 1; reset++) {
      for (; iter < numIterations; iter++) {
        try {
          converged = bfgs.optimize (1);
          iteration++;
          logger.info ("CRF finished one iteration of maximizer, i="+iter);
          runEvaluators();
        } catch (IllegalArgumentException e) {
          e.printStackTrace();
          logger.info ("Catching exception; saying converged.");
          converged = true;
        } catch (Exception e) {
          e.printStackTrace();
          logger.info("Catching exception; saying converged.");
          converged = true;
        }
        if (converged) {
          logger.info ("CRF training has converged, i="+iter);
          break;
        }
      }
      bfgs.reset();
    }
   
    ge.shutdown();
   
    return converged;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.CRFTrainerByGE

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.