Package cc.mallet.topics

Source Code of cc.mallet.topics.PolylingualTopicModel$TopicAssignment

/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.*;
import java.util.zip.*;

import java.io.*;
import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.Randoms;

/**
* Latent Dirichlet Allocation for loosely parallel corpora in arbitrary languages
*
* @author David Mimno, Andrew McCallum
*/

public class PolylingualTopicModel implements Serializable {
 
  // Analogous to a cc.mallet.classify.Classification
  public class TopicAssignment implements Serializable {
    public Instance[] instances;
    public LabelSequence[] topicSequences;
    public Labeling topicDistribution;
   
    public TopicAssignment (Instance[] instances, LabelSequence[] topicSequences) {
      this.instances = instances;
      this.topicSequences = topicSequences;
    }
  }

  int numLanguages = 1;

  protected ArrayList<TopicAssignment> data;  // the training instances and their topic assignments
  protected LabelAlphabet topicAlphabet;  // the alphabet for the topics

  protected int numStopwords = 0;
 
  protected int numTopics; // Number of topics to be fit

  HashSet<String> testingIDs = null;

  // These values are used to encode type/topic counts as
  //  count/topic pairs in a single int.
  protected int topicMask;
  protected int topicBits;

  protected Alphabet[] alphabets;
  protected int[] vocabularySizes;

  protected double[] alpha;   // Dirichlet(alpha,alpha,...) is the distribution over topics
  protected double alphaSum;
  protected double[] betas;   // Prior on per-topic multinomial distribution over words
  protected double[] betaSums;

  protected int[] languageMaxTypeCounts;

  public static final double DEFAULT_BETA = 0.01;
 
  protected double[] languageSmoothingOnlyMasses;
  protected double[][] languageCachedCoefficients;
  int topicTermCount = 0;
  int betaTopicCount = 0;
  int smoothingOnlyCount = 0;

  // An array to put the topic counts for the current document.
  // Initialized locally below.  Defined here to avoid
  // garbage collection overhead.
  protected int[] oneDocTopicCounts; // indexed by <document index, topic index>

  protected int[][][] languageTypeTopicCounts; // indexed by <feature index, topic index>
  protected int[][] languageTokensPerTopic; // indexed by <topic index>

  // for dirichlet estimation
  protected int[] docLengthCounts; // histogram of document sizes, summed over languages
  protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>

  protected int iterationsSoFar = 1;
  public int numIterations = 1000;
  public int burninPeriod = 5;
  public int saveSampleInterval = 5; // was 10; 
  public int optimizeInterval = 10;
  public int showTopicsInterval = 10; // was 50;
  public int wordsPerTopic = 7;

  protected int outputModelInterval = 0;
  protected String outputModelFilename;

  protected int saveStateInterval = 0;
  protected String stateFilename = null;
 
  protected Randoms random;
  protected NumberFormat formatter;
  protected boolean printLogLikelihood = false;
 
  public PolylingualTopicModel (int numberOfTopics) {
    this (numberOfTopics, numberOfTopics);
  }
 
  public PolylingualTopicModel (int numberOfTopics, double alphaSum) {
    this (numberOfTopics, alphaSum, new Randoms());
  }
 
  private static LabelAlphabet newLabelAlphabet (int numTopics) {
    LabelAlphabet ret = new LabelAlphabet();
    for (int i = 0; i < numTopics; i++)
      ret.lookupIndex("topic"+i);
    return ret;
  }
 
  public PolylingualTopicModel (int numberOfTopics, double alphaSum, Randoms random) {
    this (newLabelAlphabet (numberOfTopics), alphaSum, random);
  }
 
  public PolylingualTopicModel (LabelAlphabet topicAlphabet, double alphaSum, Randoms random)
  {
    this.data = new ArrayList<TopicAssignment>();
    this.topicAlphabet = topicAlphabet;
    this.numTopics = topicAlphabet.size();

    if (Integer.bitCount(numTopics) == 1) {
      // exact power of 2
      topicMask = numTopics - 1;
      topicBits = Integer.bitCount(topicMask);
    }
    else {
      // otherwise add an extra bit
      topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
      topicBits = Integer.bitCount(topicMask);
    }


    this.alphaSum = alphaSum;
    this.alpha = new double[numTopics];
    Arrays.fill(alpha, alphaSum / numTopics);
    this.random = random;
   
    formatter = NumberFormat.getInstance();
    formatter.setMaximumFractionDigits(5);

    System.err.println("Polylingual LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
               Integer.toBinaryString(topicMask) + " topic mask");
  }

