/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.types;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
/**
* Tracks ROC data for instances in {@link Trial} results.
*
* @see Trial
* @see InstanceList
* @see Classifier
* @see Classification
*
* @author Michael Bond <a href="mailto:mikejbond@gmail.com">mikejbond@gmail.com</a>
*/
public class ROCData implements AlphabetCarrying, Serializable {
private static final long serialVersionUID = -2060194953037720640L;
public static final int TRUE_POSITIVE = 0;
public static final int FALSE_POSITIVE = 1;
public static final int FALSE_NEGATIVE = 2;
public static final int TRUE_NEGATIVE = 3;
private final LabelAlphabet labelAlphabet;
/** Matrix of class, threshold, [tp, fp, fn, tn] */
private final int[][][] counts;
private final double[] thresholds;
/**
* Constructs a new object
*
* @param thresholds Array of thresholds to track counts for
* @param labelAlphabet Label alphabet for instances in {@link Trial}
*/
public ROCData(double[] thresholds, final LabelAlphabet labelAlphabet) {
// ensure that thresholds are sorted
Arrays.sort(thresholds);
this.counts = new int[labelAlphabet.size()][thresholds.length][4];
this.labelAlphabet = labelAlphabet;
this.thresholds = thresholds;
}
/**
* Adds classification results to the ROC data
*
* @param trial Trial results to add to ROC data
*/
public void add(Classification classification) {
int correctIndex = classification.getInstance().getLabeling().getBestIndex();
LabelVector lv = classification.getLabelVector();
double[] values = lv.getValues();
if (!Alphabet.alphabetsMatch(this, lv)) {
throw new IllegalArgumentException ("Alphabets do not match");
}
int numLabels = this.labelAlphabet.size();
for (int label = 0; label < numLabels; label++) {
double labelValue = values[label];
int[][] thresholdCounts = this.counts[label];
int threshold = 0;
// add the trial to all the thresholds it would be positive for
for (; threshold < this.thresholds.length && labelValue >= this.thresholds[threshold]; threshold++) {
if (correctIndex == label) {
thresholdCounts[threshold][TRUE_POSITIVE]++;
} else {
thresholdCounts[threshold][FALSE_POSITIVE]++;
}
}
// add the trial to the thresholds it would be negative for
for (; threshold < this.thresholds.length; threshold++) {
if (correctIndex == label) {
thresholdCounts[threshold][FALSE_NEGATIVE]++;
} else {
thresholdCounts[threshold][TRUE_NEGATIVE]++;
}
}
}
}
/**
* Adds trial results to the ROC data
*
* @param trial Trial results to add to ROC data
*/
public void add(Trial trial) {
for (Classification classification : trial) {
add(classification);
}
}
/**
* Adds existing ROC data to this ROC data
*
* @param rocData ROC data to add
*/
public void add(ROCData rocData) {
if (!Alphabet.alphabetsMatch(this, rocData)) {
throw new IllegalArgumentException ("Alphabets do not match");
}
if (!Arrays.equals(this.thresholds, rocData.thresholds)) {
throw new IllegalArgumentException ("Thresholds do not match");
}
int countsLength = this.counts.length;
for (int c = 0; c < countsLength; c++) {
int[][] thisClassCounts = this.counts[c];
int[][] otherClassCounts = rocData.counts[c];
int classLength = thisClassCounts.length;
for (int t = 0; t < classLength; t++) {
int[] thisThrCounts = thisClassCounts[t];
int[] otherThrCounts = otherClassCounts[t];
int thrLength = thisThrCounts.length;
for (int s = 0; s < thrLength; s++) {
thisThrCounts[s] += otherThrCounts[s];
}
}
}
}
//@Override
public Alphabet getAlphabet() {
return this.labelAlphabet;
}
//@Override
public Alphabet[] getAlphabets() {
return new Alphabet[] { this.labelAlphabet };
}
/**
* Gets the raw counts for a specified label.
*
* @param label Label to get counts for
* @see #TRUE_POSITIVE
* @see #FALSE_POSITIVE
* @see #FALSE_NEGATIVE
* @see #TRUE_NEGATIVE
* @return Array of raw counts for specified label
*/
public int[][] getCounts(Label label) {
return this.counts[label.getIndex()];
}
/**
* Gets the raw counts for a specified label and threshold.
*
* If data was not collected for the exact threshold specified, then results
* for the highest threshold <= the specified threshold will be returned.
*
* @param label Label to get counts for
* @param threshold Threshold to get counts for
* @see #TRUE_POSITIVE
* @see #FALSE_POSITIVE
* @see #FALSE_NEGATIVE
* @see #TRUE_NEGATIVE
* @return Array of raw counts for specified label and threshold
*/
public int[] getCounts(Label label, double threshold) {
int index = Arrays.binarySearch(this.thresholds, threshold);
if (index < 0) {
index = (-index) - 2;
}
return this.counts[label.getIndex()][index];
}
/**
* Gets the label alphabet
*/
public LabelAlphabet getLabelAlphabet() {
return this.labelAlphabet;
}
/**
* Gets the precision for a specified label and threshold.
*
* If data was not collected for the exact threshold specified, then results
* will for the highest threshold <= the specified threshold will be
* returned.
*
* @param label Label to get precision for
* @param threshold Threshold to get precision for
* @return Precision for specified label and threshold
*/
public double getPrecision(Label label, double threshold) {
int[] counts = getCounts(label, threshold);
return (double) counts[TRUE_POSITIVE] / (double) (counts[TRUE_POSITIVE] + counts[FALSE_POSITIVE]);
}
/**
* Gets the precision for a specified label and score. This differs from
* {@link ROCData.getPrecision(Label, double)} in that it is the precision
* for only scores falling in the one score value, not for all scores
* above the threshold.
*
* If data was not collected for the exact threshold specified, then results
* will for the highest threshold <= the specified threshold will be
* returned.
*
* @param label Label to get precision for
* @param threshold Threshold to get precision for
* @return Precision for specified label and score
*/
public double getPrecisionForScore(Label label, double score) {
final int[][] buckets = this.counts[label.getIndex()];
int index = Arrays.binarySearch(this.thresholds, score);
if (index < 0) {
index = (-index) - 2;
}
final double tp;
final double fp;
if (index == this.thresholds.length - 1) {
tp = buckets[index][TRUE_POSITIVE];
fp = buckets[index][FALSE_POSITIVE];
} else {
tp = buckets[index][TRUE_POSITIVE] - buckets[index + 1][TRUE_POSITIVE];
fp = buckets[index][FALSE_POSITIVE] - buckets[index + 1][FALSE_POSITIVE];
}
return (double) tp / (double) (tp + fp);
}
/**
* Gets the estimated percentage of training events that exceed the
* threshold.
*
* @param label Label to get precision for
* @param threshold Threshold to get precision for
* @return Estimated percentage of events exceeding threshold
*/
public double getPositivePercent(Label label, double threshold) {
final int[] counts = getCounts(label, threshold);
final int positive = counts[TRUE_POSITIVE] + counts[FALSE_POSITIVE];
return ((double) positive / (double) (positive + counts[FALSE_NEGATIVE] + counts[TRUE_NEGATIVE])) * 100.0;
}
/**
* Gets the recall rate for a specified label and threshold.
*
* If data was not collected for the exact threshold specified, then results
* will for the highest threshold <= the specified threshold will be
* returned.
*
* @param label Label to get recall for
* @param threshold Threshold to get recall for
* @return Recall rate for specified label and threshold
*/
public double getRecall(Label label, double threshold) {
int[] counts = getCounts(label, threshold);
return (double) counts[TRUE_POSITIVE] / (double) (counts[TRUE_POSITIVE] + counts[FALSE_NEGATIVE]);
}
/**
* Gets the thresholds being tracked
*
* @return Array of thresholds
*/
public double[] getThresholds() {
return this.thresholds;
}
/**
* Sets the raw counts for a specified label and threshold.
*
* If data is not collected for the exact threshold specified, then counts
* for the highest threshold <= the specified threshold will be set.
*
* @param label Label to get counts for
* @param threshold Threshold to get counts for
* @param newCounts New count values for the label and threshold
* @see #TRUE_POSITIVE
* @see #FALSE_POSITIVE
* @see #FALSE_NEGATIVE
* @see #TRUE_NEGATIVE
*/
public void setCounts(Label label, double threshold, int[] newCounts) {
int index = Arrays.binarySearch(this.thresholds, threshold);
if (index < 0) {
index = (-index) - 2;
}
final int[] oldCounts = this.counts[label.getIndex()][index];
if (newCounts.length != oldCounts.length) {
throw new IllegalArgumentException ("Array of counts must contain " + oldCounts.length + " elements.");
}
for (int i = 0; i < oldCounts.length; i++) {
oldCounts[i] = newCounts[i];
}
}
//@Override
public String toString() {
final StringBuilder buf = new StringBuilder();
final NumberFormat format = new DecimalFormat("0.####");
for (int i = 0; i < this.labelAlphabet.size(); i++) {
int[][] labelData = this.counts[i];
buf.append("ROC data for ");
buf.append(this.labelAlphabet.lookupObject(i).toString());
buf.append('\n');
buf.append("THR\tTP\tFP\tFN\tTN\tPrecis\tRecall\n");
// add one row for each threshold
for (int t = 0; t < this.thresholds.length; t++) {
buf.append(this.thresholds[t]);
for (int res : labelData[t]) {
buf.append('\t').append(res);
}
double tp = labelData[t][TRUE_POSITIVE];
double sum = tp + labelData[t][FALSE_POSITIVE];
double precision = 0.0;
if (sum != 0) {
precision = tp / sum;
}
sum = tp + labelData[t][FALSE_NEGATIVE];
double recall = 0.0;
if (sum != 0) {
recall = tp / sum;
}
buf.append('\t').append(format.format(precision));
buf.append('\t').append(format.format(recall));
buf.append('\n');
}
buf.append('\n');
}
return buf.toString();
}
}