Package cc.mallet.fst

Source Code of cc.mallet.fst.SegmentationEvaluator

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.fst;

import java.util.logging.Logger;
import java.util.regex.Pattern;

import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.types.TokenSequence;

import cc.mallet.util.MalletLogger;

public class SegmentationEvaluator extends TransducerEvaluator
{
  private static Logger logger = MalletLogger.getLogger(SegmentationEvaluator.class.getName());

  // equals() is called on these objects to determine if this token is the start (end) of a segment
  // "segmentEndTag" should return "true" for the token *after* the end of the segment (i.e. that token
  // is not part of the segment).
  static Pattern startRegex = Pattern.compile ("^B.*");
  //static Pattern endRegex = Pattern.compile ("^O.*");
  Object segmentStartTag = new Object () { public boolean equals (Object o) { return startRegex.matcher(o.toString()).matches(); } };
  Object segmentEndTag = new Object () { public boolean equals (Object o) { return false; } };
     
  public SegmentationEvaluator (InstanceList[] instanceLists, String[] descriptions) {
    super (instanceLists, descriptions);
  }
 
  public SegmentationEvaluator (InstanceList instanceList1, String description1) {
    this (new InstanceList[] {instanceList1}, new String[] {description1});
  }

  public SegmentationEvaluator (InstanceList instanceList1, String description1,
      InstanceList instanceList2, String description2) {
    this (new InstanceList[] {instanceList1, instanceList2}, new String[] {description1, description2});
  }

  public SegmentationEvaluator (InstanceList instanceList1, String description1,
      InstanceList instanceList2, String description2,
      InstanceList instanceList3, String description3) {
    this (new InstanceList[] {instanceList1, instanceList2, instanceList3}, new String[] {description1, description2, description3});
  }

  public SegmentationEvaluator setSegmentStartTag (Object o) { this.segmentStartTag = o; return this; }
  public SegmentationEvaluator setSegmentEndTag (Object o) { this.segmentEndTag = o; return this; }

  public void evaluateInstanceList (TransducerTrainer tt, InstanceList data, String description)
  {
    Transducer model = tt.getTransducer();
    int numCorrectTokens, totalTokens;
    int numTrueSegments, numPredictedSegments, numCorrectSegments;
    int numCorrectSegmentsInAlphabet, numCorrectSegmentsOOV;
    int numIncorrectSegmentsInAlphabet, numIncorrectSegmentsOOV;
    TokenSequence sourceTokenSequence = null;

    totalTokens = numCorrectTokens = 0;
    numTrueSegments = numPredictedSegments = numCorrectSegments = 0;
    numCorrectSegmentsInAlphabet = numCorrectSegmentsOOV = 0;
    numIncorrectSegmentsInAlphabet = numIncorrectSegmentsOOV = 0;
    for (int i = 0; i < data.size(); i++) {
      Instance instance = data.get(i);
      Sequence input = (Sequence) instance.getData();
      //String tokens = null;
      //if (instance.getSource() != null)
      //tokens = (String) instance.getSource().toString();
      Sequence trueOutput = (Sequence) instance.getTarget();
      assert (input.size() == trueOutput.size());
      Sequence predOutput = model.transduce (input);
      assert (predOutput.size() == trueOutput.size());
      boolean trueStart, predStart;
      for (int j = 0; j < trueOutput.size(); j++) {
        totalTokens++;
        trueStart = predStart = false;
        if (segmentStartTag.equals(trueOutput.get(j))) {
          numTrueSegments++;
          trueStart = true;
        }
        if (segmentStartTag.equals(predOutput.get(j))) {
          predStart = true;
          numPredictedSegments++;
        }
        if (trueStart && predStart) {
          int m;
          //StringBuffer sb = new StringBuffer();
          //sb.append (tokens.charAt(j));
          for (m = j+1; m < trueOutput.size(); m++) {
            trueStart = predStart = false; // Here, these actually mean "end", not "start"
            if (segmentEndTag.equals(trueOutput.get(m)))
              trueStart = true;
            if (segmentEndTag.equals(predOutput.get(m)))
              predStart = true;
            if (trueStart || predStart) {
              if (trueStart && predStart) {
                // It is a correct segment
                numCorrectSegments++;
                //if (HashFile.allLexicons.contains(sb.toString()))
                //numCorrectSegmentsInAlphabet++;
                //else
                //numCorrectSegmentsOOV++;
              } else {
                // It is an incorrect segment; let's find out if it was in the lexicon
                //for (int mm = m; mm < trueOutput.size(); mm++) {
                //if (segmentEndTag.equals(predOutput.get(mm)))
                //break;
                //sb.append (tokens.charAt(mm));
                //}
                //if (HashFile.allLexicons.contains(sb.toString()))
                //numIncorrectSegmentsInAlphabet++;
                //else
                //numIncorrectSegmentsOOV++;
              }
              break;
            }
            //sb.append (tokens.charAt(m));
          }
          // for the case of the end of the sequence
          if(m==trueOutput.size()) {
            if (trueStart==predStart) {
              numCorrectSegments++;
              //if (HashFile.allLexicons.contains(sb.toString()))
              //numCorrectSegmentsInAlphabet++;
              //else
              //numCorrectSegmentsOOV++;
            } else {
              //if (HashFile.allLexicons.contains(sb.toString()))
              //numIncorrectSegmentsInAlphabet++;
              //else
              //numIncorrectSegmentsOOV++;
            }
          }
        } else if (predStart) {
          // Here is an incorrect predicted start, find out if the word is in the lexicon
          //StringBuffer sb = new StringBuffer();
          //sb.append (tokens.charAt(j));
          //for (int mm = j+1; mm < trueOutput.size(); mm++) {
          //if (segmentEndTag.equals(predOutput.get(mm)))
          //break;
          //sb.append (tokens.charAt(mm));
          //}
          //if (HashFile.allLexicons.contains(sb.toString()))
          //numIncorrectSegmentsInAlphabet++;
          //else
          //numIncorrectSegmentsOOV++;
        }
        if (trueOutput.get(j).equals(predOutput.get(j)))
          numCorrectTokens++;
      }
    }
    logger.info (description +" accuracy="+((double)numCorrectTokens)/totalTokens);
    double precision = numPredictedSegments == 0 ? 1 : ((double)numCorrectSegments) / numPredictedSegments;
    double recall = numTrueSegments == 0 ? 1 : ((double)numCorrectSegments) / numTrueSegments;
    double f1 = recall+precision == 0.0 ? 0.0 : (2.0 * recall * precision) / (recall + precision);
    logger.info (" precision="+precision+" recall="+recall+" f1="+f1);
    logger.info ("segments true="+numTrueSegments+" pred="+numPredictedSegments+" correct="+numCorrectSegments+" misses="+(numTrueSegments-numCorrectSegments)+" alarms="+(numPredictedSegments-numCorrectSegments));
    //System.out.println ("correct segments OOV="+numCorrectSegmentsOOV+" IV="+numCorrectSegmentsInAlphabet);
    //System.out.println ("incorrect segments OOV="+numIncorrectSegmentsOOV+" IV="+numIncorrectSegmentsInAlphabet);
  }
}
TOP

Related Classes of cc.mallet.fst.SegmentationEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.