    public void loadTestingIDs(File testingIDFile) throws IOException {
        testingIDs = new HashSet();

        BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
        String id = null;
        while ((id = in.readLine()) != null) {
            testingIDs.add(id);
        }
        in.close();
    }
 
  public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
  public int getNumTopics() { return numTopics; }
  public ArrayList<TopicAssignment> getData() { return data; }
 
  public void setNumIterations (int numIterations) {
    this.numIterations = numIterations;
  }

  public void setBurninPeriod (int burninPeriod) {
    this.burninPeriod = burninPeriod;
  }

  public void setTopicDisplay(int interval, int n) {
    this.showTopicsInterval = interval;
    this.wordsPerTopic = n;
  }

  public void setRandomSeed(int seed) {
    random = new Randoms(seed);
  }

  public void setOptimizeInterval(int interval) {
    this.optimizeInterval = interval;
  }

  public void setModelOutput(int interval, String filename) {
    this.outputModelInterval = interval;
    this.outputModelFilename = filename;
  }
 
  /** Define how often and where to save the state
   *
   * @param interval Save a copy of the state every <code>interval</code> iterations.
   * @param filename Save the state to this file, with the iteration number as a suffix
   */
  public void setSaveState(int interval, String filename) {
    this.saveStateInterval = interval;
    this.stateFilename = filename;
  }
 
  public void addInstances (InstanceList[] training) {

    numLanguages = training.length;

    languageTokensPerTopic = new int[numLanguages][numTopics];
   
    alphabets = new Alphabet[ numLanguages ];
    vocabularySizes = new int[ numLanguages ];
    betas = new double[ numLanguages ];
    betaSums = new double[ numLanguages ];
    languageMaxTypeCounts = new int[ numLanguages ];
    languageTypeTopicCounts = new int[ numLanguages ][][];
   
    int numInstances = training[0].size();

    HashSet[] stoplists = new HashSet[ numLanguages ];

    for (int language = 0; language < numLanguages; language++) {

      if (training[language].size() != numInstances) {
        System.err.println("Warning: language " + language + " has " +
                   training[language].size() + " instances, lang 0 has " +
                   numInstances);
      }

      alphabets[ language ] = training[ language ].getDataAlphabet();
      vocabularySizes[ language ] = alphabets[ language ].size();
     
      betas[language] = DEFAULT_BETA;
      betaSums[language] = betas[language] * vocabularySizes[ language ];
   
      languageTypeTopicCounts[language] = new int[ vocabularySizes[language] ][];

      int[][] typeTopicCounts = languageTypeTopicCounts[language];

      // Get the total number of occurrences of each word type
      int[] typeTotals = new int[ vocabularySizes[language] ];
     
      for (Instance instance : training[language]) {
        if (testingIDs != null &&
          testingIDs.contains(instance.getName())) {
          continue;
        }

        FeatureSequence tokens = (FeatureSequence) instance.getData();
        for (int position = 0; position < tokens.getLength(); position++) {
          int type = tokens.getIndexAtPosition(position);
          typeTotals[ type ]++;
        }
      }

      /* Automatic stoplist creation, currently disabled
      TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
      for (int type = 0; type < vocabularySizes[language]; type++) {
        sortedWords.add(new IDSorter(type, typeTotals[type]));
      }

      stoplists[language] = new HashSet<Integer>();
      Iterator<IDSorter> typeIterator = sortedWords.iterator();
      int totalStopwords = 0;

      while (typeIterator.hasNext() && totalStopwords < numStopwords) {
        stoplists[language].add(typeIterator.next().getID());
      }
      */
     
      // Allocate enough space so that we never have to worry about
      //  overflows: either the number of topics or the number of times
      //  the type occurs.
      for (int type = 0; type < vocabularySizes[language]; type++) {
        if (typeTotals[type] > languageMaxTypeCounts[language]) {
          languageMaxTypeCounts[language] = typeTotals[type];
        }
        typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ];
      }
    }
   
