Package cc.mallet.fst.semi_supervised.pr

Source Code of cc.mallet.fst.semi_supervised.pr.SumLatticeKL

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.State;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Sequence;

/**
* Lattice for M-step/M-projection in PR.
*
* @author Kedar Bellare
* @author Gregory Druck
*/

public class SumLatticeKL implements SumLattice {
  // "ip" == "input position", "op" == "output position", "i" == "state index"
  Transducer t;
  double totalWeight;
  int latticeLength;
  double[][][] xis;
  Sequence input;

  protected SumLatticeKL() {}

  // If outputAlphabet is non-null, this will create a LabelVector
  // for each position in the output sequence indicating the
  // probability distribution over possible outputs at that time
  // index
  public SumLatticeKL(Transducer trans, Sequence input,
      double[] initProbs, double[] finalProbs, double[][][] xis,
      double[][][] cachedDots,
      Transducer.Incrementor incrementor) {
    assert (xis != null) : "Need transition probabilities";
    // Initialize some structures
    this.t = trans;

    this.input = input;
   
    latticeLength = input.size() + 1;
    int numStates = t.numStates();
    this.xis = xis;

    totalWeight = 0;

    // increment initial states
    for (int i = 0; i < numStates; i++) {
      if (t.getState(i).getInitialWeight() == Transducer.IMPOSSIBLE_WEIGHT)
        continue;
      if (initProbs != null) {
        totalWeight += initProbs[i] * t.getState(i).getInitialWeight();
        if (incrementor != null)
          incrementor.incrementInitialState(t.getState(i),
              initProbs[i]);
      }
    }

    for (int ip = 0; ip < latticeLength - 1; ip++)
      for (int i = 0; i < numStates; i++) {
        State s = t.getState(i);
        TransitionIterator iter = s.transitionIterator(input, ip);
        while (iter.hasNext()) {
          State destination = iter.next();
          double weight = iter.getWeight();
          double p = xis[ip][i][destination.getIndex()];
          totalWeight += p * weight;
          if (cachedDots != null) {
            cachedDots[ip][i][destination.getIndex()] = weight;
          }
          if (incrementor != null) {
            // this is used to gather "constraints",
            // so only probabilities under q are used
            incrementor.incrementTransition(iter, p);
          }
        }
      }

    for (int i = 0; i < numStates; i++) {
      if (t.getState(i).getFinalWeight() == Transducer.IMPOSSIBLE_WEIGHT)
        continue;
      if (finalProbs != null) {
        totalWeight += finalProbs[i] * t.getState(i).getFinalWeight();
        if (incrementor != null)
          incrementor.incrementFinalState(t.getState(i),
              finalProbs[i]);
      }
    }

    assert (totalWeight > Transducer.IMPOSSIBLE_WEIGHT) : "Total weight="
        + totalWeight;
  }

  public double[][][] getXis() {
    return xis;
  }

  public double[][] getGammas() {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getTotalWeight() {
    assert (!Double.isNaN(totalWeight));
    return totalWeight;
  }

  public double getGammaWeight(int inputPosition, State s) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getGammaWeight(int inputPosition, int stateIndex) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getGammaProbability(int inputPosition, State s) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getGammaProbability(int inputPosition, int stateIndex) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getXiProbability(int ip, State s1, State s2) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getXiWeight(int ip, State s1, State s2) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public int length() {
    return latticeLength;
  }

  public double getAlpha(int ip, State s) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public double getBeta(int ip, State s) {
    throw new UnsupportedOperationException("Not handled!");
  }

  public LabelVector getLabelingAtPosition(int outputPosition) {
    return null;
  }

  public Transducer getTransducer() {
    return t;
  }

  public Sequence getInput() {
    return input;
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.pr.SumLatticeKL

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.