package edu.stanford.nlp.parser.metrics;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.trees.Constituent;
import edu.stanford.nlp.trees.LabeledConstituent;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreebankLanguagePack;
/**
* Provides a method for deciding how similar two trees are.
*
* @author John Bauer
*/
public class TreeSpanScoring {
private TreeSpanScoring() {} // static only
/**
* Counts how many spans are present in goldTree, including
* preterminals, but not present in guessTree, along with how many
* spans are present in guessTree and not goldTree. Each one counts
* as an error, meaning that something like a mislabeled span or
* preterminal counts as two errors.
* <br>
* Span labels are compared using the basicCategory() function
* from the passed in TreebankLanguagePack.
*/
public static int countSpanErrors(TreebankLanguagePack tlp, Tree goldTree, Tree guessTree) {
Set<Constituent> goldConstituents = goldTree.constituents(LabeledConstituent.factory());
Set<Constituent> guessConstituents = guessTree.constituents(LabeledConstituent.factory());
Set<Constituent> simpleGoldConstituents = simplifyConstituents(tlp, goldConstituents);
Set<Constituent> simpleGuessConstituents = simplifyConstituents(tlp, guessConstituents);
//System.out.println(simpleGoldConstituents);
//System.out.println(simpleGuessConstituents);
int errors = 0;
for (Constituent gold : simpleGoldConstituents) {
if (!simpleGuessConstituents.contains(gold)) {
++errors;
}
}
for (Constituent guess : simpleGuessConstituents) {
if (!simpleGoldConstituents.contains(guess)) {
++errors;
}
}
// The spans returned by constituents() doesn't include the
// preterminals, so we need to count those ourselves now
List<TaggedWord> goldWords = goldTree.taggedYield();
List<TaggedWord> guessWords = guessTree.taggedYield();
int len = Math.min(goldWords.size(), guessWords.size());
for (int i = 0; i < len; ++i) {
String goldTag = tlp.basicCategory(goldWords.get(i).tag());
String guessTag = tlp.basicCategory(guessWords.get(i).tag());
if (!goldTag.equals(guessTag)) {
// we count one error for each span that is present in the
// gold and not in the guess, and one error for each span that
// is present in the guess and not the gold, so this counts as
// two errors
errors += 2;
}
}
return errors;
}
public static Set<Constituent> simplifyConstituents(TreebankLanguagePack tlp, Set<Constituent> constituents) {
Set<Constituent> newConstituents = new HashSet<Constituent>();
for (Constituent con : constituents) {
if (!(con instanceof LabeledConstituent)) {
throw new AssertionError("Unexpected constituent type " + con.getClass());
}
LabeledConstituent labeled = (LabeledConstituent) con;
newConstituents.add(new LabeledConstituent(labeled.start(), labeled.end(), tlp.basicCategory(labeled.value())));
}
return newConstituents;
}
}