Package statechum.analysis.learning

Source Code of statechum.analysis.learning.RPNIBlueFringeLearnerTestComponentOpt

package statechum.analysis.learning;

import java.awt.Frame;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;

import java.io.StringWriter;

import statechum.JUConstants;
import statechum.analysis.learning.observers.Learner;
import statechum.analysis.learning.observers.Learner.RestartLearningEnum;
import statechum.xmachine.model.testset.PTATestSequenceEngine;

import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.Vertex;
import edu.uci.ics.jung.graph.impl.DirectedSparseGraph;
import edu.uci.ics.jung.utils.UserData;

import static statechum.analysis.learning.TestRpniLearner.isAccept;

public class RPNIBlueFringeLearnerTestComponentOpt extends RPNIBlueFringeLearnerTestComponent
{

  public RPNIBlueFringeLearnerTestComponentOpt(Frame parent) {
    super(parent);
  }
 
  protected void update(StatePair pair)
  {
    pair.getQ().setUserDatum("pair", pair, UserData.SHARED);
    pair.getR().setUserDatum("pair", pair, UserData.SHARED);// since this copy of the graph will really not be used, changes to it are immaterial at this stage
    updateGraph(scoreComputer.getGraph());
  }
 
  //protected computeStateScores scoreComputer = new computeStateScores(0);

  protected int counterAccepted =0, counterRejected =0, counterRestarted = 0, counterEmptyQuestions = 0;

  /** Takes the candidates for merging and computes the number of times different scores are encountered. */
  public static void populateScores(Collection<computeStateScores.PairScore> data, Map<Integer,AtomicInteger> histogram)
  {
    for(computeStateScores.PairScore pair:data)
    {
    int pairScore = pair.getScore();
      AtomicInteger count = histogram.get(pairScore);
      if (count == null)
      {
        count = new AtomicInteger();histogram.put(pairScore,count);
      }
      count.incrementAndGet();
    }
  }
 
  /** Takes the candidates for merging and computes the number of times different scores (increments of 10) are encountered. */
  public static void populateHistogram(Collection<computeStateScores.PairScore> data, Map<Integer,AtomicInteger> histogram)
  {
    for(computeStateScores.PairScore pair:data)
    {
    int pairScore = pair.getScore()>= 200? pair.getScore()-pair.getScore() % 100: pair.getScore()>=10? pair.getScore()-pair.getScore()%10: pair.getScore()>0?1:0;
      AtomicInteger count = histogram.get(pairScore);
      if (count == null)
      {
        count = new AtomicInteger();histogram.put(pairScore,count);
      }
      count.incrementAndGet();
    }
  }
 
  public static String HistogramToString(Map<Integer,AtomicInteger> histogram, String Name)
  {
    final String FS=",";
    String result="\n"+Name;
    Map<Integer, AtomicInteger> tmp = new TreeMap<Integer,AtomicInteger>();
    tmp.putAll(histogram);
    for(Entry<Integer,AtomicInteger> sc:tmp.entrySet())
      result = result+FS+sc.getValue();
    result=result+"\n"+Name;
    for(Entry<Integer,AtomicInteger> sc:tmp.entrySet())
      result = result+FS+sc.getKey();

    return result+"\n";
  }
 
  public static String HistogramToSeries(Map<Integer,AtomicInteger> histogram, String Name)
  {
    final String FS=",";
    String result="\n"+Name;
    Map<Integer, AtomicInteger> tmp = new TreeMap<Integer,AtomicInteger>();
    tmp.putAll(histogram);
    int limit = 0;
    for(Entry<Integer,AtomicInteger> sc:tmp.entrySet()){
      limit = sc.getValue().get();
      for(int i = 0;i<limit;i++){
        result = result+FS+sc.getKey();
      }
    }

    return result+"\n";
  }
 
  public static String pairScoresAndIterations(Map<computeStateScores.PairScore,Integer> map, String name){
    final String FS=",";
    String result="\n"+name+"-score"+FS;
    for(computeStateScores.PairScore score:map.keySet())
      result=result+score.getScore()+FS;
    result = result+"\n"+name+"-iteration"+FS;
    for(Integer i:map.values())
      result = result+i+FS;
    return result;
  }

  public DirectedSparseGraph init(Collection<List<String>> plus, Collection<List<String>> minus)
  {
    sPlus = plus;sMinus = minus;
    scoreComputer.initPTA();
    scoreComputer.augmentPTA(sMinus, false);
    scoreComputer.augmentPTA(sPlus, true);
    return scoreComputer.getGraph();
  }
 
  public computeStateScores getScoreComputer()
  {
    return scoreComputer;
  }
 
