Package com.jgaap.classifiers

Source Code of com.jgaap.classifiers.BaggingNearestNeighborDriver

package com.jgaap.classifiers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.log4j.Logger;

import com.google.common.collect.ImmutableMultimap;
import com.jgaap.generics.AnalyzeException;
import com.jgaap.generics.DistanceCalculationException;
import com.jgaap.generics.NeighborAnalysisDriver;
import com.jgaap.util.Document;
import com.jgaap.util.EventBagging;
import com.jgaap.util.EventMap;
import com.jgaap.util.EventSet;
import com.jgaap.util.Pair;

public class BaggingNearestNeighborDriver extends NeighborAnalysisDriver {

  private static Logger logger = Logger
      .getLogger("com.jgaap.classifiers.BaggingNearestNeighborDriver");

  private ImmutableMultimap<String,EventMap> authorHistograms;

  public BaggingNearestNeighborDriver() {
    addParams("Samples", "Samples", "5", new String[] { "1", "5", "10",
        "15", "20", "25", "30", "40", "50" }, true);
    addParams("SampleSize", "Sample Size", "500",
        new String[] { "500", "750", "1000", "1500", "2000", "3000",
            "4000", "5000", "10000" }, true);
    addParams("Score", "Score", "False", new String[] {"True", "False"}, false);
  }

  @Override
  public String displayName() {
    return "Bagging Nearest Neighbor";
  }

  @Override
  public String tooltipText() {
    return "Nearest Neighbor using bagging to generated the traing data.";
  }

  @Override
  public boolean showInGUI() {
    return false;
  }

  @Override
  public void train(List<Document> knownDocuments) throws AnalyzeException {
    Map<String, EventBagging> authorBags = new HashMap<String, EventBagging>();
    for(Document knownDocument : knownDocuments){
      for (EventSet eventSet : knownDocument.getEventSets().values()) {
        EventBagging eventBag = authorBags.get(knownDocument.getAuthor());
        if (eventBag == null) {
          eventBag = new EventBagging(eventSet);
          authorBags.put(knownDocument.getAuthor(), eventBag);
        } else {
          eventBag.addAll(eventSet);
        }
      }
    }
    int samples = getParameter("samples", 5);
    int sampleSize = getParameter("sampleSize", 500);
    ImmutableMultimap.Builder<String, EventMap> authorHistogramsBuilder = ImmutableMultimap.builder();
    for (Entry<String, EventBagging> entry : authorBags.entrySet()) {
      List<EventMap> histograms = new ArrayList<EventMap>(samples);
      for (int i = 0; i < samples; i++) {
        EventSet eventSet = new EventSet(sampleSize);
        for (int j = 0; j < sampleSize; j++) {
          eventSet.addEvent(entry.getValue().next());
        }
        EventMap histogram = new EventMap(eventSet);
        histograms.add(histogram);
      }
      authorHistogramsBuilder.putAll(entry.getKey(), histograms);
    }
    authorHistograms = authorHistogramsBuilder.build();
  }

  @Override
  public List<Pair<String, Double>> analyze(Document unknownDocument)
      throws AnalyzeException {
    List<Pair<String, Double>> rawResults = new ArrayList<Pair<String, Double>>();
    EventSet unknownEventSet = new EventSet();
    for(EventSet eventSet : unknownDocument.getEventSets().values()){
      unknownEventSet.addEvents(eventSet);
    }
    EventMap unknownHistogram = new EventMap(unknownEventSet);
    for (Entry<String, Collection<EventMap>> entry : authorHistograms.asMap().entrySet()) {
      for (EventMap knownHistogram : entry.getValue()) {
        try {
          rawResults.add(new Pair<String, Double>(entry.getKey(), distance.distance(unknownHistogram, knownHistogram), 2));
        } catch (DistanceCalculationException e) {
          logger.fatal("Distance " + distance.displayName() + " failed", e);
          throw new AnalyzeException("Distance " + distance.displayName() + " failed");
        }
      }
    }
    Collections.sort(rawResults);
   
    if(getParameter("score").equalsIgnoreCase("true")){
      Map<String, Integer> authorScores = new HashMap<String, Integer>();
      int samples = getParameter("samples", 5);
      for(int i = 0; i < samples; i++){
        Integer current = authorScores.get(rawResults.get(i).getFirst());
        if (current == null){
          current = 1;
        } else {
          current ++;
        }
        authorScores.put(rawResults.get(i).getFirst(), current);
      }
      List<Pair<String, Double>> scoreResults = new ArrayList<Pair<String,Double>>();
      double samplesDouble = samples;
      for(Entry<String, Integer> entry : authorScores.entrySet()){
        scoreResults.add(new Pair<String, Double>(entry.getKey(), entry.getValue()/samplesDouble, 2));
      }
      Collections.sort(scoreResults);
      Collections.reverse(scoreResults);
      return scoreResults;
    } else {
      return rawResults;
    }
  }

}
TOP

Related Classes of com.jgaap.classifiers.BaggingNearestNeighborDriver

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.