Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.TreeAnnotatorAndBinarizer$TreeNullAnnotator

package edu.stanford.nlp.parser.lexparser;

import java.util.*;
import java.io.FileFilter;
import java.io.PrintWriter;

import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.CategoryWordTag;
import edu.stanford.nlp.ling.CategoryWordTagFactory;
import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.*;
import edu.stanford.nlp.util.Triple;


public class TreeAnnotatorAndBinarizer implements TreeTransformer {

  private final TreeFactory tf;
  private final TreebankLanguagePack tlp;
  private final TreeTransformer annotator;
  private final TreeBinarizer binarizer;
  private final PostSplitter postSplitter;
  private final boolean forceCNF;
  private final TrainOptions trainOptions;
  private final ClassicCounter<Tree> annotatedRuleCounts;
  private final ClassicCounter<String> annotatedStateCounts;

  public TreeAnnotatorAndBinarizer(TreebankLangParserParams tlpParams, boolean forceCNF, boolean insideFactor, boolean doSubcategorization, Options op) {
    this(tlpParams.headFinder(), tlpParams.headFinder(), tlpParams, forceCNF, insideFactor, doSubcategorization, op);
  }

  public TreeAnnotatorAndBinarizer(HeadFinder annotationHF, HeadFinder binarizationHF, TreebankLangParserParams tlpParams, boolean forceCNF, boolean insideFactor, boolean doSubcategorization, Options op) {
    this.trainOptions = op.trainOptions;
    if (doSubcategorization) {
      annotator = new TreeAnnotator(annotationHF, tlpParams, op);
    } else {
      annotator = new TreeNullAnnotator(annotationHF);
    }
    binarizer = new TreeBinarizer(binarizationHF, tlpParams.treebankLanguagePack(), insideFactor, trainOptions.markovFactor, trainOptions.markovOrder, trainOptions.compactGrammar() > 0, trainOptions.compactGrammar() > 1, trainOptions.HSEL_CUT, trainOptions.markFinalStates, trainOptions.simpleBinarizedLabels, trainOptions.noRebinarization);
    if (trainOptions.selectivePostSplit) {
      postSplitter = new PostSplitter(tlpParams, op);
    } else {
      postSplitter = null;
    }
    this.tf = new LabeledScoredTreeFactory(new CategoryWordTagFactory());
    this.tlp = tlpParams.treebankLanguagePack();
    this.forceCNF = forceCNF;
    if (trainOptions.printAnnotatedRuleCounts) {
      annotatedRuleCounts = new ClassicCounter<Tree>();
    } else {
      annotatedRuleCounts = null;
    }
    if (trainOptions.printAnnotatedStateCounts) {
      annotatedStateCounts = new ClassicCounter<String>();
    } else {
      annotatedStateCounts = null;
    }
  }

  public void dumpStats() {
    if (trainOptions.selectivePostSplit) {
      postSplitter.dumpStats();
    }
  }

  public void setDoSelectiveSplit(boolean doSelectiveSplit) {
    binarizer.setDoSelectiveSplit(doSelectiveSplit);
  }

  /**
   * Changes the ROOT label, and adds a Lexicon.BOUNDARY daughter to it.
   * This is needed for the dependency parser.
   * <i>Note:</i> This is a destructive operation on the tree passed in!!
   *
   * @param t The current tree into which a boundary is inserted
   */
  public void addRoot(Tree t) {
    if (t.isLeaf()) {
      System.err.println("Warning: tree is leaf: " + t);
      t = tf.newTreeNode(tlp.startSymbol(), Collections.singletonList(t));
    }
    t.setLabel(new CategoryWordTag(tlp.startSymbol(), Lexicon.BOUNDARY, Lexicon.BOUNDARY_TAG));
    List<Tree> preTermChildList = new ArrayList<Tree>();
    Tree boundaryTerm = tf.newLeaf(new Word(Lexicon.BOUNDARY));//CategoryWordTag(Lexicon.BOUNDARY,Lexicon.BOUNDARY,""));
    preTermChildList.add(boundaryTerm);
    Tree boundaryPreTerm = tf.newTreeNode(new CategoryWordTag(Lexicon.BOUNDARY_TAG, Lexicon.BOUNDARY, Lexicon.BOUNDARY_TAG), preTermChildList);
    List<Tree> childList = t.getChildrenAsList();
    childList.add(boundaryPreTerm);
    t.setChildren(childList);
  }

