Package org.renjin.stats.internals.models

Source Code of org.renjin.stats.internals.models.FormulaInterpreter

package org.renjin.stats.internals.models;

import org.renjin.eval.EvalException;
import org.renjin.sexp.FunctionCall;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Symbol;
import org.renjin.sexp.Vector;


/**
* Model formulas in R are defined with a Domain Specific Language (DSL)
* that describe a {@link Formula}. The DSL uses the same Abstract Syntax
* Tree (AST) as R expressions, but certain functions are interpreted differently
* in this context.
*
*/
public class FormulaInterpreter {
 
  private SEXP response;
  private int intercept = 1;
 
  private static final Symbol tilde = Symbol.get("~");
  private static final Symbol UNION = Symbol.get("+");
  private static final Symbol EXPAND_TERMS = Symbol.get("*");
  private static final Symbol DIFFERENCE = Symbol.get("-");
  private static final Symbol GROUP = Symbol.get("(");
 
  public Formula interpret(FunctionCall call) {
    if(call.getFunction() != tilde) {
      throw new EvalException("expected model formula (~)");
    }
   
    TermList terms = new TermList();
    if(call.getArguments().length() == 1) {
      response = null;
      add(terms, call.getArgument(0));
    } else if(call.getArguments().length() == 2) {
      response = call.getArgument(0);
      add(terms, call.getArgument(1));
    }
   
    return new Formula(response, intercept, terms.sorted());
  }
 
  /**
   * Build and return a TermList from the given SEXP.
   * @param argument the SEXP to interpret
   * @param subtracting true if we are subtracting and the intercept should be interpreted as
   *        negative
   * @return
   */
  private TermList buildTermList(SEXP argument, boolean subtracting) {
    TermList list = new TermList();
    add(list, argument, subtracting);
    return list;
  }
 
  private TermList buildTermList(SEXP argument) {
    return buildTermList(argument, false);
  }

  private void add(TermList list, SEXP argument, boolean subtracting) {
    if(argument instanceof Symbol) {
      list.add(new Term(argument));
    } else if(argument instanceof Vector) {
      intercept((Vector)argument, subtracting);
    } else if(argument instanceof FunctionCall) {
      FunctionCall call = (FunctionCall)argument;
      if(call.getFunction() == UNION) {
        unionTerms(list, call);
      } else if(call.getFunction() == EXPAND_TERMS) {
        multiply(list, call);
      } else if(call.getFunction() == DIFFERENCE) {
        difference(list, call);
      } else if(call.getFunction() == GROUP) {
        add(list, call.getArgument(0), subtracting);
      } else {
        list.add(new TermBuilder().build(call));
      }
    }
  }
 
  private void add(TermList list, SEXP argument) {
    add(list, argument, false);
  }
 
  private void multiply(TermList terms, FunctionCall call) {
    TermList a = buildTermList(call.getArgument(0));
    TermList b = buildTermList(call.getArgument(1));
   
    terms.add(a);
    terms.add(b);
   
    for(Term a_i : a) {
      for(Term b_i : b) {
        terms.add(new Term(a_i, b_i));
      }
    }
  }

  private void unionTerms(TermList terms, FunctionCall call) {
    for(SEXP argument : call.getArguments().values()) {
      add(terms, argument);
    }
  }
 
  private void difference(TermList terms, FunctionCall call) {
   
    if(call.getArguments().length() == 1) {
      // the difference between an empty set and any other set is the empty set,
      // so we don't add any terms to our parent list, but we do
      // need to look for a negative intercept
      buildTermList(call.getArgument(0), true);
     
    } else {
      TermList a = buildTermList(call.getArgument(0));
      TermList b = buildTermList(call.getArgument(1), true);
     
      a.subtract(b);
     
      terms.add(a);
    }
  }

  private void intercept(Vector vector, boolean subtracting) {
 
    if(vector.length() != 1) {
      throw new EvalException("Invalid intercept: " + vector.toString() + ", expected 0 or 1");
    }
    intercept = vector.getElementAsInt(0);
    if(intercept != 0 && intercept != 1) {
      throw new EvalException("Invalid intercept: " + intercept + ", expected 0 or 1");
    }
    if(subtracting) {
      intercept = (intercept == 0) ? 1 : 0;
    }
  }
}
TOP

Related Classes of org.renjin.stats.internals.models.FormulaInterpreter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.