Package cc.mallet.classify

Source Code of cc.mallet.classify.BalancedWinnow

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;


/**
* Classification methods of BalancedWinnow algorithm.
*
* @see BalancedWinnowTrainer
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class BalancedWinnow extends Classifier implements Serializable
{
    double [][] m_weights;
 
    /**
     * Passes along data pipe and weights from
     * {@link #BalancedWinnowTrainer BalancedWinnowTrainer}
     * @param dataPipe needed for dictionary, labels, feature vectors, etc
     * @param weights weights calculated during training phase
     */
    public BalancedWinnow (Pipe dataPipe, double [][] weights)
    {
        super (dataPipe);
        m_weights = new double[weights.length][weights[0].length];
        for (int i = 0; i < weights.length; i++)
      for (int j = 0; j < weights[0].length; j++)
        m_weights[i][j] = weights[i][j];
    }
 
    /**
     * @return a copy of the weight vectors
     */
    public double[][] getWeights()
    {
        int numCols = m_weights[0].length;
        double[][] ret = new double[m_weights.length][numCols];
        for (int i = 0; i < ret.length; i++)
      System.arraycopy(m_weights[i], 0, ret[i], 0, numCols);
        return ret;
    }
 
    /**
     * Classifies an instance using BalancedWinnow's weights
     *
     * <p>Returns a Classification containing the normalized
     * dot products between class weight vectors and the instance
     * feature vector.
     *
     * <p>One can obtain the confidence of the classification by
     * calculating weight(j')/weight(j), where j' is the
     * highest weight prediction and j is the 2nd-highest.
     * Another possibility is to calculate
     * <br><tt><center>e^{dot(w_j', x} / sum_j[e^{dot(w_j, x)}]</center></tt>
     */
    public Classification classify (Instance instance)
    {
        int numClasses = getLabelAlphabet().size();
        int numFeats = getAlphabet().size();
        double[] scores = new double[numClasses];
        FeatureVector fv = (FeatureVector) instance.getData ();

        // Make sure the feature vector's feature dictionary matches
        // what we are expecting from our data pipe (and thus our notion
        // of feature probabilities.
        assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
        int fvisize = fv.numLocations();

        // Take dot products
        double sum = 0;
        for (int ci = 0; ci < numClasses; ci++) {
      for (int fvi = 0; fvi < fvisize; fvi++) {
        int fi = fv.indexAtLocation (fvi);
        double vi = fv.valueAtLocation(fvi);

        if ( m_weights[ci].length > fi ) {
        scores[ci] += vi * m_weights[ci][fi];
        sum += vi * m_weights[ci][fi];
        }
      }
      scores[ci] += m_weights[ci][numFeats];
      sum += m_weights[ci][numFeats];
        }
        MatrixOps.timesEquals(scores, 1.0 / sum);

        // Create and return a Classification object
        return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores));
    }

    // Serialization
    // serialVersionUID is overriden to prevent innocuous changes in this
    // class from making the serialization mechanism think the external
    // format has changed.

    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;

    private void writeObject(ObjectOutputStream out) throws IOException
    {
        out.writeInt(CURRENT_SERIAL_VERSION);
        out.writeObject(getInstancePipe());
       
        // write weight vector for each class
        out.writeObject(m_weights);
    }
   
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != CURRENT_SERIAL_VERSION)
      throw new ClassNotFoundException("Mismatched BalancedWinnow versions: wanted " +
                       CURRENT_SERIAL_VERSION + ", got " +
                       version);
        instancePipe = (Pipe) in.readObject();
        m_weights = (double[][]) in.readObject();
       
    }
   
}
TOP

Related Classes of cc.mallet.classify.BalancedWinnow

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.