  /** The tree t is normally expected to be a Penn-Treebank-style tree
   *  in which the top node is an extra node that has a unary expansion.
   *  If this isn't the case, an extra node is added and the user is warned.
   */
  @Override
  public Tree transformTree(Tree t) {
    if (trainOptions.printTreeTransformations > 0) {
      trainOptions.printTrainTree(null, "ORIGINAL TREE:", t);
    }
    Tree trTree = annotator.transformTree(t);
    if (trainOptions.selectivePostSplit) {
      trTree = postSplitter.transformTree(trTree);
    }
    if (trainOptions.printTreeTransformations > 0) {
      trainOptions.printTrainTree(trainOptions.printAnnotatedPW, "ANNOTATED TREE:", trTree);
    }
    if (trainOptions.printAnnotatedRuleCounts) {
      Tree tr2 = trTree.deepCopy(new LabeledScoredTreeFactory(), new StringLabelFactory());
      Set<Tree> localTrees = tr2.localTrees();
      for (Tree tr : localTrees) {
        annotatedRuleCounts.incrementCount(tr);
      }
    }
    if (trainOptions.printAnnotatedStateCounts) {
      for (Tree subt : trTree) {
        if ( ! subt.isLeaf()) {
          annotatedStateCounts.incrementCount(subt.label().value());
        }
      }
    }

    // if we add the ROOT first, then we don't know how to percolate the heads at the top
    addRoot(trTree); // this creates a few non-binarized rules at the top

    Tree binarizedTree = binarizer.transformTree(trTree);
    if (trainOptions.printTreeTransformations > 0) {
      trainOptions.printTrainTree(trainOptions.printBinarizedPW, "BINARIZED TREE:", binarizedTree);
      trainOptions.printTreeTransformations--;
    }
    if (forceCNF) {
      binarizedTree = new CNFTransformers.ToCNFTransformer().transformTree(binarizedTree);
      //        System.out.println("BinarizedCNF:\n");
      //        binarizedTree.pennPrint();
    }
    return binarizedTree;
  }

  public void printRuleCounts() {
    System.err.println();
    for (Tree t : annotatedRuleCounts.keySet()) {
      System.err.print(annotatedRuleCounts.getCount(t) + "\t" +
                       t.label().value() + " -->");
      for (Tree dtr : t.getChildrenAsList()) {
        System.err.print(" ");
        System.err.print(dtr.label().value());
      }
      System.err.println();
    }
  }

  public void printStateCounts() {
    System.err.println();
    System.err.println("Annotated state counts");
    Set<String> keys = annotatedStateCounts.keySet();
    List<String> keyList = new ArrayList<String>(keys);
    Collections.sort(keyList);
    for (String s : keyList) {
      System.err.println(s + "\t" + annotatedStateCounts.getCount(s));
    }
  }


  // main helper function
  private static int numSubArgs(String[] args, int index) {
    int i = index;
    while (i + 1 < args.length && args[i + 1].charAt(0) != '-') {
      i++;
    }
    return i - index;
  }


