Package org.apache.ctakes.coreference.util

Source Code of org.apache.ctakes.coreference.util.SvmVectorCreator

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.coreference.util;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;

import org.apache.ctakes.constituency.parser.treekernel.TreeExtractor;
import org.apache.ctakes.constituency.parser.util.TreeUtils;
import org.apache.ctakes.core.resource.FileLocator;
import org.apache.ctakes.coreference.type.Markable;
import org.apache.ctakes.utils.tree.FragmentUtils;
import org.apache.ctakes.utils.tree.SimpleTree;
import org.apache.ctakes.utils.wiki.WikiIndex;
import org.apache.uima.jcas.JCas;
import org.apache.uima.pear.util.FileUtil;

public class SvmVectorCreator {
  HashSet<String> stopwords = null;
  private svm_model anaph_model = null;
  ArrayList<SimpleTree> frags = new ArrayList<SimpleTree>();
  WikiIndex wiki = null;
  static final int NUM_WIKI_HITS = 5;
 
  public SvmVectorCreator(HashSet<String> stopwords){
    this.stopwords = stopwords;
    try{
      wiki = new WikiIndex(NUM_WIKI_HITS, FileLocator.locateFile("org/apache/ctakes/coreference/models/index_med_5k").getAbsolutePath(), "text");
      wiki.initialize();
    }catch(IOException e){
      e.printStackTrace();
      wiki = null;
    }
  }

  public svm_node[] createAnaphoricityVector(Markable m, JCas aJCas) {
//    String vec = isGoldMarkable(m)?"1":"0";
    String[] feats = FeatureVector.getAnaphoricityFeatures();
    ArrayList<svm_node> nodeList = new ArrayList<svm_node>();
   
    AnaphoricityAttributeCalculator aac = new AnaphoricityAttributeCalculator(aJCas, m);
    aac.setStopWordsList(stopwords);
    for (int i = 0; i < feats.length; i++) {
      try {
        Object val = aac.getClass().getMethod("calcm"+feats[i]).invoke(aac);
        if (val instanceof String) {
          String s = (String) val;
          if (s.equals("yes") || s.equals("Y") || s.equals("C") || s.equals("P")) {
//            vec += " " + (i+1) + ":1";
            svm_node node = new svm_node();
            node.index = i+1;
            node.value = 1;
            nodeList.add(node);
          } else if (s.equals("S")) {
//            vec += " " + (i+1) + ":0.5";
            svm_node node = new svm_node();
            node.index = i+1;
            node.value = 0.5;
            nodeList.add(node);
          } else if (!s.equals("U") && !s.equals("N") && !s.equals("no") && !s.equals("I")){ //TODO: cleanup
            int ii = Integer.parseInt(s);
            if (ii>=0 && ii<=6){
//              vec += " " + (i+1) + ":" + (ii/6.0);
              svm_node node = new svm_node();
              node.index = i+1;
              node.value = ii/6.0;
              nodeList.add(node);
            }
           
          }
        }
      } catch (Exception e) { e.printStackTrace(); }
    }
    return nodeList.toArray(new svm_node[]{});
//    AnaphoricityVecInstance avi = new AnaphoricityVecInstance(aJCas);
//    avi.setVector(vec);
//    avi.setM(m.getCoveredText());
//    avi.setOffset(m.getBegin()+"-"+m.getEnd());
//    avi.addToIndexes();
  }


  private double calcAnaphoricity (JCas aJCas, Markable m) {
    svm_node[] nodes = createAnaphoricityVector(m, aJCas);
    double[] prob = new double[2];
    svm.svm_predict_probability(anaph_model, nodes, prob);
    int[] labels = new int[2];
    svm.svm_get_labels(anaph_model, labels);
    int anaph_idx = labels[0]==1 ? 0 : 1;
    return prob[anaph_idx];
  }

 
  public svm_node[] getNodeFeatures(Markable anaphor, Markable antecedent, JCas aJCas) {
    return getNodeFeatures(anaphor, antecedent, aJCas, false);
  }
 
