Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.GELatticeTask

/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;

/**
* Optimizable for CRF using Generalized Expectation constraints that
* consider either a single label or a pair of labels of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* @author Gregory Druck
*/

public class CRFOptimizableByGE implements Optimizable.ByGradientValue {

  private static final int DEFAULT_GPV = 10;
 
  private CRF crf;
  private ArrayList<GEConstraint> constraints;
  private InstanceList data;
 
  private int numThreads;
  private double gpv;
  private double weight;
 
  // indicator that keeps track of whether
  // the gradient / value need to be re-computed
  private int cache;
  private double cachedValue;
  private CRF.Factors cachedGradient;
 
  // lists of source states / transition indices
  // for each destination state
  // used in GELattice
  private int[][] reverseTrans;
  private int[][] reverseTransIndices;
 
  // instances in which at least one
  // constraint fires
  private BitSet instancesWithConstraints;
 
  private ThreadPoolExecutor executor;
 
  /**
   * @param crf CRF
   * @param constraints List of GEConstraints
   * @param data Unlabeled data.
   * @param map Map between states and labels.
   * @param numThreads Number of threads to use for training (DEFAULT=1)
   */
  public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads) {
    this(crf,constraints,data,map,numThreads,1);
  }
 
  public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads, double weight) {
    this.crf = crf;
    this.constraints = constraints;
    this.cache = Integer.MAX_VALUE;
    this.cachedValue = Double.NaN;
    this.cachedGradient = new CRF.Factors(crf);
    this.data = data;
    this.numThreads = numThreads;
    this.weight = weight;
   
    instancesWithConstraints = new BitSet(data.size());
   
    for (GEConstraint constraint : constraints) {
      constraint.setStateLabelMap(map);
      BitSet bitset = constraint.preProcess(data);
      instancesWithConstraints.or(bitset);
    }
    this.gpv = DEFAULT_GPV;
   
    if (numThreads > 1) {
      this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
    }
   
    createReverseTransitionMatrices(crf);
  }

  /**
   * Initializes data structures for mapping between a
   * destination state and its source states / transition indices.
   *
   * @param crf CRF
   */
  public void createReverseTransitionMatrices(CRF crf) {
    int[] counts = new int[crf.numStates()];
    for (int si = 0; si < crf.numStates(); si++) {
      CRF.State prevState = (CRF.State)crf.getState(si);
      for (int di = 0; di < prevState.numDestinations(); di++) {
        int sj = prevState.getDestinationState(di).getIndex();
        counts[sj]++;
      }
    }
 
    this.reverseTrans = new int[crf.numStates()][];
    this.reverseTransIndices = new int[crf.numStates()][];
    for (int i = 0; i < counts.length; i++) {
      this.reverseTrans[i] = new int[counts[i]];
      this.reverseTransIndices[i] = new int[counts[i]];
    }
   
    int[] indices = new int[crf.numStates()];
    for (int si = 0; si < crf.numStates(); si++) {
      CRF.State prevState = (CRF.State)crf.getState(si);
      for (int di = 0; di < prevState.numDestinations(); di++) {
        int sj = prevState.getDestinationState(di).getIndex();
        this.reverseTrans[sj][indices[sj]] = si;
        this.reverseTransIndices[sj][indices[sj]] = di;
        indices[sj]++;
      }
    }
  }
 
  public int getNumParameters() {
    return crf.getNumParameters();
  }

  public void getParameters(double[] buffer) {
    crf.getParameters().getParameters(buffer);
  }

  public double getParameter(int index) {
    return crf.getParameters().getParameter(index);
  }

  public void setParameters(double[] params) {
    crf.getParameters().setParameters(params);
    crf.weightsValueChanged();
  }

  public void setParameter(int index, double value) {
    crf.getParameters().setParameter(index, value);
    crf.weightsValueChanged();
  }

  public void cacheValueAndGradient() {
   
    // compute and cache lattices
    //System.gc();
    //System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " before lattice");  
    ArrayList<SumLattice> lattices = new ArrayList<SumLattice>();
    if (numThreads == 1) {
      for (int ii = 0; ii < data.size(); ii++) {
        if (instancesWithConstraints.get(ii)) {
          SumLatticeDefault lattice = new SumLatticeDefault(
              this.crf, (FeatureVectorSequence)data.get(ii).getData(),
              null, null, true);
          lattices.add(lattice);
        }
        else {
          lattices.add(null);
        }
      }
    }
    else {
      // mutli-threaded version
      ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>();
     
      if (data.size() < numThreads) {
        numThreads = data.size();
      }
     
      int increment = data.size() / numThreads;
      int start = 0;
      int end = increment;
      for (int thread = 0; thread < numThreads; thread++) {
        tasks.add(new SumLatticeTask(crf,data,instancesWithConstraints,start,end));
        start += increment;
        if (thread == numThreads - 2) {
          end = data.size();
        }
        else {
          end += increment;
        }
      }
     
      try {
        // run all threads and wait for them to finish
        executor.invokeAll(tasks);
      } catch (InterruptedException ie) {
        ie.printStackTrace();
      }
     
      for (Callable<Void> task : tasks) {
        lattices.addAll(((SumLatticeTask)task).getLattices());
      }
      assert(lattices.size() == data.size()) : lattices.size() + " " + data.size();
    }
    System.err.println("Done computing lattices.");
   
    for (GEConstraint constraint : constraints) {
      constraint.zeroExpectations();
      constraint.computeExpectations(lattices);
    }
    System.err.println("Done computing expectations.");
    //System.gc();
    //System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " after lattice"); 
   
    // compute GE value
    this.cachedValue = 0;
    for (GEConstraint constraint : constraints) {
      this.cachedValue += constraint.getValue();
    }
   
    cachedGradient.zero();
   
    // compute GE gradient
    if (numThreads == 1) {
      for (int ii = 0; ii < data.size(); ii++) {
        if (instancesWithConstraints.get(ii)) {
          SumLattice lattice = lattices.get(ii);
          FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData();
          new GELattice(fvs, lattice.getGammas(), lattice.getXis(), crf, reverseTrans, reverseTransIndices, cachedGradient,this.constraints, false);
        }
      }
    }
    else {
      // multi-threaded version
      ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>();
     
      if (data.size() < numThreads) {
        numThreads = data.size();
      }
     
      int increment = data.size() / numThreads;
      int start = 0;
      int end = increment;
      for (int thread = 0; thread < numThreads; thread++) {
        ArrayList<GEConstraint> constraintsCopy = new ArrayList<GEConstraint>();
        for (GEConstraint constraint : constraints) {
          constraintsCopy.add(constraint.copy());
        }
       
        tasks.add(new GELatticeTask(crf,data,lattices,constraintsCopy,instancesWithConstraints,
          reverseTrans,reverseTransIndices,start,end));
        start += increment;
        if (thread == numThreads - 2) {
          end = data.size();
        }
        else {
          end += increment;
        }
       
      }
     
      try {
        // run all threads and wait for them to finish
        executor.invokeAll(tasks);
      } catch (InterruptedException ie) {
        ie.printStackTrace();
      }
     
      for (Callable<Void> task : tasks) {
        cachedGradient.plusEquals(((GELatticeTask)task).getGradient(), 1);
      }
    }

    System.err.println("Done computing gradient.");

    this.cachedValue += crf.getParameters().gaussianPrior(gpv);
    cachedGradient.plusEqualsGaussianPriorGradient(crf.getParameters(), gpv);
   
    System.err.println("Done computing regularization.");  
   
    if (weight != 1) {
      this.cachedValue *= weight;
    }
    System.err.println("GE Value = " + this.cachedValue);
  }
 
  public void setGaussianPriorVariance(double variance) {
    this.gpv = variance;
  }
 
  public void getValueGradient(double[] buffer) {
    if (crf.getWeightsValueChangeStamp() != cache) {
      cacheValueAndGradient();
      cache = crf.getWeightsValueChangeStamp();
    }
    // TODO this will also multiply the prior, if active!
    cachedGradient.getParameters(buffer);
    if (weight != 1) {
      MatrixOps.timesEquals(buffer, weight);
    }
  }

  public double getValue() {
    if (crf.getWeightsValueChangeStamp() != cache) {
      cacheValueAndGradient();
      cache = crf.getWeightsValueChangeStamp();
    }
    return this.cachedValue;
  }
 
  /**
   * Should be called after training is complete
   * to shutdown all threads.
   */
  public void shutdown() {
    if (executor == null) return;
    executor.shutdown();
    try {
      executor.awaitTermination(30, TimeUnit.SECONDS);
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
    assert(executor.shutdownNow().size() == 0) : "All tasks didn't finish";
  }
}

