Package cc.mallet.fst.semi_supervised.pr

Source Code of cc.mallet.fst.semi_supervised.pr.ConstraintsOptimizableByPR$ExpectationTask

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;

/**
* Optimizable for E-step/I-projection in Posterior Regularization (PR).
*
* @author Kedar Bellare
* @author Gregory Druck
*/

public class ConstraintsOptimizableByPR implements Serializable, ByGradientValue {
  private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
  private static final long serialVersionUID = 1;
 
  protected boolean cacheStale;
  protected int numParameters;
  protected int numThreads;
  protected InstanceList trainingSet;
  protected double cachedValue = -123456789;
  protected double[] cachedGradient;
  protected CRF crf;
  protected ThreadPoolExecutor executor;
  protected double[][][][] cachedDots;
  PRAuxiliaryModel model;

  public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model) {
    this(crf,ilist,model,1);
  }
 
  public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model, int numThreads) {
    this.crf = crf;
    this.trainingSet = ilist;

    this.model = model;
    this.numParameters = model.numParameters();
    cachedGradient = new double[numParameters];
    this.cacheStale = true;
   
    this.numThreads = numThreads;
    this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
 
    cacheDotProducts();
   
  }

  public void cacheDotProducts() {
    cachedDots = new double[trainingSet.size()][][][];
    for (int i = 0; i < trainingSet.size(); i++) {
      FeatureVectorSequence input = (FeatureVectorSequence)trainingSet.get(i).getData();

      cachedDots[i] = new double[input.size()][crf.numStates()][crf.numStates()];
      for (int j = 0; j < input.size(); j++) {
        for (int k = 0; k < crf.numStates(); k++) {
          for (int l = 0; l < crf.numStates(); l++) {
            cachedDots[i][j][k][l] = Transducer.IMPOSSIBLE_WEIGHT;
          }
        }
      }

      for (int j = 0; j < input.size(); j++) {
        for (int k = 0; k < crf.numStates(); k++) {
          TransitionIterator iter = crf.getState(k).transitionIterator(input, j);
          while (iter.hasNext()) {
            int l = iter.next().getIndex();
            cachedDots[i][j][k][l] = iter.getWeight();
          }
        }
      }
    }
  }
 
  public int getNumParameters() {
    return numParameters;
  }

  public void getParameters(double[] params) {
    model.getParameters(params);
  }

  public double getParameter(int index) {
    return model.getParameter(index);
  }

  public void setParameters(double[] params) {
    cacheStale = true;
    model.setParameters(params);
  }

  public void setParameter(int index, double value) {
    cacheStale = true;
    model.setParameter(index, value);
  }

  protected double getExpectationValue() {
    model.zeroExpectations();

    // updating tasks
    ArrayList<Callable<Double>> tasks = new ArrayList<Callable<Double>>();
    int increment = trainingSet.size() / numThreads;
    int start = 0;
    int end = increment;
    for (int taskIndex = 0; taskIndex < numThreads; taskIndex++) {
      tasks.add(new ExpectationTask(start,end,model.copy()));
      start = end;
      if (taskIndex == numThreads - 2) {
        end = trainingSet.size();
      }
      else {
        end = start + increment;
      }
    }
   
    double value = 0;
    try {
      List<Future<Double>> results = executor.invokeAll(tasks);
   
      // compute value
      for (Future<Double> f : results) {
        try {
          value += f.get();
        } catch (ExecutionException ee) {
          ee.printStackTrace();
        }
      }
    } catch (InterruptedException ie) {
      ie.printStackTrace();
    }

    // combine results
    combine(model,tasks);

    // mu*b - w*||mu||^2
    value += model.getValue();
    return value;
  }

  /**
   * Returns the log probability of the training sequence labels and the prior
   * over parameters.
   */
  public double getValue() {
    if (cacheStale) {
      cachedValue = getExpectationValue();
      model.getValueGradient(cachedGradient);
      cacheStale = false;
      logger.info("getValue (auxiliary distribution) = " + cachedValue);
    }
    return cachedValue;
  }

  public double getCompleteValueContribution() {
    if (cacheStale) {
      getValue();
    }
    double value = model.getCompleteValueContribution();
    return value;
  }
 
  public void getValueGradient(double[] buffer) {
    if (cacheStale) {
      getValue();
    }
    System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
  }

  private void combine(PRAuxiliaryModel orig, ArrayList<Callable<Double>> tasks) {
    for (int i = 0; i < tasks.size(); i++) {
      ExpectationTask task = (ExpectationTask)tasks.get(i);
      PRAuxiliaryModel model = task.getModelCopy();
      for (int ci = 0; ci < model.numConstraints(); ci++) {
        PRConstraint origConstraint = orig.getConstraint(ci);
        PRConstraint copyConstraint = model.getConstraint(ci);
        double[] expectation = new double[origConstraint.numDimensions()];
        copyConstraint.getExpectations(expectation);
        origConstraint.addExpectations(expectation);
      }
    }
  }

  public void shutdown() {
    executor.shutdown();
  }
 
  public double[][][][] getCachedDots() {
    return cachedDots;
  }
 
  public PRAuxiliaryModel getAuxModel() {
    return model;
  }
 
  private class ExpectationTask implements Callable<Double> {
   
    private int start;
    private int end;
    private PRAuxiliaryModel modelCopy;
   
    public ExpectationTask(int start, int end, PRAuxiliaryModel modelCopy) {
      this.start = start;
      this.end = end;
      this.modelCopy = modelCopy;
    }
   
    public Double call() throws Exception {
      double value = 0;
      for (int ii = start; ii < end; ii++) {
        Instance inst = trainingSet.get(ii);
        Sequence input = (Sequence) inst.getData();
        // logZ     
        value -= new SumLatticePR(crf, ii, input, null, modelCopy, cachedDots[ii], true, null, null, false).getTotalWeight();
      }
      return value;
    }
   
    public PRAuxiliaryModel getModelCopy() {
      return modelCopy;
    }
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.pr.ConstraintsOptimizableByPR$ExpectationTask

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.