/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.pmml.pmml_4_2.predictive.models;
import org.dmg.pmml.pmml_4_2.descr.MISSINGVALUESTRATEGY;
import org.dmg.pmml.pmml_4_2.descr.PMML;
import org.dmg.pmml.pmml_4_2.descr.TreeModel;
import org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest;
import org.drools.pmml.pmml_4_2.PMML4Compiler;
import org.drools.pmml.pmml_4_2.PMML4Helper;
import org.junit.After;
import org.junit.Test;
import org.kie.api.definition.type.FactType;
import org.kie.api.runtime.ClassObjectFilter;
import org.kie.api.runtime.KieSession;
import org.kie.internal.io.ResourceFactory;
import java.util.Collection;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class DecisionTreeTest extends DroolsAbstractPMMLTest {
private static final boolean VERBOSE = false;
private static final String source1 = "org/drools/pmml/pmml_4_2/test_tree_simple.xml";
private static final String source2 = "org/drools/pmml/pmml_4_2/test_tree_missing.xml";
private static final String packageName = "org.drools.pmml.pmml_4_2.test";
@After
public void tearDown() {
getKSession().dispose();
}
@Test
public void testSimpleTree() throws Exception {
setKSession( getModelSession( source1, VERBOSE ) );
setKbase( getKSession().getKieBase() );
KieSession kSession = getKSession();
// kSession.addEventListener( new org.drools.event.rule.DebugAgendaEventListener() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld5" );
kSession.getEntryPoint( "in_Fld1" ).insert( 30.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( 60.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "false" );
kSession.getEntryPoint( "in_Fld4" ).insert( "optA" );
kSession.fireAllRules();
checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" );
}
protected Object getToken( KieSession kSession ) {
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
assertNotNull( tok );
Collection c = kSession.getObjects( new ClassObjectFilter( tok.getFactClass() ) );
assertEquals( 1, c.size() );
return c.iterator().next();
}
@Test
public void testMissingTree() throws Exception {
setKSession( getModelSession( source2, VERBOSE ) );
setKbase( getKSession().getKieBase() );
KieSession kSession = getKSession();
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( 45.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( 60.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.6, tok.get( token, "confidence" ) );
assertEquals( "null", tok.get( token, "current" ) );
checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtZ" );
}
@Test
public void testMissingTreeWeighted1() throws Exception {
setKSession( getModelSession( source2, VERBOSE ) );
setKbase( getKSession().getKieBase() );
KieSession kSession = getKSession();
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.8, tok.get( token, "confidence" ) );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 50.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testMissingTreeWeighted2() throws Exception {
setKSession( getModelSession( source2, VERBOSE ) );
setKbase( getKSession().getKieBase() );
KieSession kSession = getKSession();
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.6, tok.get( token, "confidence" ) );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 100.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testMissingTreeDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD );
}
}
KieSession kSession = getSession( compiler.generateTheory( pmml ) );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( 70.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( 40.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.72, (Double) tok.get( token, "confidence" ), 1e-6 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 40.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testMissingTreeAllMissingDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD );
}
}
String theory = compiler.generateTheory( pmml );
if ( VERBOSE ) {
System.out.println( theory );
}
KieSession kSession = getSession( theory );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 1.0, (Double) tok.get( token, "confidence" ), 1e-6 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 0.0, tok.get( token, "totalCount" ) );
// checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testMissingTreeLastChoice() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.LAST_PREDICTION );
}
}
String theory = compiler.generateTheory( pmml );
if ( VERBOSE ) {
System.out.println( theory );
}
KieSession kSession = getSession( theory );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.8, (Double) tok.get( token, "confidence" ), 1e-6 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 50.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testMissingTreeNull() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NULL_PREDICTION );
}
}
String theory = compiler.generateTheory( pmml );
if ( VERBOSE ) {
System.out.println( theory );
}
KieSession kSession = getSession( theory );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.0, (Double) tok.get( token, "confidence" ), 1e-6 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 0.0, tok.get( token, "totalCount" ) );
assertEquals( 0, getKSession().getObjects( new ClassObjectFilter( tgt.getFactClass() ) ).size() );
}
@Test
public void testMissingAggregate() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.AGGREGATE_NODES );
}
}
String theory = compiler.generateTheory( pmml );
if ( VERBOSE ) {
System.out.println( theory );
}
KieSession kSession = getSession( theory );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( 45.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( 90.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.47, (Double) tok.get( token, "confidence" ), 1e-2 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 60.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" );
}
@Test
public void testMissingTreeNone() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );
for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
if ( o instanceof TreeModel ) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NONE );
}
}
String theory = compiler.generateTheory( pmml );
if ( VERBOSE ) {
System.out.println( theory );
}
KieSession kSession = getSession( theory );
setKSession( kSession );
setKbase( getKSession().getKieBase() );
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.6, (Double) tok.get( token, "confidence" ), 1e-6 );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 100.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" );
}
@Test
public void testSimpleTreeOutput() throws Exception {
setKSession( getModelSession( source2, VERBOSE ) );
setKbase( getKSession().getKieBase() );
KieSession kSession = getKSession();
kSession.fireAllRules(); //init model
FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );
kSession.fireAllRules();
Object token = getToken( kSession );
assertEquals( 0.8, tok.get( token, "confidence" ) );
assertEquals( "null", tok.get( token, "current" ) );
assertEquals( 50.0, tok.get( token, "totalCount" ) );
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
checkFirstDataFieldOfTypeStatus( kSession.getKieBase().getFactType( packageName, "OutClass" ),
true, false, "Missing", "tgtX" );
checkFirstDataFieldOfTypeStatus( kSession.getKieBase().getFactType( packageName, "OutProb" ),
true, false, "Missing", 0.8 );
}
}