  private static void removeDeleteSplittersFromSplitters(TreebankLanguagePack tlp, Options op) {
    if (op.trainOptions.deleteSplitters != null) {
      List<String> deleted = new ArrayList<String>();
      for (String del : op.trainOptions.deleteSplitters) {
        String baseDel = tlp.basicCategory(del);
        boolean checkBasic = del.equals(baseDel);
        for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) {
          String elem = it.next();
          String baseElem = tlp.basicCategory(elem);
          boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
          if (delStr) {
            it.remove();
            deleted.add(elem);
          }
        }
      }
      if (op.testOptions.verbose) {
        System.err.println("Removed from vertical splitters: " + deleted);
      }
    }
  }


  /** @return a pair of binaryTrainTreebank,binaryTuneTreebank.
   */
  public static Triple<Treebank, Treebank, Treebank> getAnnotatedBinaryTreebankFromTreebank(Treebank trainTreebank,
      Treebank secondaryTreebank,
      Treebank tuneTreebank,
      Options op) {
    // setup tree transforms
    TreebankLangParserParams tlpParams = op.tlpParams;
    TreebankLanguagePack tlp = tlpParams.treebankLanguagePack();

    if (op.testOptions.verbose) {
      PrintWriter pwErr = tlpParams.pw(System.err);
      pwErr.print("Training ");
      pwErr.println(trainTreebank.textualSummary(tlp));
      if (secondaryTreebank != null) {
        pwErr.print("Secondary training ");
        pwErr.println(secondaryTreebank.textualSummary(tlp));
      }
    }

    CompositeTreeTransformer trainTransformer =
      new CompositeTreeTransformer();
    if (op.trainOptions.preTransformer != null) {
      trainTransformer.addTransformer(op.trainOptions.preTransformer);
    }
    if (op.trainOptions.collinsPunc) {
      CollinsPuncTransformer collinsPuncTransformer =
        new CollinsPuncTransformer(tlp);
      trainTransformer.addTransformer(collinsPuncTransformer);
    }

    System.err.print("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer;
    if (!op.trainOptions.leftToRight) {
      binarizer = new TreeAnnotatorAndBinarizer(tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), !op.trainOptions.predictSplits, op);
    } else {
      binarizer = new TreeAnnotatorAndBinarizer(tlpParams.headFinder(), new LeftHeadFinder(), tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), !op.trainOptions.predictSplits, op);
    }
    trainTransformer.addTransformer(binarizer);

    if (op.wordFunction != null) {
      TreeTransformer wordFunctionTransformer =
        new TreeLeafLabelTransformer(op.wordFunction);
      trainTransformer.addTransformer(wordFunctionTransformer);
    }

    Treebank wholeTreebank;
    if (secondaryTreebank == null) {
      wholeTreebank = trainTreebank;
    } else {
      wholeTreebank = new CompositeTreebank(trainTreebank, secondaryTreebank);
    }

    if (op.trainOptions.selectiveSplit) {
      op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories(wholeTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, tlp);
      removeDeleteSplittersFromSplitters(tlp, op);
      if (op.testOptions.verbose) {
        List<String> list = new ArrayList<String>(op.trainOptions.splitters);
        Collections.sort(list);
        System.err.println("Parent split categories: " + list);
      }
    }

    if (op.trainOptions.selectivePostSplit) {
      // Do all the transformations once just to learn selective splits on annotated categories
      TreeTransformer myTransformer = new TreeAnnotator(tlpParams.headFinder(), tlpParams, op);
      wholeTreebank = wholeTreebank.transform(myTransformer);
      op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories(wholeTreebank, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, tlp);
      if (op.testOptions.verbose) {
        System.err.println("Parent post annotation split categories: " + op.trainOptions.postSplitters);
      }
    }
    if (op.trainOptions.hSelSplit) {
      // We run through all the trees once just to gather counts for hSelSplit!
      int ptt = op.trainOptions.printTreeTransformations;
      op.trainOptions.printTreeTransformations = 0;
      binarizer.setDoSelectiveSplit(false);
      for (Tree tree : wholeTreebank) {
        trainTransformer.transformTree(tree);
      }
      binarizer.setDoSelectiveSplit(true);
      op.trainOptions.printTreeTransformations = ptt;
    }
    // we've done all the setup now. here's where the train treebank is transformed.
    trainTreebank = trainTreebank.transform(trainTransformer);
    if (secondaryTreebank != null) {
      secondaryTreebank = secondaryTreebank.transform(trainTransformer);
    }
    if (op.trainOptions.printAnnotatedStateCounts) {
      binarizer.printStateCounts();
    }
    if (op.trainOptions.printAnnotatedRuleCounts) {
      binarizer.printRuleCounts();
    }

    if (tuneTreebank != null) {
      tuneTreebank = tuneTreebank.transform(trainTransformer);
    }

    if (op.testOptions.verbose) {
      binarizer.dumpStats();
    }

    return new Triple<Treebank, Treebank, Treebank>(trainTreebank, secondaryTreebank, tuneTreebank);
  }


  /** Lets you test out the TreeAnnotatorAndBinarizer on the command line.
   *
   *  @param args Command line arguments: All flags accepted by FactoredParser.setOptionFlag
   *     and -train treebankPath [fileRanges]
   */
  public static void main(String[] args) {
    Options op = new Options();
    String treebankPath = null;
    FileFilter trainFilter = null;

    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
      if (args[i].equalsIgnoreCase("-train")) {
        int numSubArgs = numSubArgs(args, i);
        i++;
        if (numSubArgs >= 1) {
          treebankPath = args[i];
          i++;
        } else {
          throw new RuntimeException("Error: -train option must have treebankPath as first argument.");
        }
        if (numSubArgs == 2) {
          trainFilter = new NumberRangesFileFilter(args[i++], true);
        } else if (numSubArgs >= 3) {
          int low = Integer.parseInt(args[i]);
          int high = Integer.parseInt(args[i + 1]);
          trainFilter = new NumberRangeFileFilter(low, high, true);
          i += 2;
        }
      } else {
        i = op.setOption(args, i);
      }
    }
    if (i < args.length) {
      System.err.println("usage: java TreeAnnotatorAndBinarizer options*");
      System.err.println("  Options are like for lexicalized parser including -train treebankPath fileRange]");
      return;
    }

    System.err.println("Annotating from treebank dir: " + treebankPath);
    Treebank trainTreebank = op.tlpParams.diskTreebank();
    if (trainFilter == null) {
      trainTreebank.loadPath(treebankPath);
    } else {
      trainTreebank.loadPath(treebankPath, trainFilter);
    }

    Treebank binaryTrainTreebank = getAnnotatedBinaryTreebankFromTreebank(trainTreebank, null, null, op).first();

    Iterator<Tree> it = trainTreebank.iterator();
    for (Tree t : binaryTrainTreebank) {
      System.out.println("Original tree:");
      it.next().pennPrint();
      System.out.println("Binarized tree:");
      t.pennPrint();
      System.out.println();
    }
  } // end main



  /** This does nothing but a function to change the tree nodes into
   *  CategoryWordTag, while the leaves are StringLabels. That's what the
   *  rest of the code assumes.
   */
  static class TreeNullAnnotator implements TreeTransformer {

    private final TreeFactory tf =
      new LabeledScoredTreeFactory(new CategoryWordTagFactory());
    private final HeadFinder hf;

    public Tree transformTree(Tree t) {
      // make a defensive copy which the helper method can then mangle
      Tree copy = t.treeSkeletonCopy(tf);
      return transformTreeHelper(copy);
    }

    private Tree transformTreeHelper(Tree t) {
      if (t != null) {
        String cat = t.label().value();
        if (t.isLeaf()) {
          Label label = new Word(cat); //new CategoryWordTag(cat,cat,"");
          t.setLabel(label);
        } else {
          Tree[] kids = t.children();
          for (Tree child : kids) {
            transformTreeHelper(child); // recursive call
          }
          Tree headChild = hf.determineHead(t);
          String tag;
          String word;
          if (headChild == null) {
            System.err.println("ERROR: null head for tree\n" + t.toString());
            word = null;
            tag = null;
          } else if (headChild.isLeaf()) {
            tag = cat;
            word = headChild.label().value();
          } else {
            CategoryWordTag headLabel = (CategoryWordTag) headChild.label();
            word = headLabel.word();
            tag = headLabel.tag();
          }
          Label label = new CategoryWordTag(cat, word, tag);
          t.setLabel(label);
        }
      }
      return t;
    }

    public TreeNullAnnotator(HeadFinder hf) {
      this.hf = hf;
    }

  } // end static class TreeNullAnnotator

} // end class TreeAnnotatorAndBinarizer
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.TreeAnnotatorAndBinarizer$TreeNullAnnotator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.