package org.sf.mustru.test;
import com.aliasi.hmm.HiddenMarkovModel;
import com.aliasi.hmm.HmmDecoder;
import com.aliasi.hmm.TagWordLattice;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Date;
import java.util.Iterator;
import java.util.Vector;
import org.sf.mustru.utils.Constants;
import org.sf.mustru.utils.StringTools;
import org.apache.oro.text.perl.Perl5Util;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
public class TestPOSTagger
{
static String MUSTRU_HOME = System.getProperty("MUSTRU_HOME");
static int correct_tags, total_tags;
static Perl5Util p5util = new Perl5Util();
static Logger logger = null;
static boolean line_test = false; //*-- test a single line of text or the test file
/**
* Test the part of speech tagged trained in TrainPOSTager using the Brown Corpus
* @param
* @throws ClassNotFoundException
* @throws IOException
*/
public static void main(String[] args) throws ClassNotFoundException, IOException
{
PropertyConfigurator.configure (Constants.LOG4J_FILE);
logger = Logger.getLogger(TestPOSTagger.class.getName());
logger.debug("Started TestPOSTagger");
correct_tags = total_tags = 0;
//*-- read the POS Tagger model and generate the decoder
String modelFile = MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "pos" + File.separator + "pos_tagger";
logger.info("Reading POS tagger model from " + modelFile);
long startReadTime = new Date().getTime();
ObjectInputStream oi = new ObjectInputStream( new FileInputStream(modelFile) );
HiddenMarkovModel hmm = (HiddenMarkovModel) oi.readObject();
oi.close();
long readTime = new Date().getTime() - startReadTime;
System.out.println("Time to read model " + readTime + " msecs.");
HmmDecoder decoder = new HmmDecoder(hmm);
if (line_test)
{
//*-- test a single tagged sentence
// String tline = "Before/cs you/ppss let/vb loose/jj a/at howl/nn saying/vbg we/ppss announced/vbd its/pp$ coming/nn ,/, not/* once/rb but/cc several/ap times/nns ,/, indeed/rb we/ppss did/dod ./.";
// String tline = "This/dt seems/vbz like/cs a/at large/jj order/nn ./.";
String tline = "Plastic/nn signs/nns work/vb around/in the/at clock/nn ./.";
parseSentence(tline, decoder, true);
}
else
{
//*-- test a tagged file
String taggedFile = MUSTRU_HOME + File.separator + "data" + File.separator + "testing" + File.separator + "pos" + File.separator + "c.txt";
BufferedReader is = new BufferedReader( new FileReader(taggedFile));
String iline = "";
while ( (iline = is.readLine() ) != null)
{ iline = iline.trim(); parseSentence(iline, decoder, false); }
}
logger.info("Correctly assigned " + correct_tags + " out of " + total_tags + " tags.");
logger.info("Precision: " + format( 100.0 * ((double) correct_tags / total_tags)) + "%");
logger.info("Finished TestPOSTagger");
}
/**
* Accept a tagged line of text and a HMM decoder and count the number of correctly matched tags.
* @param String containing the pre-tagged line
* @param A HMM decoder
*/
static void parseSentence(String iline, HmmDecoder decoder, boolean diagnostics)
{
//*-- First split the line into pairs : word / tag
iline = iline.trim();
Vector word_tags = new Vector();
p5util.split(word_tags, iline);
//*-- next create string arrays of tokens and tags
String[] tokens = new String[word_tags.size()];
String[] tags = new String[word_tags.size()];
for (int i = 0; i < word_tags.size(); i++)
{
//*-- split the token and tag and build the arrays
Vector temp = new Vector(2);
p5util.split(temp, "m%/%", word_tags.elementAt(i).toString() );
tokens[i] = temp.elementAt(0).toString(); tags[i] = temp.elementAt(1).toString();
}
//*-- Get the generate tags and check with the correct tags
String[] gen_tags = decoder.firstBest(tokens); int num_correct = 0;
for (int i = 0; i < tokens.length; i++)
{ if ( tags[i].equalsIgnoreCase(gen_tags[i]) ) { num_correct++; }
else if (diagnostics)
{ logger.info("Error Token: " + tokens[i] + " Correct: " + tags[i] + " Assigned: " + gen_tags[i]); }
}
//*-- update the static counts
correct_tags += num_correct; total_tags += tokens.length;
//*-- dump the diagnostics if necessary
if (diagnostics)
{ nBest(tokens, decoder); confidence(tokens, decoder); }
return;
}
/**
* Print the top n best sequences in descending order of probability
* @param tokens
* @param decoder
*/
static void nBest(String[] tokens, HmmDecoder decoder)
{
logger.info("-------------------- NBEST --------------------------------");
logger.info(StringTools.fillin("JointLogProb", 15, true, ' ') + StringTools.fillin("Tags", 80, true, ' '));
Iterator nBestIt = decoder.nBest(tokens);
for (int n = 0; n < 3 && nBestIt.hasNext(); ++n)
{
//*-- get the ScoredObject to fetch the score and tags
ScoredObject tagScores = (ScoredObject) nBestIt.next();
//*-- format the score
double score = tagScores.score();
String f_score = Strings.decimalFormat(score, "#,##0.000", 9);
StringBuffer sb = new StringBuffer(); sb.append(StringTools.fillin(f_score, 15, true, ' '));
//*-- format the tags
String[] tags = (String[]) tagScores.getObject();
for (int i = 0; i < tokens.length; ++i)
sb.append(StringTools.fillin(tokens[i] + "_" + tags[i], 16, true, ' ' ) );
logger.info(sb.toString());
} //*-- end of for
}
/**
* Print the confidence of tags for individual tokens
* @param tokens
* @param decoder
*/
static void confidence(String[] tokens, HmmDecoder decoder)
{
logger.info("----------------- CONFIDENCE ------------------------");
logger.info(StringTools.fillin("Token", 10, true, ' ') + StringTools.fillin("PROB: TAG", 60, true, ' '));
//*-- get the lattice of tag scores and print
TagWordLattice lattice = decoder.lattice(tokens);
for (int i = 0; i < tokens.length; ++i)
{
StringBuffer sb = new StringBuffer();
ScoredObject[] tagScores = lattice.log2ConditionalTags(i);
sb.append(StringTools.fillin(tokens[i], 10, true, ' '));
for (int j = 0; j < 3; ++j)
{ double logProb = tagScores[j].score();
double conditionalProb = Math.pow(2.0,logProb);
String tag = (String) tagScores[j].getObject();
sb.append(StringTools.fillin(format(conditionalProb) + ": " + tag, 16, true, ' ') );
} //*-- end of inner for
logger.info(sb.toString());
} //*-- end of outer for
}
static String format(double x) { return Strings.decimalFormat(x, "#,##0.000", 9); }
}