    for (int doc = 0; doc < numInstances; doc++) {

      if (testingIDs != null &&
        testingIDs.contains(training[0].get(doc).getName())) {
        continue;
      }

      Instance[] instances = new Instance[ numLanguages ];
      LabelSequence[] topicSequences = new LabelSequence[ numLanguages ];

      for (int language = 0; language < numLanguages; language++) {
       
        int[][] typeTopicCounts = languageTypeTopicCounts[language];
        int[] tokensPerTopic = languageTokensPerTopic[language];

        instances[language] = training[language].get(doc);
        FeatureSequence tokens = (FeatureSequence) instances[language].getData();
        topicSequences[language] =
          new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
     
        int[] topics = topicSequences[language].getFeatures();
        for (int position = 0; position < tokens.size(); position++) {
         
          int type = tokens.getIndexAtPosition(position);
          int[] currentTypeTopicCounts = typeTopicCounts[ type ];
         
          int topic = random.nextInt(numTopics);

          // If the word is one of the [numStopwords] most
          //  frequent words, put it in a non-sampled topic.
          //if (stoplists[language].contains(type)) {
          //  topic = -1;
          //}

          topics[position] = topic;
          tokensPerTopic[topic]++;
         
          // The format for these arrays is
          //  the topic in the rightmost bits
          //  the count in the remaining (left) bits.
          // Since the count is in the high bits, sorting (desc)
          //  by the numeric value of the int guarantees that
          //  higher counts will be before the lower counts.
         
          // Start by assuming that the array is either empty
          //  or is in sorted (descending) order.
         
          // Here we are only adding counts, so if we find
          //  an existing location with the topic, we only need
          //  to ensure that it is not larger than its left neighbor.
         
          int index = 0;
          int currentTopic = currentTypeTopicCounts[index] & topicMask;
          int currentValue;
         
          while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
            index++;
           
            /*
              // Debugging output...
             if (index >= currentTypeTopicCounts.length) {
              for (int i=0; i < currentTypeTopicCounts.length; i++) {
                System.out.println((currentTypeTopicCounts[i] & topicMask) + ":" +
                           (currentTypeTopicCounts[i] >> topicBits) + " ");
              }
             
              System.out.println(type + " " + typeTotals[type]);
            }
            */
            currentTopic = currentTypeTopicCounts[index] & topicMask;
          }
          currentValue = currentTypeTopicCounts[index] >> topicBits;
         
          if (currentValue == 0) {
            // new value is 1, so we don't have to worry about sorting
            //  (except by topic suffix, which doesn't matter)
           
            currentTypeTopicCounts[index] =
              (1 << topicBits) + topic;
          }
          else {
            currentTypeTopicCounts[index] =
              ((currentValue + 1) << topicBits) + topic;
           
            // Now ensure that the array is still sorted by
            //  bubbling this value up.
            while (index > 0 &&
                 currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
              int temp = currentTypeTopicCounts[index];
              currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
              currentTypeTopicCounts[index - 1] = temp;
             
              index--;
            }
          }
        }
      }

