Package edu.stanford.nlp.ie.crf

Source Code of edu.stanford.nlp.ie.crf.CRFLogConditionalObjectiveFunctionWithDropout

package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.concurrent.*;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Quadruple;

import java.util.*;

/**
* @author Mengqiu Wang
*/

public class CRFLogConditionalObjectiveFunctionWithDropout extends CRFLogConditionalObjectiveFunction {

  private final double delta;
  private final double dropoutScale;
  private double[][] dropoutPriorGradTotal;
  private final boolean dropoutApprox;
  private double[][] weightSquare;

  private final int[][][][] totalData;  // data[docIndex][tokenIndex][][]
  private int unsupDropoutStartIndex;
  private final double unsupDropoutScale;

  private List<List<Set<Integer>>> dataFeatureHash;
  private List<Map<Integer, List<Integer>>> condensedMap;
  private int[][] dataFeatureHashByDoc;
  private int edgeLabelIndexSize;
  private int nodeLabelIndexSize;
  private int[][] edgeLabels;
  private Map<Integer, List<Integer>> currPrevLabelsMap;
  private Map<Integer, List<Integer>> currNextLabelsMap;

  private ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> dropoutPriorThreadProcessor =
        new ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>>() {
    @Override
    public Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> process(Pair<Integer,Boolean> docIndexUnsup) {
      return expectedCountsAndValueForADoc(docIndexUnsup.first(), false, docIndexUnsup.second());
    }
    @Override
    public ThreadsafeProcessor<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> newInstance() {
      return this;
    }
  };

  //TODO(Mengqiu) Need to figure out what to do with dataDimension() in case of
  // mixed supervised+unsupervised data for SGD (AdaGrad)
  CRFLogConditionalObjectiveFunctionWithDropout(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, double delta, double dropoutScale, int multiThreadGrad, boolean dropoutApprox, double unsupDropoutScale, int[][][][] unsupDropoutData) {
    super(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad);
    this.delta = delta;
    this.dropoutScale = dropoutScale;
    this.dropoutApprox = dropoutApprox;
    dropoutPriorGradTotal = empty2D();
    this.unsupDropoutStartIndex = data.length;
    this.unsupDropoutScale = unsupDropoutScale;
    if (unsupDropoutData != null) {
      this.totalData = new int[data.length + unsupDropoutData.length][][][];
      for (int i=0; i<data.length; i++) {
        this.totalData[i] = data[i];
      }
      for (int i=0; i<unsupDropoutData.length; i++) {
        this.totalData[i+unsupDropoutStartIndex] = unsupDropoutData[i];
      }
    } else {
      this.totalData = data;
    }
    initEdgeLabels();
    initializeDataFeatureHash();
  }

  private void initEdgeLabels() {
    if (labelIndices.size() < 2)
      return;
    Index<CRFLabel> edgeLabelIndex = labelIndices.get(1);
    edgeLabelIndexSize = edgeLabelIndex.size();
    Index<CRFLabel> nodeLabelIndex = labelIndices.get(0);
    nodeLabelIndexSize = nodeLabelIndex.size();
    currPrevLabelsMap = new HashMap<Integer, List<Integer>>();
    currNextLabelsMap = new HashMap<Integer, List<Integer>>();
    edgeLabels = new int[edgeLabelIndexSize][];
    for (int k=0; k < edgeLabelIndexSize; k++) {
      int[] labelPair = edgeLabelIndex.get(k).getLabel();
      edgeLabels[k] = labelPair;
      int curr = labelPair[1];
      int prev = labelPair[0];
      if (!currPrevLabelsMap.containsKey(curr))
        currPrevLabelsMap.put(curr, new ArrayList<Integer>(numClasses));
      currPrevLabelsMap.get(curr).add(prev);
      if (!currNextLabelsMap.containsKey(prev))
        currNextLabelsMap.put(prev, new ArrayList<Integer>(numClasses));
      currNextLabelsMap.get(prev).add(curr);
    }
  }

