Package cc.mallet.fst

Examples of cc.mallet.fst.CRF


    int numStates = 5;
    Alphabet inputAlphabet = new Alphabet();
    for (int i = 0; i < inputVocabSize; i++)
      inputAlphabet.lookupIndex("feature" + i);
    Alphabet outputAlphabet = new Alphabet();
    CRF crf = new CRF(inputAlphabet, outputAlphabet);
    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    crf.addFullyConnectedStates(stateNames);
    CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
    Optimizable.ByGradientValue mcrf = crft
        .getOptimizableCRF(new InstanceList(null));
    TestOptimizable.testGetSetParameters(mcrf);
  }
View Full Code Here


    Alphabet inputAlphabet = new Alphabet();
    for (int i = 0; i < inputVocabSize; i++)
      inputAlphabet.lookupIndex("feature" + i);
    Alphabet outputAlphabet = new Alphabet();

    CRF crf = new CRF(inputAlphabet, outputAlphabet);

    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    crf.addFullyConnectedStates(stateNames);

    crf.setWeightsDimensionDensely();
    crf.getState(0).setInitialWeight(1.0);
    crf.getState(1).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
    crf.getState(0).setFinalWeight(0.0);
    crf.getState(1).setFinalWeight(0.0);
    crf.setParameter(0, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state0
    // self-transition
    crf.setParameter(0, 1, 0, 1.0); // state0->state1
    crf.setParameter(1, 1, 0, 1.0); // state1 self-transition
    crf.setParameter(1, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state1->state0

    FeatureVectorSequence fvs = new FeatureVectorSequence(
        new FeatureVector[] {
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }),
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }),
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }), });

    SumLattice lattice = new SumLatticeDefault(crf, fvs, true);
    // We start in state0
    assertTrue(lattice.getGammaProbability(0, crf.getState(0)) == 1.0);
    assertTrue(lattice.getGammaProbability(0, crf.getState(1)) == 0.0);
    // We go to state1
    assertTrue(lattice.getGammaProbability(1, crf.getState(0)) == 0.0);
    assertTrue(lattice.getGammaProbability(1, crf.getState(1)) == 1.0);
    // And on through a self-transition
    assertTrue(lattice
        .getXiProbability(1, crf.getState(1), crf.getState(1)) == 1.0);
    assertTrue(lattice
        .getXiProbability(1, crf.getState(1), crf.getState(0)) == 0.0);
    assertTrue("Lattice weight = " + lattice.getTotalWeight(), lattice
        .getTotalWeight() == 4.0);
    // Gammas at all times sum to 1.0
    for (int time = 0; time < lattice.length() - 1; time++) {
      double gammasum = lattice
          .getGammaProbability(time, crf.getState(0))
          + lattice.getGammaProbability(time, crf.getState(1));
      assertEquals("Gammas at time step " + time + " sum to " + gammasum,
          1.0, gammasum, 0.0001);
    }
    // Xis at all times sum to 1.0
    for (int time = 0; time < lattice.length() - 1; time++) {
      double xissum = lattice.getXiProbability(time, crf.getState(0), crf
          .getState(0))
          + lattice.getXiProbability(time, crf.getState(0), crf
              .getState(1))
          + lattice.getXiProbability(time, crf.getState(1), crf
              .getState(0))
          + lattice.getXiProbability(time, crf.getState(1), crf
              .getState(1));
      assertEquals("Xis at time step " + time + " sum to " + xissum, 1.0,
          xissum, 0.0001);
    }
  }