  @Override
  public DirectedSparseGraph learnMachine(
      @SuppressWarnings("unused"DirectedSparseGraph model,
      @SuppressWarnings("unused") Collection<List<String>> sP,
      @SuppressWarnings("unused"Collection<List<String>> sM) {
    //init(sPlus, sMinus);
    //return learnMachine();
    throw new UnsupportedOperationException();
  }

  public String getStats()
  {
    return counterRestarted+ ", "+ scoreComputer.getExtentOfCompleteness();
  }

  public Learner getLearner()
  {
    return thisLearner;
  }
 
  protected final Learner thisLearner = new Learner()
  {

    public void AugmentPTA(computeStateScores pta, @SuppressWarnings("unused") RestartLearningEnum ptaKind,
        List<String> sequence, boolean accepted, @SuppressWarnings("unused") JUConstants newColour) {
      pta.augmentPTA(sequence, accepted);
    }

    /** Displays a tentative graph and asks user a supplied question.
     * Options are to be shown as choices in addition to yes/element_not_accepted.
     */
    public int CheckWithEndUser(computeStateScores graph, List<String> question, Object[] options) {
      return checkWithEndUser(graph.getGraph(),question, options);
    }

    /** Identifies a collection of states to merge, sorted in the order of scores. */
    public Stack<computeStateScores.PairScore> ChooseStatePairs(computeStateScores graph)
    {
      return graph.chooseStatePairs();
    }

    /** Given a pair of graphs, computes the set of questions to validate the merge which
     * resulted in the second graph
     *
     * @param original the original graph
     * @param temp the merged graph
     * @param pair the pair of states merged in the original graph
     */
    public List<List<String>> ComputeQuestions(computeStateScores.PairScore pair, computeStateScores original, computeStateScores temp)
    {
      List<List<String>> questions = new LinkedList<List<String>>();
      int score = pair.getScore();
      if(score <certaintyThreshold&&score>minCertaintyThreshold)
      {
        questions = sort(original.computeQS(pair, temp));
        if (questions.isEmpty())
          ++counterEmptyQuestions;
      }
      return questions;
    }

    /** Given a graph, merges a pair of states from it and returns the result. */
    public computeStateScores MergeAndDeterminize(computeStateScores original, StatePair pair)
    {
      return computeStateScores.mergeAndDeterminize(original, pair);   
    }

    public void Restart(@SuppressWarnings("unused"RestartLearningEnum mode) {
      // does nothing
    }

    /** Returns statistics reflecting the learning.
     */
    public String getResult() {
      return null;
    }

    public DirectedSparseGraph init(Collection<List<String>> plus,  Collection<List<String>> minus)
    {
      RPNIBlueFringeLearnerTestComponentOpt.this.init(plus, minus);
      return RPNIBlueFringeLearnerTestComponentOpt.this.scoreComputer.getGraph();
    }

    public DirectedSparseGraph init(@SuppressWarnings("unused"PTATestSequenceEngine engine,
        @SuppressWarnings("unused"int plusSize,
        @SuppressWarnings("unused"int minusSize)
    {
      throw new UnsupportedOperationException("cannot use PTAs here");
    }

    public DirectedSparseGraph learnMachine() {
      return RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine();
    }

    public DirectedSparseGraph learnMachine(PTATestSequenceEngine engine, int plusSize, int minusSize)
    {
      topLearner.init(engine, plusSize, minusSize);
      return RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine();
    }

    public DirectedSparseGraph learnMachine(Collection<List<String>> plus, Collection<List<String>> minus)
    {
      topLearner.init(plus, minus);
      return RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine();
    }

   
    public void setTopLevelListener(Learner top) {
      topLearner = top;
    }
   
  };
 
  Learner topLearner = thisLearner;
 
  protected computeStateScores scoreComputer = new computeStateScores(0);
 
  public DirectedSparseGraph learnMachine()
  {
    Map<Integer, AtomicInteger> whichScoresWereUsedForMerging = new HashMap<Integer,AtomicInteger>(),
      restartScoreDistribution = new HashMap<Integer,AtomicInteger>();
    Map<computeStateScores.PairScore, Integer> scoresToIterations = new HashMap<computeStateScores.PairScore, Integer>();
    Map<computeStateScores.PairScore, Integer> restartsToIterations = new HashMap<computeStateScores.PairScore, Integer>();
    computeStateScores newPTA = scoreComputer;// no need to clone - this is the job of mergeAndDeterminize anyway
    String pairsMerged = "";
    StringWriter report = new StringWriter();
    counterAccepted =0;counterRejected =0;counterRestarted = 0;counterEmptyQuestions = 0;report.write("\n[ PTA: "+scoreComputer.getStatistics(false)+" ] ");

    setChanged();
    /*dumpPTA(scoreComputer, "/tmp/initial_pta.xml");
     */
    Stack<computeStateScores.PairScore> possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
    int plusSize = sPlus.size(), minusSize = sMinus.size(), iterations = 0;
    final int restartOfInterest = -21;
   
    while(!possibleMerges.isEmpty()){
      iterations++;
      //populateScores(possibleMerges,possibleMergeScoreDistribution);
      computeStateScores.PairScore pair = possibleMerges.pop();
      if (counterRestarted == restartOfInterest) System.out.println("merging "+pair);
      computeStateScores temp = topLearner.MergeAndDeterminize(scoreComputer,pair);
      setChanged();
      Collection<List<String>> questions = topLearner.ComputeQuestions(pair,scoreComputer,temp);
     
      boolean restartLearning = false;// whether we need to rebuild a PTA and restart learning.

      Iterator<List<String>> questionIt = questions.iterator();
      while(questionIt.hasNext()){
        List<String> question = questionIt.next();
        String accepted = pair.getQ().getUserDatum(JUConstants.ACCEPTED).toString();
        int answer = topLearner.CheckWithEndUser(scoreComputer,question,new Object[] {"Test"});
       
        this.questionCounter++;
        if (answer == USER_CANCELLED)
        {
          System.out.println("CANCELLED");
          return null;
        }
       
        Vertex tempVertex = temp.getVertex(question);
        if (tempVertex == null)
          System.out.println();
       
        if(answer == USER_ACCEPTED)
        {
          ++counterAccepted;
          //sPlus.add(question);
          topLearner.AugmentPTA(newPTA, RestartLearningEnum.restartHARD, question, true,null);++plusSize;
          //System.out.println(setByAuto+question.toString()+ " <yes>");
          if (counterRestarted == restartOfInterest) System.out.println(question.toString()+ " <yes>");
         
          if(!TestRpniLearner.isAccept(tempVertex))
          {
            pairsMerged=pairsMerged+"ABOUT TO RESTART due to acceptance of a reject vertex for a pair "+pair+" ========\n";
            restartLearning = true;break;
          }
        }
        else
          if(answer >= 0)
          {// The sequence has been rejected by a user
            assert answer < question.size();
            ++counterRejected;
            LinkedList<String> subAnswer = new LinkedList<String>();subAnswer.addAll(question.subList(0, answer+1));
            //sMinus.add(subAnswer);
            topLearner.AugmentPTA(newPTA, RestartLearningEnum.restartHARD, subAnswer, false,null);++minusSize ;// important: since vertex IDs is
            // only unique for each instance of computeStateScores, only once
            // instance should ever receive calls to augmentPTA
            if (counterRestarted == restartOfInterest) System.out.println(question.toString()+ " <no> at position "+answer+", element "+question.get(answer));
            //System.out.println(setByAuto+question.toString()+ " <no> at position "+answer+", element "+question.get(answer));
            if( (answer < question.size()-1) || isAccept(tempVertex))
            {
              assert accepted.equals("true");
              pairsMerged=pairsMerged+"ABOUT TO RESTART because accept vertex was rejected for a pair "+pair+" ========\n";
              restartLearning = true;break;
            }
          }
          else
            throw new IllegalArgumentException("unexpected user choice");
       
      }

      if (restartLearning)
      {// restart learning
        //computeStateScores expected = createAugmentedPTA(sPlus, sMinus);// KIRR: node labelling is done by createAugmentedPTA
        scoreComputer = newPTA;// no need to clone - this is the job of mergeAndDeterminize anyway
        scoreComputer.clearColours();
        setChanged();++counterRestarted;
        //System.out.println("restarts - "+counterRestarted+" questions: "+(counterAccepted+counterRejected)+" states in PTA: "+newPTA.getStateNumber());
        //dumpPTA(scoreComputer,"/tmp/orig_restart"+counterRestarted);
        pairsMerged=pairsMerged+"========== RESTART "+counterRestarted+" ==========\n";
        AtomicInteger count = restartScoreDistribution.get(pair.getScore());
        if (count == null)
        {
          count = new AtomicInteger();restartScoreDistribution.put(pair.getScore(),count);
        }
        count.incrementAndGet();
        restartsToIterations.put(pair, iterations);
        iterations = 0;
        topLearner.Restart(RestartLearningEnum.restartHARD);
      }
      else
      {
        // At this point, scoreComputer may have been modified because it may point to
        // the original PTA which will be modified as a result of new sequences being added to it.
        // temp is different too, hence there is no way for me to compute compatibility score here.
        // This is hence computed inside the obtainPair method.
        pairsMerged=pairsMerged+pair+" questions: "+questions.size()+"\n";
       
        // keep going with the existing model
        scoreComputer = temp;
        // now update the statistics
        AtomicInteger count = whichScoresWereUsedForMerging.get(pair.getScore());
        if (count == null)
        {
          count = new AtomicInteger();whichScoresWereUsedForMerging.put(pair.getScore(),count);
        }
        count.incrementAndGet();
        scoresToIterations.put(pair, iterations);
        topLearner.Restart(RestartLearningEnum.restartNONE);
      }
     
      possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
      //System.out.println(possibleMerges);
    }
    report.write("\n[ Questions: "+counterAccepted+" accepted "+counterRejected+" rejected resulting in "+counterRestarted+ " restarts; "+counterEmptyQuestions+" empty sets of questions ]\n[ Learned automaton: "+scoreComputer.getStatistics(true)+" ] ");
    report.write("\n[ final sets of questions, plus: "+plusSize+" minus: "+minusSize+" ] ");
    report.write("\n[ Pair scores to iteration numbers:"+pairScoresAndIterations(scoresToIterations,"MERGED-ITERATIONS"));
    report.write("\n[ Restart scores to iteration numbers:"+pairScoresAndIterations(restartsToIterations,"RESTART-ITERATIONS"));
    report.write("\n[ Pairs merged (score-number of times):"+HistogramToSeries(whichScoresWereUsedForMerging,"MERGED"));
    report.write("\n[ Pairs restarted (score-number of times):"+HistogramToSeries(restartScoreDistribution,"RESTARTED"));
    report.write("\n Pair merge details: \n"+pairsMerged);
    DirectedSparseGraph result = scoreComputer.getGraph();result.addUserDatum(JUConstants.STATS, report.toString(), UserData.SHARED);
    updateGraph(result);
    return result;
  }
/* 
  protected static void dumpSets(String output, Collection<List<String>> sPlus, Collection<List<String>> sMinus)
  { 
    try
    {
      System.out.println("dumping sets");
      XMLEncoder encoder = new XMLEncoder(new BufferedOutputStream(new FileOutputStream(output)));
      encoder.writeObject(sPlus);
      encoder.writeObject(sMinus);
      encoder.close();
      throw new IllegalArgumentException("finished");
    }
    catch(FileNotFoundException e)
    {
      IllegalArgumentException ex = new IllegalArgumentException("failed to write output file");
      ex.initCause(e);throw ex;
    }   
  }
*/
  @Override
  protected List<List<String>> generateQuestions(
      @SuppressWarnings("unused") DirectedSparseGraph model,
      @SuppressWarnings("unused") DirectedSparseGraph temp,
      @SuppressWarnings("unused"StatePair pair) {
    throw new UnsupportedOperationException();
  }

  @Override
  protected Stack chooseStatePairs(
      @SuppressWarnings("unused"DirectedSparseGraph g,
      @SuppressWarnings("unused"Collection<List<String>> plus,
      @SuppressWarnings("unused"Collection<List<String>> minus) {
    throw new UnsupportedOperationException();
  }

  @Override
  protected int computeScore(
      @SuppressWarnings("unused"DirectedSparseGraph original,
      @SuppressWarnings("unused"StatePair blueRed) {
    throw new UnsupportedOperationException();
  }

  @Override
  protected DirectedSparseGraph createAugmentedPTA(
      @SuppressWarnings("unused"DirectedSparseGraph model,
      @SuppressWarnings("unused"Collection<List<String>> plus,
      @SuppressWarnings("unused"Collection<List<String>> minus) {
    throw new UnsupportedOperationException();
  }

  @Override
  protected List<List<String>> generateQuestions(
      @SuppressWarnings("unused"DirectedSparseGraph model,
      @SuppressWarnings("unused"StatePair pair) {
    throw new UnsupportedOperationException();
  }

  @Override
  protected DirectedSparseGraph mergeAndDeterminize(
      @SuppressWarnings("unused"Graph model,
      @SuppressWarnings("unused"StatePair pair) {
    throw new UnsupportedOperationException();
  }
 
    public static List<List<String>> sort(Collection<List<String>> data)
    {
      LinkedList<List<String>> result = new LinkedList<List<String>>();result.addAll(data);
      Collections.sort(result, new Comparator<List<String>>() {

      public int compare(List<String> o1, List<String> o2) {
        int len1 = o1.size(),len2 = o2.size();
        if (len1 < len2) return -1;else if (len1 > len2) return 1;
        Iterator<String> it1 = o1.iterator(),it2 = o2.iterator();
        while(it1.hasNext())
        {
          int cmpResult = it1.next().compareTo(it2.next());
          if (cmpResult != 0) return cmpResult;
        }
        return 0;
      }});
      return result;
    }
}
TOP

Related Classes of statechum.analysis.learning.RPNIBlueFringeLearnerTestComponentOpt

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.