Package cc.mallet.fst.semi_supervised.pr

Source Code of cc.mallet.fst.semi_supervised.pr.SumLatticePR

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr;

import java.util.logging.Logger;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.State;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;


/**
* Lattice for E-step/I-projection in PR.
*
* @author Gregory Druck
* @author Kedar Bellare
*/

public class SumLatticePR implements SumLattice {
  private static Logger logger = MalletLogger.getLogger(SumLatticePR.class.getName());

  protected double totalWeight;
  protected int latticeLength;
  protected double[][] gammas;
  protected double[][][] xis;
  protected LabelVector labelings[];
  protected Transducer transducer;
  protected LatticeNode[][] nodes;
  private Sequence input;

  public SumLatticePR(Transducer trans, int index, Sequence input, Sequence output,
      PRAuxiliaryModel auxModel, double[][][] cachedDots, boolean incrementConstraints, Transducer.Incrementor incrementor, 
      LabelAlphabet outputAlphabet, boolean saveXis) {
   
    assert (output == null || input.size() == output.size());

    // Initialize some structures
    this.input = input;
    this.transducer = trans;
    this.latticeLength = input.size() + 1;
    int numStates = transducer.numStates();
    this.nodes = new LatticeNode[latticeLength][numStates];
    this.gammas = new double[latticeLength][numStates];
    if (saveXis)
      xis = new double[latticeLength][numStates][numStates];

    double outputCounts[][] = null;
    if (outputAlphabet != null)
      outputCounts = new double[latticeLength][outputAlphabet.size()];

    for (int i = 0; i < numStates; i++) {
      for (int ip = 0; ip < latticeLength; ip++) {
        gammas[ip][i] = Transducer.IMPOSSIBLE_WEIGHT;
      }
      if (saveXis) {
        for (int j = 0; j < numStates; j++) {
          for (int ip = 0; ip < latticeLength; ip++) {
            xis[ip][i][j] = Transducer.IMPOSSIBLE_WEIGHT;
          }
        }
      }
    }

    // Forward pass
    boolean atLeastOneInitialState = false;
    for (int i = 0; i < numStates; i++) {
      double initialWeight = transducer.getState(i).getInitialWeight();
      if (initialWeight > Transducer.IMPOSSIBLE_WEIGHT) {
        getLatticeNode(0, i).alpha = initialWeight;
        atLeastOneInitialState = true;
      }
    }
    if (atLeastOneInitialState == false)
      logger.warning("There are no starting states!");

    for (int ip = 0; ip < latticeLength - 1; ip++)
      for (int i = 0; i < numStates; i++) {
        if (nodes[ip][i] == null
            || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT) {
          continue;
        }
       
        State s = transducer.getState(i);
        CachedDotTransitionIterator iter =
          new CachedDotTransitionIterator((CRF.State)s,input,ip,
              null,cachedDots[ip][i]);

        auxModel.preProcess(index,ip,input);
        while (iter.hasNext()) {
          State destination = iter.next();
          LatticeNode destinationNode = getLatticeNode(ip + 1, destination.getIndex());
          destinationNode.output = iter.getOutput();
          double transitionWeight = iter.getWeight();
          transitionWeight += auxModel.getWeight(index,ip,input,iter);
          destinationNode.alpha = Transducer.sumLogProb(
              destinationNode.alpha, nodes[ip][i].alpha + transitionWeight);
        }
      }

    totalWeight = Transducer.IMPOSSIBLE_WEIGHT;
    for (int i = 0; i < numStates; i++) {
      if (nodes[latticeLength-1][i] != null) {
        totalWeight = Transducer.sumLogProb(totalWeight,
          (nodes[latticeLength-1][i].alpha + transducer.getState(i).getFinalWeight()));
      }
    }

    if (totalWeight == Transducer.IMPOSSIBLE_WEIGHT) {
      return;
    }

    // Backward pass
    for (int i = 0; i < numStates; i++)
      if (nodes[latticeLength - 1][i] != null) {
        State s = transducer.getState(i);
        nodes[latticeLength - 1][i].beta = s.getFinalWeight();
        gammas[latticeLength - 1][i] = nodes[latticeLength - 1][i].alpha
            + nodes[latticeLength - 1][i].beta - totalWeight;
        if (incrementor != null) {
          double p = Math.exp(gammas[latticeLength - 1][i]);
          assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
              + ", gamma=" + gammas[latticeLength - 1][i];
          incrementor.incrementFinalState(s, p);
        }
      }

    for (int ip = latticeLength - 2; ip >= 0; ip--) {
      for (int i = 0; i < numStates; i++) {
        if (nodes[ip][i] == null
            || nodes[ip][i].alpha == Transducer.IMPOSSIBLE_WEIGHT)
          continue;
        State s = transducer.getState(i);
        CachedDotTransitionIterator iter =
          new CachedDotTransitionIterator((CRF.State)s,input,ip,
              null,cachedDots[ip][i]);
        auxModel.preProcess(index,ip,input);
        while (iter.hasNext()) {
          State destination = iter.next();
          int j = destination.getIndex();
          LatticeNode destinationNode = nodes[ip + 1][j];
          if (destinationNode != null) {
            double transitionWeight = iter.getWeight();
            transitionWeight += auxModel.getWeight(index,ip,input,iter);

            nodes[ip][i].beta = Transducer.sumLogProb(
                nodes[ip][i].beta, destinationNode.beta
                    + transitionWeight);
            double xi = nodes[ip][i].alpha + transitionWeight
                + nodes[ip + 1][j].beta - totalWeight;
            if (saveXis)
              xis[ip][i][j] = xi;
            if (incrementor != null || auxModel.numParameters() > 0
                || outputAlphabet != null) {
              double p = Math.exp(xi);
              assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p
                  + ", xis[" + ip + "][" + i + "][" + j
                  + "]=" + xi;
              if (incrementor != null) {
                incrementor.incrementTransition(iter, p);
              }
              if (incrementConstraints) {
                 // preprocess from above still applies
                 auxModel.incrementTransition(index, ip, input, iter, p);
              }
              if (outputAlphabet != null) {
                int outputIndex = outputAlphabet.lookupIndex(iter.getOutput(), false);
                assert (outputIndex >= 0);
                outputCounts[ip][outputIndex] += p;
              }
            }
          }
        }
        gammas[ip][i] = nodes[ip][i].alpha + nodes[ip][i].beta
            - totalWeight;
      }
    }
    if (incrementor != null)
      for (int i = 0; i < numStates; i++) {
        double p = Math.exp(gammas[0][i]);
        assert (p >= 0.0 && p <= 1.0 + 1e-6) : "p=" + p;
        incrementor.incrementInitialState(transducer.getState(i), p);
      }
    if (outputAlphabet != null) {
      labelings = new LabelVector[latticeLength];
      for (int ip = latticeLength - 2; ip >= 0; ip--) {
        assert (Math.abs(1.0 - MatrixOps.sum(outputCounts[ip])) < 0.000001);
        labelings[ip] = new LabelVector(outputAlphabet,
            outputCounts[ip]);
      }
    }
  }
 
