Package weka.classifiers.meta

Source Code of weka.classifiers.meta.WeightedVote

package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.RandomizableMultipleClassifiersCombiner;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/**
*
* @author rlichten
*/
public class WeightedVote extends RandomizableMultipleClassifiersCombiner implements TechnicalInformationHandler {

  protected double[] m_Weights;

  public void setWeights( double[] weights ) {
    double min = weights[Utils.minIndex(weights)];
    if( min < 0 ) {
      for( int i = 0; i < weights.length; i++ ) {
        weights[i] += -min;
      }
    }
    double sum = Utils.sum( weights );
    if( sum <= 0 ) {
      for( int i = 0; i < weights.length; i++ ) {
        weights[i] = 1.0 / (double)weights.length;
      }
    }

    Utils.normalize( weights );
    m_Weights = weights;
  }

  public double[] getWeights() {
    return m_Weights;
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector result = new Vector();
    Enumeration optionEnum = super.listOptions();
    while( optionEnum.hasMoreElements() ) {
      result.addElement( optionEnum.nextElement() );
    }
    result.addElement( new Option(
        "\tThe weights to use (default: even weighting)\n",
        "w", 1, "-w <comma-separated-weights>" ) );

    return result.elements();
  }

  public void setOptions( String[] options ) throws Exception {
    String weightStr = Utils.getOption( 'w', options );
    if( weightStr.length() != 0 ) {
      String[] weightStrArray = weightStr.split( "," );
      double[] weightsArray = new double[weightStrArray.length];
      for( int i = 0; i < weightStrArray.length; i++ ) {
        weightsArray[i] = new Double( weightStrArray[i] ).doubleValue();
      }
      setWeights( weightsArray );
    }
  }

  public String[] getOptions() {
    String[] options = new String[4];
    int current = 0;

    String weightStr = "";
    for( int i = 0; i < m_Weights.length; i++ ) {
      weightStr += m_Weights[i] + ",";
    }
    weightStr.replace( ",$", "" );

    options[current++] = "-w";
    options[current++] = "" + weightStr;

    while( current < options.length ) {
      options[current++] = "";
    }
    return options;
  }

  /**
   * Classifies the given test instance.
   *
   * @param instance the instance to be classified
   * @return the predicted most likely class for the instance or
   * Instance.missingValue() if no prediction is made
   * @throws Exception if an error occurred during the prediction
   */
  public double classifyInstance( Instance instance ) throws Exception {
    double result;

    double[] dist = distributionForInstance( instance );
    if( instance.classAttribute().isNominal() ) {
      int index = Utils.maxIndex( dist );
      if( dist[index] == 0 ) {
        result = Instance.missingValue();
      } else {
        result = index;
      }
    } else if( instance.classAttribute().isNumeric() ) {
      result = dist[0];
    } else {
      result = Instance.missingValue();
    }
    return result;
  }

  /**
   * Classifies a given instance using the selected combination rule.
   *
   * @param instance the instance to be classified
   * @return the distribution
   * @throws Exception if instance could not be classified
   * successfully
   */
  public double[] distributionForInstance( Instance instance ) throws Exception {
    double[] result = new double[instance.numClasses()];

    for( int i = 0; i < m_Classifiers.length; i++ ) {
      double[] dist = getClassifier( i ).distributionForInstance( instance );
      for( int j = 0; j < result.length; j++ ) {
        result[j] += m_Weights[i] * dist[j];
      }

    }

    if( !instance.classAttribute().isNumeric() && ( Utils.sum( result ) > 0 ) ) {
      Utils.normalize( result );
    }

    return result;
  }

  /**
   * This method selects a classifier from the set of classifiers
   * by minimising error on the training data.
   *
   * @param data the training data to be used for generating the
   * classifier.
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier( Instances data ) throws Exception {

    getCapabilities().testWithFail( data );

    Instances newData = new Instances( data );
    newData.deleteWithMissingClass();

    for( int i = 0; i < m_Classifiers.length; i++ ) {
      getClassifier( i ).buildClassifier( newData );
    }
  }

  public String getRevision() {
    return RevisionUtils.extract( "$Revision: 1.0$" );
  }

  public TechnicalInformation getTechnicalInformation() {
    throw new UnsupportedOperationException( "Not supported yet." );
  }

  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments:
   * -t training file [-T test file] [-c class index]
   */
  public static void main( String[] argv ) {
    runClassifier( new WeightedVote(), argv );
  }
}
TOP

Related Classes of weka.classifiers.meta.WeightedVote

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.