      TopicAssignment t = new TopicAssignment (instances, topicSequences);
      data.add (t);
    }

    initializeHistograms();

    languageSmoothingOnlyMasses = new double[ numLanguages ];
    languageCachedCoefficients = new double[ numLanguages ][ numTopics ];

    cacheValues();
  }

  /**
   *  Gather statistics on the size of documents
   *  and create histograms for use in Dirichlet hyperparameter
   *  optimization.
   */
  private void initializeHistograms() {

    int maxTokens = 0;
    int totalTokens = 0;

    for (int doc = 0; doc < data.size(); doc++) {
      int length = 0;
      for (LabelSequence sequence : data.get(doc).topicSequences) {
        length += sequence.getLength();
      }

      if (length > maxTokens) {
        maxTokens = length;
      }

      totalTokens += length;
    }

    System.err.println("max tokens: " + maxTokens);
    System.err.println("total tokens: " + totalTokens);

    docLengthCounts = new int[maxTokens + 1];
    topicDocCounts = new int[numTopics][maxTokens + 1];
   
  }

  private void cacheValues() {

    for (int language = 0; language < numLanguages; language++) {
      languageSmoothingOnlyMasses[language] = 0.0;
     
      for (int topic=0; topic < numTopics; topic++) {
        languageSmoothingOnlyMasses[language] +=
          alpha[topic] * betas[language] /
          (languageTokensPerTopic[language][topic] + betaSums[language]);
        languageCachedCoefficients[language][topic] =
          alpha[topic] / (languageTokensPerTopic[language][topic] + betaSums[language]);
      }
     
    }
   
  }
 
  private void clearHistograms() {
    Arrays.fill(docLengthCounts, 0);
    for (int topic = 0; topic < topicDocCounts.length; topic++)
      Arrays.fill(topicDocCounts[topic], 0);
  }

  public void estimate () throws IOException {
    estimate (numIterations);
  }
 
  public void estimate (int iterationsThisRound) throws IOException {

    long startTime = System.currentTimeMillis();
    int maxIteration = iterationsSoFar + iterationsThisRound;

    long totalTime = 0;
 
    for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
      long iterationStart = System.currentTimeMillis();
     
      if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
        System.out.println();
        printTopWords (System.out, wordsPerTopic, false);

      }

      if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
        this.printState(new File(stateFilename + '.' + iterationsSoFar));
      }

      /*
        if (outputModelInterval != 0 && iterations % outputModelInterval == 0) {
        this.write (new File(outputModelFilename+'.'+iterations));
        }
      */

      // TODO this condition should also check that we have more than one sample to work with here
      // (The number of samples actually obtained is not yet tracked.)
      if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
        iterationsSoFar % optimizeInterval == 0) {

        alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts);
        optimizeBetas();
        clearHistograms();
        cacheValues();
      }

      // Loop over every document in the corpus
      topicTermCount = betaTopicCount = smoothingOnlyCount = 0;

      for (int doc = 0; doc < data.size(); doc++) {

        sampleTopicsForOneDoc (data.get(doc),
                     (iterationsSoFar >= burninPeriod &&
                    iterationsSoFar % saveSampleInterval == 0));
      }
   
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            totalTime += elapsedMillis;

      if ((iterationsSoFar + 1) % 10 == 0) {
       
        double ll = modelLogLikelihood();
        System.out.println(elapsedMillis + "\t" + totalTime + "\t" +
                   ll);
      }
      else {
        System.out.print(elapsedMillis + " ");
      }
    }

    /*
    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
    */
  }
 
  public void optimizeBetas() {
   
    for (int language = 0; language < numLanguages; language++) {
     
      // The histogram starts at count 0, so if all of the
      //  tokens of the most frequent type were assigned to one topic,
      //  we would need to store a maxTypeCount + 1 count.
      int[] countHistogram = new int[languageMaxTypeCounts[language] + 1];
     
      // Now count the number of type/topic pairs that have
      //  each number of tokens.
     
      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];

      int index;
      for (int type = 0; type < vocabularySizes[language]; type++) {
        int[] counts = typeTopicCounts[type];
        index = 0;
        while (index < counts.length &&
             counts[index] > 0) {
          int count = counts[index] >> topicBits;
          countHistogram[count]++;
          index++;
        }
      }
     
      // Figure out how large we need to make the "observation lengths"
      //  histogram.
      int maxTopicSize = 0;
      for (int topic = 0; topic < numTopics; topic++) {
        if (tokensPerTopic[topic] > maxTopicSize) {
          maxTopicSize = tokensPerTopic[topic];
        }
      }
     
      // Now allocate it and populate it.
      int[] topicSizeHistogram = new int[maxTopicSize + 1];
      for (int topic = 0; topic < numTopics; topic++) {
        topicSizeHistogram[ tokensPerTopic[topic] ]++;
      }
     
      betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram,
                                     topicSizeHistogram,
                                     vocabularySizes[ language ],
                                     betaSums[language]);
      betas[language] = betaSums[language] / vocabularySizes[ language ];
    }
  }

  protected void sampleTopicsForOneDoc (TopicAssignment topicAssignment,
                      boolean shouldSaveState) {

    int[] currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double topicWeightsSum;

    int[] localTopicCounts = new int[numTopics];
    int[] localTopicIndex = new int[numTopics];

    for (int language = 0; language < numLanguages; language++) {

      int[] oneDocTopics =
        topicAssignment.topicSequences[language].getFeatures();
      int docLength =
        topicAssignment.topicSequences[language].getLength();
     
      //    populate topic counts
      for (int position = 0; position < docLength; position++) {
        localTopicCounts[oneDocTopics[position]]++;
      }
    }

    // Build an array that densely lists the topics that
    //  have non-zero counts.
    int denseIndex = 0;
    for (int topic = 0; topic < numTopics; topic++) {
      if (localTopicCounts[topic] != 0) {
        localTopicIndex[denseIndex] = topic;
        denseIndex++;
      }
    }

    // Record the total number of non-zero topics
    int nonZeroTopics = denseIndex;

    for (int language = 0; language < numLanguages; language++) {

            int[] oneDocTopics =
        topicAssignment.topicSequences[language].getFeatures();
            int docLength =
        topicAssignment.topicSequences[language].getLength();
      FeatureSequence tokenSequence =
        (FeatureSequence) topicAssignment.instances[language].getData();

      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];
      double beta = betas[language];
      double betaSum = betaSums[language];

      // Initialize the smoothing-only sampling bucket
      double smoothingOnlyMass = languageSmoothingOnlyMasses[language];
      //for (int topic = 0; topic < numTopics; topic++)
      //smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
     
      // Initialize the cached coefficients, using only smoothing.
      //cachedCoefficients = new double[ numTopics ];
      //for (int topic=0; topic < numTopics; topic++)
      //  cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
     
      double[] cachedCoefficients =
        languageCachedCoefficients[language];

      //    Initialize the topic count/beta sampling bucket
      double topicBetaMass = 0.0;
     
      // Initialize cached coefficients and the topic/beta
      //  normalizing constant.
     
      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];
        int n = localTopicCounts[topic];
       
        //  initialize the normalization constant for the (B * n_{t|d}) term
        topicBetaMass += beta * n /  (tokensPerTopic[topic] + betaSum)
       
        //  update the coefficients for the non-zero topics
        cachedCoefficients[topic] (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
      }

      double topicTermMass = 0.0;

      double[] topicTermScores = new double[numTopics];
      int[] topicTermIndices;
      int[] topicTermValues;
      int i;
      double score;

      //  Iterate over the positions (words) in the document
      for (int position = 0; position < docLength; position++) {
        type = tokenSequence.getIndexAtPosition(position);
        oldTopic = oneDocTopics[position];
        if (oldTopic == -1) { continue; }

        currentTypeTopicCounts = typeTopicCounts[type];
       
        //  Remove this token from all counts.
       
        // Remove this topic's contribution to the
        //  normalizing constants
        smoothingOnlyMass -= alpha[oldTopic] * beta /
          (tokensPerTopic[oldTopic] + betaSum);
        topicBetaMass -= beta * localTopicCounts[oldTopic] /
          (tokensPerTopic[oldTopic] + betaSum);
       
        // Decrement the local doc/topic counts
       
        localTopicCounts[oldTopic]--;
       
        // Maintain the dense index, if we are deleting
        //  the old topic
        if (localTopicCounts[oldTopic] == 0) {
         
          // First get to the dense location associated with
          //  the old topic.
         
          denseIndex = 0;
         
          // We know it's in there somewhere, so we don't
          //  need bounds checking.
          while (localTopicIndex[denseIndex] != oldTopic) {
            denseIndex++;
          }
         
          // shift all remaining dense indices to the left.
          while (denseIndex < nonZeroTopics) {
            if (denseIndex < localTopicIndex.length - 1) {
              localTopicIndex[denseIndex] =
                localTopicIndex[denseIndex + 1];
            }
            denseIndex++;
          }
         
          nonZeroTopics --;
        }
       
        // Decrement the global topic count totals
        tokensPerTopic[oldTopic]--;
        //assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
         
       
        // Add the old topic's contribution back into the
        //  normalizing constants.
        smoothingOnlyMass += alpha[oldTopic] * beta /
          (tokensPerTopic[oldTopic] + betaSum);
        topicBetaMass += beta * localTopicCounts[oldTopic] /
          (tokensPerTopic[oldTopic] + betaSum);
       
        // Reset the cached coefficient for this topic
        cachedCoefficients[oldTopic] =
          (alpha[oldTopic] + localTopicCounts[oldTopic]) /
          (tokensPerTopic[oldTopic] + betaSum);
       
       
        // Now go over the type/topic counts, decrementing
        //  where appropriate, and calculating the score
        //  for each topic at the same time.
       
        int index = 0;
        int currentTopic, currentValue;
       
        boolean alreadyDecremented = false;
       
        topicTermMass = 0.0;
       
        while (index < currentTypeTopicCounts.length &&
             currentTypeTopicCounts[index] > 0) {
          currentTopic = currentTypeTopicCounts[index] & topicMask;
          currentValue = currentTypeTopicCounts[index] >> topicBits;
         
          if (! alreadyDecremented &&
            currentTopic == oldTopic) {
           
            // We're decrementing and adding up the
            //  sampling weights at the same time, but
            //  decrementing may require us to reorder
            //  the topics, so after we're done here,
            //  look at this cell in the array again.
           
            currentValue --;
            if (currentValue == 0) {
              currentTypeTopicCounts[index] = 0;
            }
            else {
              currentTypeTopicCounts[index] =
                (currentValue << topicBits) + oldTopic;
            }
           
            // Shift the reduced value to the right, if necessary.
           
            int subIndex = index;
            while (subIndex < currentTypeTopicCounts.length - 1 &&
                 currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
              int temp = currentTypeTopicCounts[subIndex];
              currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
              currentTypeTopicCounts[subIndex + 1] = temp;
             
              subIndex++;
            }
           
            alreadyDecremented = true;
          }
          else {
            score =
              cachedCoefficients[currentTopic] * currentValue;
            topicTermMass += score;
            topicTermScores[index] = score;
           
            index++;
          }
        }
       
        double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
        double origSample = sample;
       
        //  Make sure it actually gets set
        newTopic = -1;
       
        if (sample < topicTermMass) {
          //topicTermCount++;
         
          i = -1;
          while (sample > 0) {
            i++;
            sample -= topicTermScores[i];
          }
         
          newTopic = currentTypeTopicCounts[i] & topicMask;
          currentValue = currentTypeTopicCounts[i] >> topicBits;
         
          currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic;
         
          // Bubble the new value up, if necessary
         
          while (i > 0 &&
               currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
            int temp = currentTypeTopicCounts[i];
            currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
            currentTypeTopicCounts[i - 1] = temp;
           
            i--;
          }
         
        }
        else {
          sample -= topicTermMass;
         
          if (sample < topicBetaMass) {
            //betaTopicCount++;
           
            sample /= beta;
           
            for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
              int topic = localTopicIndex[denseIndex];
             
              sample -= localTopicCounts[topic] /
                (tokensPerTopic[topic] + betaSum);

              if (sample <= 0.0) {
                newTopic = topic;
                break;
              }
            }
           
          }
          else {
            //smoothingOnlyCount++;
           
            sample -= topicBetaMass;
           
            sample /= beta;
           
            newTopic = 0;
            sample -= alpha[newTopic] /
              (tokensPerTopic[newTopic] + betaSum);
           
            while (sample > 0.0) {
              newTopic++;
              sample -= alpha[newTopic] /
                (tokensPerTopic[newTopic] + betaSum);
            }
         
          }
         
          // Move to the position for the new topic,
          //  which may be the first empty position if this
          //  is a new topic for this word.
         
          index = 0;
          while (currentTypeTopicCounts[index] > 0 &&
               (currentTypeTopicCounts[index] & topicMask) != newTopic) {
            index++;
          }
         
          // index should now be set to the position of the new topic,
          //  which may be an empty cell at the end of the list.
         
          if (currentTypeTopicCounts[index] == 0) {
            // inserting a new topic, guaranteed to be in
            //  order w.r.t. count, if not topic.
            currentTypeTopicCounts[index] = (1 << topicBits) + newTopic;
          }
          else {
            currentValue = currentTypeTopicCounts[index] >> topicBits;
            currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic;
           
            // Bubble the increased value left, if necessary
            while (index > 0 &&
                 currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
              int temp = currentTypeTopicCounts[index];
              currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
              currentTypeTopicCounts[index - 1] = temp;
             
              index--;
            }
          }
         
        }
       
        if (newTopic == -1) {
          System.err.println("PolylingualTopicModel sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
                     topicBetaMass + " " + topicTermMass);
          newTopic = numTopics-1; // TODO is this appropriate
          //throw new IllegalStateException ("PolylingualTopicModel: New topic not sampled.");
        }
        //assert(newTopic != -1);
       
        //      Put that new topic into the counts
        oneDocTopics[position] = newTopic;
       
        smoothingOnlyMass -= alpha[newTopic] * beta /
          (tokensPerTopic[newTopic] + betaSum);
        topicBetaMass -= beta * localTopicCounts[newTopic] /
          (tokensPerTopic[newTopic] + betaSum);
       
        localTopicCounts[newTopic]++;
       
        // If this is a new topic for this document,
        //  add the topic to the dense index.
        if (localTopicCounts[newTopic] == 1) {
         
          // First find the point where we
          //  should insert the new topic by going to
          //  the end (which is the only reason we're keeping
          //  track of the number of non-zero
          //  topics) and working backwards
         
          denseIndex = nonZeroTopics;
         
          while (denseIndex > 0 &&
               localTopicIndex[denseIndex - 1] > newTopic) {
           
            localTopicIndex[denseIndex] =
              localTopicIndex[denseIndex - 1];
            denseIndex--;
          }
         
          localTopicIndex[denseIndex] = newTopic;
          nonZeroTopics++;
        }
       
        tokensPerTopic[newTopic]++;
       
        //  update the coefficients for the non-zero topics
        cachedCoefficients[newTopic] =
          (alpha[newTopic] + localTopicCounts[newTopic]) /
          (tokensPerTopic[newTopic] + betaSum);
       
        smoothingOnlyMass += alpha[newTopic] * beta /
          (tokensPerTopic[newTopic] + betaSum);
        topicBetaMass += beta * localTopicCounts[newTopic] /
          (tokensPerTopic[newTopic] + betaSum);
       
        // Save the smoothing-only mass to the global cache
        languageSmoothingOnlyMasses[language] = smoothingOnlyMass;

      }
    }

    if (shouldSaveState) {
      // Update the document-topic count histogram,
      //  for dirichlet estimation

      int totalLength = 0;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];
       
        topicDocCounts[topic][ localTopicCounts[topic] ]++;
        totalLength += localTopicCounts[topic];
      }

      docLengthCounts[ totalLength ]++;

    }

  }

  public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
    PrintStream out = new PrintStream (file);
    printTopWords(out, numWords, useNewLines);
    out.close();
  }
 
    public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {

    TreeSet[][] languageTopicSortedWords = new TreeSet[numLanguages][numTopics];

    for (int language = 0; language < numLanguages; language++) {
      TreeSet[] topicSortedWords = languageTopicSortedWords[language];
      int[][] typeTopicCounts = languageTypeTopicCounts[language];

      for (int topic = 0; topic < numTopics; topic++) {
        topicSortedWords[topic] = new TreeSet<IDSorter>();
      }

      for (int type = 0; type < vocabularySizes[language]; type++) {
       
        int[] topicCounts = typeTopicCounts[type];
       
        int index = 0;
        while (index < topicCounts.length &&
             topicCounts[index] > 0) {
         
          int topic = topicCounts[index] & topicMask;
          int count = topicCounts[index] >> topicBits;
         
          topicSortedWords[topic].add(new IDSorter(type, count));

          index++;
        }
      }
    }

        for (int topic = 0; topic < numTopics; topic++) {

      out.println (topic + "\t" + formatter.format(alpha[topic]));
       
      for (int language = 0; language < numLanguages; language++) {
       
        out.print(" " + language + "\t" + languageTokensPerTopic[language][topic] + "\t" + betas[language] + "\t");

        TreeSet<IDSorter> sortedWords = languageTopicSortedWords[language][topic];
        Alphabet alphabet = alphabets[language];

        int word = 1;
        Iterator<IDSorter> iterator = sortedWords.iterator();
        while (iterator.hasNext() && word < numWords) {
          IDSorter info = iterator.next();
         
          out.print(alphabet.lookupObject(info.getID()) + " ");
          word++;
        }
       
        out.println();
            }
        }
    }

  public void printDocumentTopics (File f) throws IOException {
    printDocumentTopics (new PrintWriter (f, "UTF-8") );
  }

  public void printDocumentTopics (PrintWriter pw) {
    printDocumentTopics (pw, 0.0, -1);
  }

  /**
   *  @param pw          A print writer
   *  @param threshold   Only print topics with proportion greater than this number
   *  @param max         Print no more than this many topics
   */
  public void printDocumentTopics (PrintWriter pw, double threshold, int max)  {
    pw.print ("#doc source topic proportion ...\n");
    int docLength;
    int[] topicCounts = new int[ numTopics ];

    IDSorter[] sortedTopics = new IDSorter[ numTopics ];
    for (int topic = 0; topic < numTopics; topic++) {
      // Initialize the sorters with dummy values
      sortedTopics[topic] = new IDSorter(topic, topic);
    }

    if (max < 0 || max > numTopics) {
      max = numTopics;
    }

    for (int di = 0; di < data.size(); di++) {

      pw.print (di); pw.print (' ');

      int totalLength = 0;

      for (int language = 0; language < numLanguages; language++) {
     
        LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequences[language];
        int[] currentDocTopics = topicSequence.getFeatures();
       
        docLength = topicSequence.getLength();
        totalLength += docLength;
       
        // Count up the tokens
        for (int token=0; token < docLength; token++) {
          topicCounts[ currentDocTopics[token] ]++;
        }
      }
       
      // And normalize
      for (int topic = 0; topic < numTopics; topic++) {
        sortedTopics[topic].set(topic, (float) topicCounts[topic] / totalLength);
      }
     
      Arrays.sort(sortedTopics);

      for (int i = 0; i < max; i++) {
        if (sortedTopics[i].getWeight() < threshold) { break; }
       
        pw.print (sortedTopics[i].getID() + " " +
              sortedTopics[i].getWeight() + " ");
      }
      pw.print (" \n");

      Arrays.fill(topicCounts, 0);
    }
   
  }
 
  public void printState (File f) throws IOException {
    PrintStream out =
      new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))),
              false, "UTF-8");
    printState(out);
    out.close();
  }
 
  public void printState (PrintStream out) {

    out.println ("#doc lang pos typeindex type topic");

    for (int doc = 0; doc < data.size(); doc++) {
      for (int language =0; language < numLanguages; language++) {
        FeatureSequence tokenSequence =  (FeatureSequence) data.get(doc).instances[language].getData();
        LabelSequence topicSequence =  (LabelSequence) data.get(doc).topicSequences[language];
       
        for (int pi = 0; pi < topicSequence.getLength(); pi++) {
          int type = tokenSequence.getIndexAtPosition(pi);
          int topic = topicSequence.getIndexAtPosition(pi);
          out.print(doc); out.print(' ');
          out.print(language); out.print(' ');
          out.print(pi); out.print(' ');
          out.print(type); out.print(' ');
          out.print(alphabets[language].lookupObject(type)); out.print(' ');
          out.print(topic); out.println();
        }
      }
    }
  }

  public double modelLogLikelihood() {
    double logLikelihood = 0.0;
    int nonZeroTopics;

    // The likelihood of the model is a combination of a
    // Dirichlet-multinomial for the words in each topic
    // and a Dirichlet-multinomial for the topics in each
    // document.

    // The likelihood function of a dirichlet multinomial is
    //   Gamma( sum_i alpha_i )   prod_i Gamma( alpha_i + N_i )
    //  prod_i Gamma( alpha_i )    Gamma( sum_i (alpha_i + N_i) )

    // So the log likelihood is
    //  logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
    //   sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]

    // Do the documents first

    int[] topicCounts = new int[numTopics];
    double[] topicLogGammas = new double[numTopics];
    int[] docTopics;

    for (int topic=0; topic < numTopics; topic++) {
      topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
    }
 
    for (int doc=0; doc < data.size(); doc++) {

      int totalLength = 0;

            for (int language = 0; language < numLanguages; language++) {

                LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();

        totalLength += topicSequence.getLength();

                // Count up the tokens
                for (int token=0; token < topicSequence.getLength(); token++) {
                    topicCounts[ currentDocTopics[token] ]++;
                }
            }

      for (int topic=0; topic < numTopics; topic++) {
        if (topicCounts[topic] > 0) {
          logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
                    topicLogGammas[ topic ]);
        }
      }

      // subtract the (count + parameter) sum term
      logLikelihood -= Dirichlet.logGammaStirling(alphaSum + totalLength);

      Arrays.fill(topicCounts, 0);
    }
 
    // add the parameter sum term
    logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);

    // And the topics

    for (int language = 0; language < numLanguages; language++) {
      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];
      double beta = betas[language];

      // Count the number of type-topic pairs
      int nonZeroTypeTopics = 0;
     
      for (int type=0; type < vocabularySizes[language]; type++) {
        // reuse this array as a pointer
       
        topicCounts = typeTopicCounts[type];
       
        int index = 0;
        while (index < topicCounts.length &&
             topicCounts[index] > 0) {
          int topic = topicCounts[index] & topicMask;
          int count = topicCounts[index] >> topicBits;
         
          nonZeroTypeTopics++;
          logLikelihood += Dirichlet.logGammaStirling(beta + count);
         
          if (Double.isNaN(logLikelihood)) {
            System.out.println(count);
            System.exit(1);
          }
         
          index++;
        }
      }
     
      for (int topic=0; topic < numTopics; topic++) {
        logLikelihood -=
          Dirichlet.logGammaStirling( (beta * numTopics) +
                        tokensPerTopic[ topic ] );
        if (Double.isNaN(logLikelihood)) {
          System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
          System.exit(1);
        }
       
      }
     
      logLikelihood +=
        (Dirichlet.logGammaStirling(beta * numTopics)) -
        (Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);
    }

    if (Double.isNaN(logLikelihood)) {
      System.out.println("at the end");
      System.exit(1);
    }


    return logLikelihood;
  }
 
  public static void main (String[] args) throws IOException {

    if (args.length < 4) {
      System.err.println("Usage: PolylingualTopicModel [num topics] [file to save state] [testing IDs file] [language 0 instances] ...");
      System.exit(1);
    }

    int numTopics = Integer.parseInt(args[0]);
    String stateFileName = args[1];
    File testingIDsFile = new File(args[2]);

    InstanceList[] training = new InstanceList[ args.length - 3 ];
    for (int language=0; language < training.length; language++) {
      training[language] = InstanceList.load(new File(args[language + 3]));
      System.err.println("loaded " + args[language + 3]);
    }

    PolylingualTopicModel lda = new PolylingualTopicModel (numTopics, 2.0);
    lda.printLogLikelihood = true;
    lda.setTopicDisplay(50, 7);
    lda.loadTestingIDs(testingIDsFile);
    lda.addInstances(training);
    lda.setSaveState(200, stateFileName);
   
    lda.estimate();
    lda.printState(new File(stateFileName));
  }
 
}
TOP

Related Classes of cc.mallet.topics.PolylingualTopicModel$TopicAssignment

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.