Package mia.classifier.ch16.train

Source Code of mia.classifier.ch16.train.TrainNewsGroups

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mia.classifier.ch16.train;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.io.Files;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/**
* Reads and trains an adaptive logistic regression model on the 20 newsgroups data.
* The first command line argument gives the path of the directory holding the training
* data.  The optional second argument, leakType, defines which classes of features to use.
* Importantly, leakType controls whether a synthetic date is injected into the data as
* a target leak and if so, how.
* <p>
* The value of leakType % 3 determines whether the target leak is injected according to
* the following table:
* <p>
* <table>
* <tr><td valign='top'>0</td><td>No leak injected</td></tr>
* <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and
* is a perfect target leak since each newsgroup is given a different month</td></tr>
* <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format.  The day varies
* and thus there are more leak symbols that need to be learned.  Ultimately this is just
* as big a leak as case 1.</td></tr>
* </table>
* <p>
* Leaktype also determines what other text will be indexed.  If leakType is greater
* than or equal to 6, then neither headers nor text body will be used for features and the leak is the only
* source of data.  If leakType is greater than or equal to 3, then subject words will be used as features.
* If leakType is less than 3, then both subject and body text will be used as features.
* <p>
* A leakType of 0 gives no leak and all textual features.
* <p>
* See the following table for a summary of commonly used values for leakType
* <p>
* <table>
* <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
* <tr><td colspan=4><hr></td></tr>
* <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
* <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
* <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
* <tr><td colspan=4><hr></td></tr>
* <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
* <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
* <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
* <tr><td colspan=4><hr></td></tr>
* <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
* <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
* <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
* <tr><td colspan=4><hr></td></tr>
* </table>
*/
public final class TrainNewsGroups {
  private static final int FEATURES = 10000;

  // 1997-01-15 00:01:00 GMT
  private static final long DATE_REFERENCE = 853286460;
  private static final long MONTH = 30 * 24 * 3600;
  private static final long WEEK = 7 * 24 * 3600;

  private static final Random rand = RandomUtils.getRandom();

  private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
  private static final SimpleDateFormat[] DATE_FORMATS = {
    new SimpleDateFormat("", Locale.ENGLISH),
    new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
    new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
  };

  private static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
  private static final TextValueEncoder encoder = new TextValueEncoder("body");
  private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");

  private TrainNewsGroups() {
  }

  public static void main(String[] args) throws IOException {
    File base = new File(args[0]);

    int leakType = 0;
    if (args.length > 1) {
      leakType = Integer.parseInt(args[1]);
    }

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
      if (newsgroup.isDirectory()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
      }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = {1, 2, 5};
    for (File file : files) {
      String ng = file.getParentFile().getName();
      int actual = newsGroups.intern(ng);

      Vector v = encodeFeatureVector(file);
      learningAlgorithm.train(actual, v);

      k++;

      int bump = bumps[(int) Math.floor(step) % bumps.length];
      int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
      State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
      double maxBeta;
      double nonZeros;
      double positive;
      double norm;

      double lambda = 0;
      double mu = 0;

      if (best != null) {
        CrossFoldLearner state = best.getPayload().getLearner();
        averageCorrect = state.percentCorrect();
        averageLL = state.logLikelihood();

        OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();
       
        Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
          @Override
          public double apply(double v) {
            return Math.abs(v) > 1.0e-6 ? 1 : 0;
          }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
          @Override
          public double apply(double v) {
            return v > 0 ? 1 : 0;
          }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = learningAlgorithm.getBest().getMappedParams()[0];
        mu = learningAlgorithm.getBest().getMappedParams()[1];
      } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
      }
      if (k % (bump * scale) == 0) {
        if (learningAlgorithm.getBest() != null) {
          ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
        }

        step += 0.25;
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n",
          k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
      }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
                                learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
  }

  private static void dissect(Dictionary newsGroups,
                              AdaptiveLogisticRegression learningAlgorithm,
                              Iterable<File> files) throws IOException {
    CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
    model.close();

    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    ModelDissector md = new ModelDissector();

    encoder.setTraceDictionary(traceDictionary);
    bias.setTraceDictionary(traceDictionary);

    for (File file : permute(files, rand).subList(0, 500)) {
      traceDictionary.clear();
      Vector v = encodeFeatureVector(file);
      md.update(v, traceDictionary, model);
    }

    List<String> ngNames = Lists.newArrayList(newsGroups.values());
    List<ModelDissector.Weight> weights = md.summary(100);
    for (ModelDissector.Weight w : weights) {
      System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
                        w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
                        w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
    }
  }

  private static Vector encodeFeatureVector(File file) throws IOException {
    BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
    try {
      String line = reader.readLine();
      while (line != null && line.length() > 0) {
        boolean countHeader = (
          line.startsWith("From:") || line.startsWith("Subject:") ||
            line.startsWith("Keywords:") || line.startsWith("Summary:")) ;
        do {
          if (countHeader) {
            line = line.replaceAll(".*:", "");
            encoder.addText(line.toLowerCase());
          }
          line = reader.readLine();
        } while (line != null && line.startsWith(" "));
      }

      if (line != null) {
        line = reader.readLine();
      }

      while (line != null) {
        encoder.addText(line.toLowerCase());
        line = reader.readLine();
      }
    } finally {
      reader.close();
    }

    Vector v = new RandomAccessSparseVector(FEATURES);
    bias.addToVector((byte[]) null, 1, v);
    encoder.flush(1, v);
    return v;
  }

  private static List<File> permute(Iterable<File> files, Random rand) {
    List<File> r = Lists.newArrayList();
    for (File file : files) {
      int i = rand.nextInt(r.size() + 1);
      if (i == r.size()) {
        r.add(file);
      } else {
        r.add(r.get(i));
        r.set(i, file);
      }
    }
    return r;
  }
}
TOP

Related Classes of mia.classifier.ch16.train.TrainNewsGroups

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.