   protected LatticeNode getLatticeNode(int ip, int stateIndex) {
      if (nodes[ip][stateIndex] == null)
        nodes[ip][stateIndex] = new LatticeNode(ip, transducer.getState(stateIndex));
      return nodes[ip][stateIndex];
    }

  public double[][][] getXis() {
    return xis;
  }

  public double[][] getGammas() {
    return gammas;
  }

  public double getTotalWeight() {
    assert (!Double.isNaN(totalWeight));
    return totalWeight;
  }

  public double getGammaWeight(int inputPosition, State s) {
    return gammas[inputPosition][s.getIndex()];
  }

  public double getGammaWeight(int inputPosition, int stateIndex) {
    return gammas[inputPosition][stateIndex];
  }

  public double getGammaProbability(int inputPosition, State s) {
    return Math.exp(gammas[inputPosition][s.getIndex()]);
  }

  public double getGammaProbability(int inputPosition, int stateIndex) {
    return Math.exp(gammas[inputPosition][stateIndex]);
  }

  public double getXiProbability(int ip, State s1, State s2) {
    if (xis == null)
      throw new IllegalStateException("xis were not saved.");
    int i = s1.getIndex();
    int j = s2.getIndex();
    return Math.exp(xis[ip][i][j]);
  }

  public double getXiWeight(int ip, State s1, State s2) {
    if (xis == null)
      throw new IllegalStateException("xis were not saved.");

    int i = s1.getIndex();
    int j = s2.getIndex();
    return xis[ip][i][j];
  }

  public int length() {
    return latticeLength;
  }

  public double getAlpha(int ip, State s) {
    LatticeNode node = getLatticeNode(ip, s.getIndex());
    return node.alpha;
  }

  public double getBeta(int ip, State s) {
    LatticeNode node = getLatticeNode(ip, s.getIndex());
    return node.beta;
  }

  public LabelVector getLabelingAtPosition(int outputPosition) {
    if (labelings != null)
      return labelings[outputPosition];
    return null;
  }

  public Transducer getTransducer() {
    return transducer;
  }

  protected class LatticeNode {
    int inputPosition;
    State state;
    Object output;
    double alpha = Transducer.IMPOSSIBLE_WEIGHT;
    double beta = Transducer.IMPOSSIBLE_WEIGHT;

    LatticeNode(int inputPosition, State state) {
      this.inputPosition = inputPosition;
      this.state = state;
      assert (this.alpha == Transducer.IMPOSSIBLE_WEIGHT);
    }
  }
 
  public Sequence getInput() {
    return input;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.pr.SumLatticePR

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.