  private Map<Integer, double[]> sparseE(Set<Integer> activeFeatures) {
    Map<Integer, double[]> aMap = new HashMap<Integer, double[]>(activeFeatures.size());
    for (int f: activeFeatures) {
      // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1);
      aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    return aMap;
  }

  private Map<Integer, double[]> sparseE(int[] activeFeatures) {
    Map<Integer, double[]> aMap = new HashMap<Integer, double[]>(activeFeatures.length);
    for (int f: activeFeatures) {
      // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1);
      aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    return aMap;
  }

  private Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> expectedCountsAndValueForADoc(int docIndex,
      boolean skipExpectedCountCalc, boolean skipValCalc) {

    int[] activeFeatures = dataFeatureHashByDoc[docIndex];
    List<Set<Integer>> docDataHash = dataFeatureHash.get(docIndex);
    Map<Integer, List<Integer>> condensedFeaturesMap = condensedMap.get(docIndex);

    double prob = 0;
    int[][][] docData = totalData[docIndex];
    int[] docLabels = null;
    if (docIndex < labels.length)
      docLabels = labels[docIndex];

    Timing timer = new Timing();
    double[][][] featureVal3DArr = null;
    if (featureVal != null)
      featureVal3DArr = featureVal[docIndex];

    // make a clique tree for this document
    CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr);

    if (!skipValCalc) {
      if (TIMED)
        timer.start();
      // compute the log probability of the document given the model with the parameters x
      int[] given = new int[window - 1];
      Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
      if (docLabels.length>docData.length) { // only true for self-training
        // fill the given array with the extra docLabels
        System.arraycopy(docLabels, 0, given, 0, given.length);
        // shift the docLabels array left
        int[] newDocLabels = new int[docData.length];
        System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length);
        docLabels = newDocLabels;
      }

      double startPosLogProb = cliqueTree.logProbStartPos();
      if (VERBOSE)
        System.err.printf("P_-1(Background) = % 5.3f\n", startPosLogProb);
      prob += startPosLogProb;

      // iterate over the positions in this document
      for (int i = 0; i < docData.length; i++) {
        int label = docLabels[i];
        double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
        if (VERBOSE) {
          System.err.println("P(" + label + "|" + ArrayMath.toString(given) + ")=" + Math.exp(p));
        }
        prob += p;
        System.arraycopy(given, 1, given, 0, given.length - 1);
        given[given.length - 1] = label;
      }
      if (TIMED) {
        long elapsedMs = timer.stop();
        System.err.println("Calculate objective took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    Map<Integer, double[]> EForADoc = sparseE(activeFeatures);
    List<Map<Integer, double[]>> EForADocPos = null;
    if (dropoutApprox) {
      EForADocPos = new ArrayList<Map<Integer, double[]>>(docData.length);
    }

    if (!skipExpectedCountCalc) {
      if (TIMED)
        timer.start();
      // compute the expected counts for this document, which we will need to compute the derivative
      // iterate over the positions in this document
      double fVal = 1.0;
      for (int i = 0; i < docData.length; i++) {
        Set<Integer> docDataHashI = docDataHash.get(i);
        Map<Integer, double[]> EForADocPosAtI = null;
        if (dropoutApprox)
          EForADocPosAtI = sparseE(docDataHashI);

        for (int fIndex: docDataHashI) {
          int j= map[fIndex];
          Index<CRFLabel> labelIndex = labelIndices.get(j);
          // for each possible labeling for that clique
          for (int k = 0; k < labelIndex.size(); k++) {
            int[] label = labelIndex.get(k).getLabel();
            double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features
            if (dropoutApprox)
              increScore(EForADocPosAtI, fIndex, k, fVal * p);
            increScore(EForADoc, fIndex, k, fVal * p);
          }
        }
        if (dropoutApprox) {
          for (int fIndex: docDataHashI) {
            if (condensedFeaturesMap.containsKey(fIndex)) {
              List<Integer> aList = condensedFeaturesMap.get(fIndex);
              for (int toCopyInto: aList) {
                double[] arr = EForADocPosAtI.get(fIndex);
                double[] targetArr = new double[arr.length];
                for (int q=0; q < arr.length; q++)
                  targetArr[q] = arr[q];
                EForADocPosAtI.put(toCopyInto, targetArr);
              }
            }
          }
          EForADocPos.add(EForADocPosAtI);
        }
      }

      // copy for condensedFeaturesMap
      for (Map.Entry<Integer, List<Integer>> entry: condensedFeaturesMap.entrySet()) {
        int key = entry.getKey();
        List<Integer> aList = entry.getValue();
        for (int toCopyInto: aList) {
          double[] arr = EForADoc.get(key);
          double[] targetArr = new double[arr.length];
          for (int i=0; i < arr.length; i++)
            targetArr[i] = arr[i];
          EForADoc.put(toCopyInto, targetArr);
        }
      }

      if (TIMED) {
        long elapsedMs = timer.stop();
        System.err.println("Expected count took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    Map<Integer, double[]> dropoutPriorGrad = null;
    if (prior == DROPOUT_PRIOR) {
      if (TIMED)
        timer.start();
      // we can optimize this, this is too large, don't need this big
      dropoutPriorGrad = sparseE(activeFeatures);

      // System.err.print("computing dropout prior for doc " + docIndex + " ... ");
      prob -= getDropoutPrior(cliqueTree, docData, EForADoc, docDataHash, activeFeatures, dropoutPriorGrad, condensedFeaturesMap, EForADocPos);
      // System.err.println(" done!");
      if (TIMED) {
        long elapsedMs = timer.stop();
        System.err.println("Dropout took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    return new Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>(docIndex, prob, EForADoc, dropoutPriorGrad);
  }

  private void increScore(Map<Integer, double[]> aMap, int fIndex, int k, double val) {
    aMap.get(fIndex)[k] += val;
  }

  private void increScoreAllowNull(Map<Integer, double[]> aMap, int fIndex, int k, double val) {
    if (!aMap.containsKey(fIndex)) {
      aMap.put(fIndex, new double[map[fIndex] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    aMap.get(fIndex)[k] += val;
  }

  private void initializeDataFeatureHash() {
    int macroActiveFeatureTotalCount = 0;
    int macroCondensedTotalCount = 0;
    int macroDocPosCount = 0;

    System.err.println("initializing data feature hash, sup-data size: " + data.length + ", unsup data size: " + (totalData.length-data.length));
    dataFeatureHash = new ArrayList<List<Set<Integer>>>(totalData.length);
    condensedMap = new ArrayList<Map<Integer, List<Integer>>>(totalData.length);
    dataFeatureHashByDoc = new int[totalData.length][];
    for (int m=0; m < totalData.length; m++) {
      Map<Integer, Integer> occurPos = new HashMap<Integer, Integer>();

      int[][][] aDoc = totalData[m];
      List<Set<Integer>> aList = new ArrayList<Set<Integer>>(aDoc.length);
      Set<Integer> setOfFeatures = new HashSet<Integer>();
      for (int i=0; i< aDoc.length; i++) { // positions in docI
        Set<Integer> aSet = new HashSet<Integer>();
        int[][] dataI = aDoc[i];
        for (int j=0; j < dataI.length; j++) {
          int[] dataJ = dataI[j];
          for (int item: dataJ) {
            if (j == 0) {
              if (occurPos.containsKey(item))
                occurPos.put(item, -1);
              else
                occurPos.put(item, i);
            }

            aSet.add(item);
          }
        }
        aList.add(aSet);
        setOfFeatures.addAll(aSet);
      }
      macroDocPosCount += aDoc.length;
      macroActiveFeatureTotalCount += setOfFeatures.size();

      if (CONDENSE) {
        if (DEBUG3)
          System.err.println("Before condense, activeFeatures = " + setOfFeatures.size());
        // examine all singletons, merge ones in the same position
        Map<Integer, List<Integer>> condensedFeaturesMap = new HashMap<Integer, List<Integer>>();
        int[] representFeatures = new int[aDoc.length];
        Arrays.fill(representFeatures, -1);

        for (Map.Entry<Integer, Integer> entry: occurPos.entrySet()) {
          int key = entry.getKey();
          int pos = entry.getValue();
          if (pos != -1) {
            if (representFeatures[pos] == -1) { // use this as representFeatures
              representFeatures[pos] = key;
              condensedFeaturesMap.put(key, new ArrayList<Integer>());
            } else { // condense this one
              int rep = representFeatures[pos];
              condensedFeaturesMap.get(rep).add(key);
              // remove key
              aList.get(pos).remove(key);
              setOfFeatures.remove(key);
            }
          }
        }
        int condensedCount = 0;
        for(Iterator<Map.Entry<Integer, List<Integer>>> it = condensedFeaturesMap.entrySet().iterator(); it.hasNext(); ) {
          Map.Entry<Integer, List<Integer>> entry = it.next();
          if(entry.getValue().size() == 0) {
            it.remove();
          } else {
            if (DEBUG3) {
              condensedCount += entry.getValue().size();
              for (int cond: entry.getValue())
                System.err.println("condense " + cond + " to " + entry.getKey());
            }
          }
        }
        if (DEBUG3)
          System.err.println("After condense, activeFeatures = " + setOfFeatures.size() + ", condensedCount = " + condensedCount);
        macroCondensedTotalCount += setOfFeatures.size();
        condensedMap.add(condensedFeaturesMap);
      }

      dataFeatureHash.add(aList);
      int[] arrOfIndex = new int[setOfFeatures.size()];
      int pos2 = 0;
      for(Integer ind: setOfFeatures)
        arrOfIndex[pos2++] = ind;
      dataFeatureHashByDoc[m] = arrOfIndex;
    }
    System.err.println("Avg. active features per position: " + (macroActiveFeatureTotalCount/ (macroDocPosCount+0.0)));
    System.err.println("Avg. condensed features per position: " + (macroCondensedTotalCount / (macroDocPosCount+0.0)));
    System.err.println("initializing data feature hash done!");
  }

  private double getDropoutPrior(CRFCliqueTree cliqueTree, int[][][] docData,
      Map<Integer, double[]> EForADoc, List<Set<Integer>> docDataHash, int[] activeFeatures, Map<Integer, double[]> dropoutPriorGrad,
      Map<Integer, List<Integer>> condensedFeaturesMap, List<Map<Integer, double[]>> EForADocPos) {

    Map<Integer, double[]> dropoutPriorGradFirstHalf = sparseE(activeFeatures);

    if (TIMED)
      System.err.println("activeFeatures size: "+activeFeatures.length + ", dataLen: " + docData.length);

    Timing timer = new Timing();
    if (TIMED)
      timer.start();

    double priorValue = 0;

    long elapsedMs = 0;
    Pair<double[][][], double[][][]> condProbs = getCondProbs(cliqueTree, docData);
    if (TIMED) {
      elapsedMs = timer.stop();
      System.err.println("\t cond prob took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
    }

    // first index position is curr index, second index curr-class, third index prev-class
    // e.g. [1][2][3] means curr is at position 1 with class 2, prev is at position 0 with class 3
    double[][][] prevGivenCurr = condProbs.first();
    // first index position is curr index, second index curr-class, third index next-class
    // e.g. [0][2][3] means curr is at position 0 with class 2, next is at position 1 with class 3
    double[][][] nextGivenCurr = condProbs.second();

    // first dim is doc length (i)
    // second dim is numOfFeatures (fIndex)
    // third dim is numClasses (y)
    // fourth dim is labelIndexSize (matching the clique type of fIndex, for \theta)
    double[][][][] FAlpha = null;
    double[][][][] FBeta  = null;
    if (!dropoutApprox) {
      FAlpha = new double[docData.length][][][];
      FBeta  = new double[docData.length][][][];
    }
    for (int i = 0; i < docData.length; i++) {
      if (!dropoutApprox) {
        FAlpha[i] = new double[activeFeatures.length][][];
        FBeta[i= new double[activeFeatures.length][][];
      }
    }

    if (!dropoutApprox) {
      if (TIMED) {
        timer.start();
      }
      // computing FAlpha
      int fIndex = 0;
      double aa, bb, cc = 0;
      boolean prevFeaturePresent  = false;
      for (int i = 1; i < docData.length; i++) {
        // for each possible clique at this position
        Set<Integer> docDataHashIMinusOne = docDataHash.get(i-1);
        for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
          fIndex = activeFeatures[fIndexPos];
          prevFeaturePresent = docDataHashIMinusOne.contains(fIndex);
          int j = map[fIndex];
          Index<CRFLabel> labelIndex = labelIndices.get(j);
          int labelIndexSize = labelIndex.size();

          if (FAlpha[i-1][fIndexPos] == null) {
            FAlpha[i-1][fIndexPos] = new double[numClasses][labelIndexSize];
            for (int q = 0; q < numClasses; q++)
              FAlpha[i-1][fIndexPos][q] = new double[labelIndexSize];
          }

          for (Map.Entry<Integer, List<Integer>> entry : currPrevLabelsMap.entrySet()) {
            int y = entry.getKey(); // value at i-1
            double[] sum = new double[labelIndexSize];
            for (int yPrime: entry.getValue()) { // value at i-2
              for (int kk = 0; kk < labelIndexSize; kk++) {
                int[] prevLabel = labelIndex.get(kk).getLabel();
                aa = (prevGivenCurr[i-1][y][yPrime]);
                bb = (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0);
                cc = 0;
                if (FAlpha[i-1][fIndexPos][yPrime] != null)
                  cc = FAlpha[i-1][fIndexPos][yPrime][kk];
                sum[kk] +=  aa * (bb + cc);
                // sum[kk] += (prevGivenCurr[i-1][y][yPrime]) * ((prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) + FAlpha[i-1][fIndexPos][yPrime][kk]);
                if (DEBUG2)
                  System.err.printf("alpha[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f), prevLabel=%s\n", i, fIndex, y, kk, (prevGivenCurr[i-1][y][yPrime]), (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) , FAlpha[i-1][fIndexPos][yPrime][kk], Arrays.toString(prevLabel));
              }
            }
            if (FAlpha[i][fIndexPos] == null) {
              FAlpha[i][fIndexPos] = new double[numClasses][];
            }
            FAlpha[i][fIndexPos][y] = sum;
            if (DEBUG2)
              System.err.println("FAlpha["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));

          }
        }
      }
      if (TIMED) {
        elapsedMs = timer.stop();
        System.err.println("\t alpha took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
        timer.start();
      }
      // computing FBeta
      int docDataLen = docData.length;
      for (int i = docDataLen-2; i >= 0; i--) {
        Set<Integer> docDataHashIPlusOne = docDataHash.get(i+1);
        // for each possible clique at this position
        for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
          fIndex = activeFeatures[fIndexPos];
          boolean nextFeaturePresent = docDataHashIPlusOne.contains(fIndex);
          int j = map[fIndex];
          Index<CRFLabel> labelIndex = labelIndices.get(j);
          int labelIndexSize = labelIndex.size();

          if (FBeta[i+1][fIndexPos] == null) {
            FBeta[i+1][fIndexPos] = new double[numClasses][labelIndexSize];
            for (int q = 0; q < numClasses; q++)
              FBeta[i+1][fIndexPos][q] = new double[labelIndexSize];
          }

          for (Map.Entry<Integer, List<Integer>> entry : currNextLabelsMap.entrySet()) {
            int y = entry.getKey(); // value at i
            double[] sum = new double[labelIndexSize];
            for (int yPrime: entry.getValue()) { // value at i+1
              for (int kk=0; kk < labelIndexSize; kk++) {
                int[] nextLabel = labelIndex.get(kk).getLabel();
                // System.err.println("labelIndexSize:"+labelIndexSize+", nextGivenCurr:"+nextGivenCurr+", nextLabel:"+nextLabel+", FBeta["+(i+1)+"]["+ fIndexPos +"]["+yPrime+"] :"+FBeta[i+1][fIndexPos][yPrime]);
                aa = (nextGivenCurr[i][y][yPrime]);
                bb = (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0);
                cc = 0;
                if (FBeta[i+1][fIndexPos][yPrime] != null)
                  cc = FBeta[i+1][fIndexPos][yPrime][kk];
                sum[kk] +=  aa * ( bb + cc);
                // sum[kk] += (nextGivenCurr[i][y][yPrime]) * ( (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0) + FBeta[i+1][fIndexPos][yPrime][kk]);
                if (DEBUG2)
                  System.err.printf("beta[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f)\n", i, fIndex, y, kk, (nextGivenCurr[i][y][yPrime]), (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0), FBeta[i+1][fIndexPos][yPrime][kk]);
              }
            }
            if (FBeta[i][fIndexPos] == null) {
              FBeta[i][fIndexPos] = new double[numClasses][];
            }
            FBeta[i][fIndexPos][y] = sum;
            if (DEBUG2)
              System.err.println("FBeta["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));
          }
        }
      }
      if (TIMED) {
        elapsedMs = timer.stop();
        System.err.println("\t beta took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }
    if (TIMED) {
      timer.start();
    }

    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * (1-PtYYp)'
    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * -PtYYp'
    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1 - 2 * PtYYp)

    double deltaDivByOneMinusDelta = delta / (1.0-delta);

    Timing innerTimer = new Timing();
    long eTiming = 0;
    long dropoutTiming= 0;

    boolean containsFeature = false;
    // iterate over the positions in this document
    for (int i = 1; i < docData.length; i++) {
      Set<Integer> docDataHashI = docDataHash.get(i);
      Map<Integer, double[]> EForADocPosAtI = null;
      if (dropoutApprox)
        EForADocPosAtI = EForADocPos.get(i);

      // for each possible clique at this position
      for (int k = 0; k < edgeLabelIndexSize; k++) { // sum over (y, y')
        int[] label = edgeLabels[k];
        int y = label[0];
        int yP = label[1];

        if (TIMED)
          innerTimer.start();

        // important to use label as an int[] for calculating cliqueTree.prob()
        // if it's a node clique, and label index is 2, if we don't use int[]{2} but just pass 2,
        // cliqueTree is going to treat it as index of the edge clique labels, and convert 2
        // into int[]{0,2}, and return the edge prob marginal instead of node marginal
        double PtYYp = cliqueTree.prob(i, label);
        double PtYYpTimesOneMinusPtYYp = PtYYp * (1.0 - PtYYp);
        double oneMinus2PtYYp = (1.0 - 2 * PtYYp);
        double USum = 0;
        int fIndex;
        for (int jjj=0; jjj<labelIndices.size(); jjj++) {
          for (int n = 0; n < docData[i][jjj].length; n++) {
            fIndex = docData[i][jjj][n];
            int valIndex;
            if (jjj == 1)
              valIndex = k;
            else
              valIndex = yP;
            double theta;
            try {
              theta = weights[fIndex][valIndex];
            }catch (Exception ex) {
              System.err.printf("weights[%d][%d], map[%d]=%d, labelIndices.get(map[%d]).size() = %d, weights.length=%d\n", fIndex, valIndex, fIndex, map[fIndex], fIndex, labelIndices.get(map[fIndex]).size(), weights.length);
              throw new RuntimeException(ex);
            }

            USum += weightSquare[fIndex][valIndex];

            // first half of derivative: VarU' * PtYYp * (1-PtYYp)
            double VarUp = deltaDivByOneMinusDelta * theta;
            increScoreAllowNull(dropoutPriorGradFirstHalf, fIndex, valIndex, VarUp * PtYYpTimesOneMinusPtYYp);
          }
        }

        if (TIMED) {
          eTiming += innerTimer.stop();
          innerTimer.start();
        }
        double VarU = 0.5 * deltaDivByOneMinusDelta * USum;

        // update function objective
        priorValue += VarU * PtYYpTimesOneMinusPtYYp;

        double VarUTimesOneMinus2PtYYp = VarU * oneMinus2PtYYp;

        // second half of derivative: VarU * PtYYp' * (1 - 2 * PtYYp)
        // boolean prevFeaturePresent = false;
        // boolean nextFeaturePresent = false;
        for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
          fIndex = activeFeatures[fIndexPos];
          containsFeature = docDataHashI.contains(fIndex);

          // if (!containsFeature) continue;
          int jj = map[fIndex];
          Index<CRFLabel> fLabelIndex = labelIndices.get(jj);
          for (int kk = 0; kk < fLabelIndex.size(); kk++) { // for all parameter \theta
            int[] fLabel = fLabelIndex.get(kk).getLabel();
            // if (FAlpha[i] != null)
            //   System.err.println("fIndex: " + fIndex+", FAlpha[i].size:"+FAlpha[i].length);
            double fCount = containsFeature && ((jj == 0 && fLabel[0] == yP) || (jj == 1 && k == kk)) ? 1 : 0;

            double alpha;
            double beta;
            double condE;
            double PtYYpPrime;
            if (!dropoutApprox) {
              alpha = ((FAlpha[i][fIndexPos] == null || FAlpha[i][fIndexPos][y] == null) ? 0 : FAlpha[i][fIndexPos][y][kk]);
              beta = ((FBeta[i][fIndexPos] == null || FBeta[i][fIndexPos][yP] == null) ? 0 : FBeta[i][fIndexPos][yP][kk]);
              condE = fCount + alpha + beta;
              if (DEBUG2)
                System.err.printf("fLabel=%s, yP = %d, fCount:%f = ((jj == 0 && fLabel[0] == yP)=%b || (jj == 1 && k == kk))=%b\n", Arrays.toString(fLabel),yP, fCount,(jj == 0 && fLabel[0] == yP) , (jj == 1 && k == kk));
              PtYYpPrime = PtYYp * (condE - EForADoc.get(fIndex)[kk]);
            } else {
              double E = 0;
              if (EForADocPosAtI.containsKey(fIndex))
                E = EForADocPosAtI.get(fIndex)[kk];
              condE = fCount;
              PtYYpPrime = PtYYp * (condE - E);
            }

            if (DEBUG2)
              System.err.printf("for i=%d, k=%d, y=%d, yP=%d, fIndex=%d, kk=%d, PtYYpPrime=% 5.3f, PtYYp=% 3.3f, (condE-E[fIndex][kk])=% 3.3f, condE=% 3.3f, E[fIndex][k]=% 3.3f, alpha=% 3.3f, beta=% 3.3f, fCount=% 3.3f\n", i, k, y, yP, fIndex, kk, PtYYpPrime, PtYYp, (condE - EForADoc.get(fIndex)[kk]), condE, EForADoc.get(fIndex)[kk], alpha, beta, fCount);

            increScore(dropoutPriorGrad, fIndex, kk, VarUTimesOneMinus2PtYYp * PtYYpPrime);
          }

          if (DEBUG2)
            System.err.println();
        }
        if (TIMED)
          dropoutTiming += innerTimer.stop();
      }
    }
    if (CONDENSE) {
      // copy for condensedFeaturesMap
      for (Map.Entry<Integer, List<Integer>> entry: condensedFeaturesMap.entrySet()) {
        int key = entry.getKey();
        List<Integer> aList = entry.getValue();
        for (int toCopyInto: aList) {
          double[] arr = dropoutPriorGrad.get(key);
          double[] targetArr = new double[arr.length];
          for (int i=0; i < arr.length; i++)
            targetArr[i] = arr[i];
          dropoutPriorGrad.put(toCopyInto, targetArr);
        }
      }
    }

    if (DEBUG3) {
      System.err.print("dropoutPriorGradFirstHalf.keys:[");
      for (int key: dropoutPriorGradFirstHalf.keySet())
        System.err.print(" "+key);
      System.err.println("]");

      System.err.print("dropoutPriorGrad.keys:[");
      for (int key: dropoutPriorGrad.keySet())
        System.err.print(" "+key);
      System.err.println("]");
    }

    for (Map.Entry<Integer, double[]> entry: dropoutPriorGrad.entrySet()) {
      Integer key = entry.getKey();
      double[] target = entry.getValue();
      if (dropoutPriorGradFirstHalf.containsKey(key)) {
        double[] source = dropoutPriorGradFirstHalf.get(key);
        for (int i=0; i<target.length; i++) {
          target[i] += source[i];
        }
      }
    }
    // for (int i=0;i<dropoutPriorGrad.length;i++)
    //   for (int j=0; j<dropoutPriorGrad[i].length;j++) {
    //     if (DEBUG3)
    //       System.err.printf("f=%d, k=%d, dropoutPriorGradFirstHalf[%d][%d]=% 5.3f, dropoutPriorGrad[%d][%d]=% 5.3f\n", i, j, i, j, dropoutPriorGradFirstHalf[i][j], i, j, dropoutPriorGrad[i][j]);
    //     dropoutPriorGrad[i][j] += dropoutPriorGradFirstHalf[i][j];
    //   }

    if (TIMED) {
      elapsedMs = timer.stop();
      System.err.println("\t grad took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      System.err.println("\t\t exp took: " + Timing.toMilliSecondsString(eTiming) + " ms");
      System.err.println("\t\t dropout took: " + Timing.toMilliSecondsString(dropoutTiming) + " ms");
    }

    return dropoutScale * priorValue;
  }


  @Override
  public void setWeights(double[][] weights) {
    super.setWeights(weights);
    if (weightSquare == null) {
      weightSquare = new double[weights.length][];
      for (int i = 0; i < weights.length; i++)
        weightSquare[i] = new double[weights[i].length];
    }
    for (int i = 0; i < weights.length; i++) {
      for (int j=0; j < weights[i].length; j++) {
        double w = weights[i][j];
        weightSquare[i][j] = w * w;
      }
    }
  }

  /**
   * Calculates both value and partial derivatives at the point x, and save them internally.
   */
  @Override
  public void calculate(double[] x) {

    double prob = 0.0; // the log prob of the sequence given the model, which is the negation of value at this point
    // final double[][] weights = to2D(x);
    to2D(x, weights);

    setWeights(weights);

    // the expectations over counts
    // first index is feature index, second index is of possible labeling
    // double[][] E = empty2D();
    clear2D(E);
    clear2D(dropoutPriorGradTotal);

    MulticoreWrapper<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>> wrapper =
      new MulticoreWrapper<Pair<Integer, Boolean>, Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>>>(multiThreadGrad, dropoutPriorThreadProcessor);
    // supervised part
    for (int m = 0; m < totalData.length; m++) {
      boolean submitIsUnsup = (m >= unsupDropoutStartIndex);
      wrapper.put(new Pair<Integer, Boolean>(m, submitIsUnsup));
      while (wrapper.peek()) {
        Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> result = wrapper.poll();
        int docIndex = result.first();
        boolean isUnsup = docIndex >= unsupDropoutStartIndex;
        if (isUnsup) {
          prob += unsupDropoutScale * result.second();
        } else {
          prob += result.second();
        }

        Map<Integer, double[]> partialDropout = result.fourth();
        if (partialDropout != null) {
          if (isUnsup) {
            combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale);
          } else {
            combine2DArr(dropoutPriorGradTotal, partialDropout);
          }
        }

        if (!isUnsup) {
          Map<Integer, double[]> partialE = result.third();
          if (partialE != null)
            combine2DArr(E, partialE);
        }
      }
    }
    wrapper.join();
    while (wrapper.peek()) {
      Quadruple<Integer, Double, Map<Integer, double[]>, Map<Integer, double[]>> result = wrapper.poll();
      int docIndex = result.first();
      boolean isUnsup = docIndex >= unsupDropoutStartIndex;
      if (isUnsup) {
        prob += unsupDropoutScale * result.second();
      } else {
        prob += result.second();
      }

      Map<Integer, double[]> partialDropout = result.fourth();
      if (partialDropout != null) {
        if (isUnsup) {
          combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale);
        } else {
          combine2DArr(dropoutPriorGradTotal, partialDropout);
        }
      }

      if (!isUnsup) {
        Map<Integer, double[]> partialE = result.third();
        if (partialE != null)
          combine2DArr(E, partialE);
      }
    }


    if (Double.isNaN(prob)) { // shouldn't be the case
      throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunctionWithDropout.calculate()" +
              " - this may well indicate numeric underflow due to overly long documents.");
    }

    // because we minimize -L(\theta)
    value = -prob;
    if (VERBOSE) {
      System.err.println("value is " + Math.exp(-value));
    }

    // compute the partial derivative for each feature by comparing expected counts to empirical counts
    int index = 0;
    for (int i = 0; i < E.length; i++) {
      for (int j = 0; j < E[i].length; j++) {
        // because we minimize -L(\theta)
        derivative[index] = (E[i][j] - Ehat[i][j]);
        derivative[index] += dropoutScale * dropoutPriorGradTotal[i][j];
        if (VERBOSE) {
          System.err.println("deriv(" + i + ',' + j + ") = " + E[i][j] + " - " + Ehat[i][j] + " = " + derivative[index]);
        }
        index++;
      }
    }
  }

}
TOP

Related Classes of edu.stanford.nlp.ie.crf.CRFLogConditionalObjectiveFunctionWithDropout

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.