View Full Code Here


  private static CRF loadCrf (File crfFile) throws IOException
  {
     ObjectInputStream ois = new ObjectInputStream( new FileInputStream( crfFile ) );
    CRF crf = null;

    // We shouldn't run into a ClassNotFound exception...
    try {
      crf = (CRF)ois.readObject();
    } catch (ClassNotFoundException e) {
View Full Code Here

     InstanceList training = new InstanceList (pipe);
     training.addThruPipe (new ArrayIterator (data0));
     InstanceList testing = new InstanceList (pipe);
     testing.addThruPipe (new ArrayIterator (data1));

     CRF crf = new CRF (pipe, null);
     crf.addFullyConnectedStatesForLabels ();
     CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
     crft.trainIncremental (training);

     CRFExtractor extor = TestLatticeViewer.hackCrfExtor (crf);
     Extraction extraction = extor.extract (new ArrayIterator (data1));
View Full Code Here

    Alphabet inputAlphabet = new Alphabet();
    for (int i = 0; i < inputVocabSize; i++)
      inputAlphabet.lookupIndex("feature" + i);
    Alphabet outputAlphabet = new Alphabet();

    CRF crf = new CRF(inputAlphabet, outputAlphabet);

    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    crf.addFullyConnectedStates(stateNames);

    crf.setWeightsDimensionDensely();
    crf.getState(0).setInitialWeight(1.0);
    crf.getState(1).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
    crf.getState(0).setFinalWeight(0.0);
    crf.getState(1).setFinalWeight(0.0);
    crf.setParameter(0, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state0
    // self-transition
    crf.setParameter(0, 1, 0, 1.0); // state0->state1
    crf.setParameter(1, 1, 0, 1.0); // state1 self-transition
    crf.setParameter(1, 0, 0, Transducer.IMPOSSIBLE_WEIGHT); // state1->state0

    FeatureVectorSequence fvs = new FeatureVectorSequence(
        new FeatureVector[] {
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }),
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }),
            new FeatureVector((Alphabet) crf.getInputAlphabet(),
                new double[] { 1 }), });

    MaxLattice lattice = new MaxLatticeDefault(crf, fvs);
    Sequence<Transducer.State> viterbiPath = lattice.bestStateSequence();
    // We start in state0
    assertTrue(viterbiPath.get(0) == crf.getState(0));
    // We go to state1
    assertTrue(viterbiPath.get(1) == crf.getState(1));
    // And on through a self-transition to state1 again
    assertTrue(viterbiPath.get(2) == crf.getState(1));
  }
View Full Code Here

    for (int i = 0; i < numStates; i++) {
      stateNames[i] = "state" + i;
      outputAlphabet.lookupIndex(stateNames[i]);
    }

    CRF crf = new CRF(inputAlphabet, outputAlphabet);
    CRF saveCRF = crf;
    // inputAlphabet = (Feature.Alphabet) crf.getInputAlphabet();
    FeatureVectorSequence fvs = new FeatureVectorSequence(
        new FeatureVector[] {
            new FeatureVector(crf.getInputAlphabet(), new int[] {
                1, 2, 3 }),
View Full Code Here

    InstanceList instances = new InstanceList(p);
    instances.addThruPipe(new ArrayIterator(data));
    InstanceList[] lists = instances.split(new Random(1), new double[] {
        .5, .5 });
    CRF crf = new CRF(p, p2);
    crf.addFullyConnectedStatesForLabels();
    CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue optable = crft
          .getOptimizableCRF(lists[0]);
      // TestOptimizable.testValueAndGradient(minable);
      double[] gradient = new double[optable.getNumParameters()];
      optable.getValueGradient(gradient);
      // TestOptimizable.testValueAndGradientInDirection(optable,
      // gradient);
      // TestOptimizable.testValueAndGradientCurrentParameters(optable);
      TestOptimizable.testValueAndGradient(optable); // This tests at
      // current
      // parameters and at
      // parameters
      // purturbed toward
      // the gradient
    } else {
      System.out.println("Training Accuracy before training = "
          + crf.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = "
          + crf.averageTokenAccuracy(lists[1]));
      System.out.println("Training...");
      crft.trainIncremental(lists[0]);
      System.out.println("Training Accuracy after training = "
          + crf.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy after training = "
          + crf.averageTokenAccuracy(lists[1]));
      System.out.println("Training results:");
      for (int i = 0; i < lists[0].size(); i++) {
        Instance inst = lists[0].get(i);
        Sequence input = (Sequence) inst.getData();
        Sequence output = crf.transduce(input);
        System.out.println(output);
      }
      System.out.println("Testing results:");
      for (int i = 0; i < lists[1].size(); i++) {
        Instance inst = lists[1].get(i);
        Sequence input = (Sequence) inst.getData();
        Sequence output = crf.transduce(input);
        System.out.println(output);
      }
    }
  }