class SumLatticeTask implements Callable<Void> {

  private int start;
  private int end;
  private ArrayList<SumLatticeDefault> lattices;
  private InstanceList data;
  private CRF crf;
  private BitSet instancesWithConstraints;
 
 
  public SumLatticeTask(CRF crf, InstanceList data, BitSet instancesWithConstraints, int start, int end) {
    this.crf = crf;
    this.data = data;
    this.start = start;
    this.end = end;
    this.lattices = new ArrayList<SumLatticeDefault>();
    this.instancesWithConstraints = instancesWithConstraints;
  }
 
  public ArrayList<SumLatticeDefault> getLattices() {
    return this.lattices;
  }

  public Void call() throws Exception {
    for (int ii = start; ii < end; ii++) {
      if (instancesWithConstraints.get(ii)) {
        Instance instance = data.get(ii);
        SumLatticeDefault lattice = new SumLatticeDefault(
          this.crf, (FeatureVectorSequence)instance.getData(),
          null, null, true);
        lattices.add(lattice);
      }
      else {
        lattices.add(null);
      }
    }
    return null;
  }
}

class GELatticeTask implements Callable<Void> {

  private int start;
  private int end;
  private ArrayList<GEConstraint> constraints;
  private ArrayList<SumLattice> lattices;
  private InstanceList data;
  private CRF crf;
  private CRF.Factors gradient;
  private BitSet instancesWithConstraints;
  private int[][] reverseTrans;
  private int[][] reverseTransIndices;
 
 
  /**
   * @param crf CRF
   * @param data Unlabeled data
   * @param lattices Cached SumLattices
   * @param constraints List of GEConstraints
   * @param instancesWithConstraints BitSet which indices whether any constraints fire for an instance
   * @param reverseTrans Source state indices for each destination state
   * @param reverseTransIndices Transition indices for each destination state
   * @param start Position in unlabeled data where this thread starts computing
   * @param end Position in unlabeled data where this thread stops computing
   */
  public GELatticeTask(CRF crf, InstanceList data, ArrayList<SumLattice> lattices,
      ArrayList<GEConstraint> constraints, BitSet instancesWithConstraints,
      int[][] reverseTrans, int[][] reverseTransIndices,
      int start, int end) {
    this.crf = crf;
    this.data = data;
    this.lattices = lattices;
    this.constraints = constraints;
    this.start = start;
    this.end = end;
    this.gradient = new CRF.Factors(crf);
    this.instancesWithConstraints = instancesWithConstraints;
    this.reverseTrans = reverseTrans;
    this.reverseTransIndices = reverseTransIndices;
  }
 
  public CRF.Factors getGradient() {
    return this.gradient;
  }

  public Void call() throws Exception {
    for (int ii = start; ii < end; ii++) {
      if (instancesWithConstraints.get(ii)) {
        SumLattice lattice = lattices.get(ii);
        FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData();
        new GELattice(fvs, lattice.getGammas(), lattice.getXis(),
          crf, reverseTrans, reverseTransIndices, gradient,this.constraints, false);
      }
    }
    return null;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.GELatticeTask

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.