  public svm_node[] getNodeFeatures(Markable anaphor, Markable antecedent, JCas aJCas, boolean needsAnaph) {
    LinkedList<svm_node> nodes = new LinkedList<svm_node>();
    String[] feats = FeatureVector.getNECorefFeatures();
    SyntaxAttributeCalculator sac = new SyntaxAttributeCalculator(aJCas, antecedent, anaphor, wiki);
    sac.setStopWordsList(stopwords);
    int ind = 0;
    for (int i = 0; i < feats.length; i++, ind++) {
      try {
        if (feats[i].equals("Anaph")) {
          double anaph_prob = needsAnaph ? calcAnaphoricity(aJCas, anaphor) : 1.0;
//          vec += " " + (ind+1) + ":" + anaph_prob;
          svm_node n = new svm_node();
          n.index = ind+1;
          n.value = anaph_prob;
          nodes.add(n);
          continue;
        }
        if (feats[i].contains(":")){
          String[] catFeat = feats[i].split(":");
          String featType = catFeat[0];
          String featName = catFeat[1];
          // if catFeat[0].equalsIgnoreCase("cat"){}   // Not necessary, no other feature namespaces used yet.
          String methodName = "num" + featName;
          //          Class<String> strClass = new Class<String>();
          int num = (Integer) sac.getClass().getMethod(methodName, Markable.class).invoke(sac, anaphor);
          methodName = "calc" + featType + featName;
          Method method = sac.getClass().getMethod(methodName, Integer.class, Markable.class);
          for(int j = 0; j < num; j++, ind++){
            String val = (String) method.invoke(sac, j, anaphor);
            if(val.equalsIgnoreCase("Y")){
              svm_node n = new svm_node();
              n.index = ind+1;
              n.value = 1;
              nodes.add(n);
//              vec += " " + (ind+1) + ":1";
            }
          }
          continue;
        }else{
          Object val = sac.getClass().getMethod("calc"+feats[i]).invoke(sac);
          if (val instanceof String) {
            String s = (String) val;
            if (s.equals("yes") || s.equals("Y") || s.equals("C")) {
              //            vec += " " + (ind+1) + ":1";
              svm_node n = new svm_node();
              n.index = ind+1;
              n.value = 1;
              nodes.add(n);
            }
          }
          else if (val instanceof Integer) {
            int v = ((Integer)val).intValue();
            if (v!=0) {
              //            vec += " " + (ind+1) + ":" + ((double)v/(i==0?600:10));
              svm_node n = new svm_node();
              n.index = ind+1;
              n.value = (double) v; // ((double)v/(i==0?600:10));
              nodes.add(n);
            }
          }
          else if (val instanceof Double) {
            //          vec += " " + (ind+1) + ":" + val;
            if((Double)val != 0.0){
              svm_node n = new svm_node();
              n.index = ind+1;
              n.value = (Double) val;
              nodes.add(n);
            }
          }else if (val instanceof Boolean) {
            if((Boolean) val == true){
              svm_node n = new svm_node();
              n.index = ind + 1;
              n.value = 1.0;
              nodes.add(n);
            }
          }

        }
      } catch (Exception e) { e.printStackTrace(); }
    }
   
    if(frags != null && frags.size() > 0){
      SimpleTree tn = TreeExtractor.extractPathTree(MarkableTreeUtils.markableNode(aJCas, antecedent.getBegin(), antecedent.getEnd()),
          MarkableTreeUtils.markableNode(aJCas, anaphor.getBegin(), anaphor.getEnd()));
//      SimpleTree tn = TreeExtractor.extractPathEnclosedTree(MarkableTreeUtils.markableNode(aJCas, antecedent.getBegin(), antecedent.getEnd()),
//          MarkableTreeUtils.markableNode(aJCas, anaphor.getBegin(), anaphor.getEnd()),
//          aJCas);
      // now go over the tree fragment features:
      for(SimpleTree frag : frags){
        if(TreeUtils.contains(tn, frag)){
          svm_node n = new svm_node();
          n.index = ind+1;
          n.value = 1.0;
          nodes.add(n);
        }
        ind++;
      }
    }
    return nodes.toArray(new svm_node[]{});
  }


  public void setFrags(Collection<String> treeFrags) {
    for(String frag : treeFrags){
      frags.add(FragmentUtils.frag2tree(frag));
    }
  }
}
TOP

Related Classes of org.apache.ctakes.coreference.util.SvmVectorCreator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.