View Full Code Here

  public void doTestSpacePrediction(boolean testValueAndGradient,
      boolean useSaved, boolean useSparseWeights) {
    Pipe p = makeSpacePredictionPipe();

    CRF savedCRF;
    File f = new File("TestObject.obj");
    InstanceList instances = new InstanceList(p);
    instances.addThruPipe(new ArrayIterator(data));
    InstanceList[] lists = instances.split(new double[] { .5, .5 });
    CRF crf = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf.addFullyConnectedStatesForLabels();
    CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
    crft.setUseSparseWeights(useSparseWeights);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue minable = crft
          .getOptimizableCRF(lists[0]);
      TestOptimizable.testValueAndGradient(minable);
    } else {
      System.out.println("Training Accuracy before training = "
          + crf.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = "
          + crf.averageTokenAccuracy(lists[1]));
      savedCRF = crf;
      System.out.println("Training serialized crf.");
      crft.trainIncremental(lists[0]);
      double preTrainAcc = crf.averageTokenAccuracy(lists[0]);
      double preTestAcc = crf.averageTokenAccuracy(lists[1]);
      System.out.println("Training Accuracy after training = "
          + preTrainAcc);
      System.out.println("Testing  Accuracy after training = "
          + preTestAcc);
      try {
        ObjectOutputStream oos = new ObjectOutputStream(
            new FileOutputStream(f));
        oos.writeObject(crf);
        oos.close();
      } catch (IOException e) {
        System.err.println("Exception writing file: " + e);
      }
      System.err.println("Wrote out CRF");
      System.err.println("CRF parameters. hyperbolicPriorSlope: "
          + crft.getUseHyperbolicPriorSlope()
          + ". hyperbolicPriorSharpness: "
          + crft.getUseHyperbolicPriorSharpness()
          + ". gaussianPriorVariance: "
          + crft.getGaussianPriorVariance());
      // And read it back in
      if (useSaved) {
        crf = null;
        try {
          ObjectInputStream ois = new ObjectInputStream(
              new FileInputStream(f));
          crf = (CRF) ois.readObject();
          ois.close();
        } catch (IOException e) {
          System.err.println("Exception reading file: " + e);
        } catch (ClassNotFoundException cnfe) {
          System.err
              .println("Cound not find class reading in object: "
                  + cnfe);
        }
        System.err.println("Read in CRF.");
        crf = savedCRF;

        double postTrainAcc = crf.averageTokenAccuracy(lists[0]);
        double postTestAcc = crf.averageTokenAccuracy(lists[1]);
        System.out.println("Training Accuracy after saving = "
            + postTrainAcc);
        System.out.println("Testing  Accuracy after saving = "
            + postTestAcc);
View Full Code Here

        new double[] { .5, .5 });

    // Compare 3 CRFs trained with addOrderNStates, and make sure
    // that having more features leads to a higher likelihood

    CRF crf1 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf1.addOrderNStates(lists[0], new int[] { 1, },
        new boolean[] { false, }, "START", null, null, false);
    new CRFTrainerByLabelLikelihood(crf1).trainIncremental(lists[0]);

    CRF crf2 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf2.addOrderNStates(lists[0], new int[] { 1, 2, }, new boolean[] {
        false, true }, "START", null, null, false);
    new CRFTrainerByLabelLikelihood(crf2).trainIncremental(lists[0]);

    CRF crf3 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf3.addOrderNStates(lists[0], new int[] { 1, 2, }, new boolean[] {
        false, false }, "START", null, null, false);
    new CRFTrainerByLabelLikelihood(crf3).trainIncremental(lists[0]);

    // Prevent cached values
    double lik1 = getLikelihood(crf1, lists[0]);
View Full Code Here

    Pipe p = makeSpacePredictionPipe();

    InstanceList instances = new InstanceList(p);
    instances.addThruPipe(new ArrayIterator(data));

    CRF crf1 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf1.addFullyConnectedStatesForLabels();
    CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood(
        crf1);
    crft1.trainIncremental(instances);

    CRF crf2 = new CRF(p.getDataAlphabet(), p.getTargetAlphabet());
    crf2.addFullyConnectedStatesForLabels();
    // Freeze some weights, before training
    for (int i = 0; i < crf2.getWeights().length; i += 2)
      crf2.freezeWeights(i);
    CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood(
        crf2);
    crft2.trainIncremental(instances);

    SparseVector[] w = crf2.getWeights();
    double[] b = crf2.getDefaultWeights();
    for (int i = 0; i < w.length; i += 2) {
      assertEquals(0.0, b[i], 1e-10);
      for (int loc = 0; loc < w[i].numLocations(); loc++) {
        assertEquals(0.0, w[i].valueAtLocation(loc), 1e-10);
      }
View Full Code Here

TOP

Related Classes of cc.mallet.fst.CRF

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.