Package upenn.junto.graph

Source Code of upenn.junto.graph.CrossValidationGenerator

package upenn.junto.graph;

import upenn.junto.util.ObjectDoublePair;
import upenn.junto.util.Constants;
import upenn.junto.util.CollectionUtil;

import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.iterator.TObjectDoubleIterator;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

public class CrossValidationGenerator {
  // seed used to initialize the random number generator
  static long _kDeterministicSeed = 100;
 
  public static void Split(Graph g, double trainFract) {   
    Random r = new Random(_kDeterministicSeed);
    // Random r = new Random();
   
    TObjectDoubleHashMap instanceVertices = new TObjectDoubleHashMap();
    Iterator vIter = g.vertices().keySet().iterator();
    while (vIter.hasNext()) {
      Vertex v = g.vertices().get(vIter.next());
     
      // nodes without feature prefix and those with at least one
      // gold labels are considered valid instances
      if (!v.name().startsWith(Constants.GetFeatPrefix()) &&
          v.goldLabels().size() > 0) {
        instanceVertices.put(v, r.nextDouble());
      }
    }
   
    ArrayList<ObjectDoublePair> sortedRandomInstances =
      CollectionUtil.ReverseSortMap(instanceVertices);

    int totalInstances = sortedRandomInstances.size();
    double totalTrainInstances = Math.ceil(totalInstances * trainFract);
    for (int vi = 0; vi < totalInstances; ++vi) {
      Vertex v = (Vertex) sortedRandomInstances.get(vi).GetLabel();
     
      // mark train and test nodes
      if (vi < totalTrainInstances) {
        v.setIsSeedNode(true);
       
        // we expect that the gold labels for the node has already been
        // set, we only need to copy them as injected labels
        TObjectDoubleIterator goldLabIter = v.goldLabels().iterator();
        while (goldLabIter.hasNext()) {
          goldLabIter.advance();
          v.SetInjectedLabelScore((String) goldLabIter.key(), goldLabIter.value());
        }
      } else {
        v.setIsTestNode(true);
      }
    }
   
    //    // for sanity check, count the number of train and test nodes
    //    int totalTrainNodes = 0;
    //    int totalTestNodes = 0;
    //    for (int vi = 0; vi < totalInstances; ++vi) {
    //      Vertex v = (Vertex) sortedRandomInstances.get(vi).GetLabel();
    //      if (v.isSeedNode()) {
    //        ++totalTrainNodes;
    //      }
    //      if (v.isTestNode()) {
    //        ++totalTestNodes;
    //      }
    //    }
    //    MessagePrinter.Print("Total train nodes: " + totalTrainNodes);
    //    MessagePrinter.Print("Total test nodes: " + totalTestNodes);
  }
}
TOP

Related Classes of upenn.junto.graph.CrossValidationGenerator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.