Package hivemall.classifier

Source Code of hivemall.classifier.PassiveAggressiveUDTFTest

/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2013
*   National Institute of Advanced Industrial Science and Technology (AIST)
*   Registration Number: H25PRO-1520
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
*/
package hivemall.classifier;

import hivemall.io.PredictionResult;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertEquals;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.junit.Test;

public class PassiveAggressiveUDTFTest {

    @Test
    public void testInitialize() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        /* test for INT_TYPE_NAME feature */
        StructObjectInspector intListSOI = udtf.initialize(
            new ObjectInspector[]{intListOI, intOI});
        assertEquals("struct<feature:int,weight:float>", intListSOI.getTypeName());

        /* test for STRING_TYPE_NAME feature */
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        StructObjectInspector stringListSOI = udtf.initialize(
            new ObjectInspector[]{stringListOI, intOI});
        assertEquals("struct<feature:string,weight:float>", stringListSOI.getTypeName());

        /* test for BIGINT_TYPE_NAME feature */
        ObjectInspector longOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
        ListObjectInspector longListOI = ObjectInspectorFactory.getStandardListObjectInspector(longOI);
        StructObjectInspector longListSOI = udtf.initialize(
            new ObjectInspector[]{longListOI, intOI});
        assertEquals("struct<feature:bigint,weight:float>", longListSOI.getTypeName());
    }

    @Test
    public void testTrain() throws HiveException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
        udtf.initialize(new ObjectInspector[]{intListOI, intOI});

        /* train weights by List<Object> */
        List<Integer> features1 = new ArrayList<Integer>();
        features1.add(1);
        features1.add(2);
        features1.add(3);
        udtf.train(features1, 1);

        /* check weights */
        assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
        assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
        assertEquals(0.3333333f, udtf.model.get(3).get(), 1e-5f);

        /* train weights by Object[] */
        List<?> features2 = (List<?>) intListOI.getList(new Object[]{3, 4, 5});
        udtf.train(features2, 1);

        /* check weights */
        assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
        assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
        assertEquals(0.5555555f, udtf.model.get(3).get(), 1e-5f);
        assertEquals(0.2222222f, udtf.model.get(4).get(), 1e-5f);
        assertEquals(0.2222222f, udtf.model.get(5).get(), 1e-5f);
    }

    @Test
    public void testEta() {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
        float loss = 0.1f;

        PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
        float expectedLearningRate1 = 2.0f;
        assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);

        PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
        float expectedLearningRate2 = 10.0f;
        assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
    }

    @Test
    public void testPA1Eta() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            new String("-c 3.0")
        );

        /* do initialize() with aggressiveness parameter */
        udtf.initialize(new ObjectInspector[]{intListOI, intOI, param});
        float loss = 0.1f;

        PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
        float expectedLearningRate1 = 2.0f;
        assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);

        PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
        float expectedLearningRate2 = 3.0f;
        assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
    }

    @Test
    public void testPA1EtaDefaultParameter() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        udtf.initialize(new ObjectInspector[]{intListOI, intOI});
        float loss = 0.1f;

        PredictionResult margin = new PredictionResult(0.5f).squaredNorm(0.05f);
        float expectedLearningRate = 1.0f;
        assertEquals(expectedLearningRate, udtf.eta(loss, margin), 1e-5f);
    }

    @Test
    public void testPA1TrainWithoutParameter() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        /* define aggressive parameter */
        udtf.initialize(new ObjectInspector[]{intListOI, intOI});

        /* train weights */
        List<?> features = (List<?>) intListOI.getList(new Object[]{1, 2, 3});
        udtf.train(features, 1);

        /* check weights */
        assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
        assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
        assertEquals(0.3333333f, udtf.model.get(3).get(), 1e-5f);
    }

    @Test
    public void testPA1TrainWithParameter() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            new String("-c 0.1")
        );
        /* define aggressive parameter */
        udtf.initialize(new ObjectInspector[]{intListOI, intOI, param});

        /* train weights */
        List<?> features = (List<?>) intListOI.getList(new Object[]{1, 2, 3});
        udtf.train(features, 1);

        /* check weights */
        assertEquals(0.1000000f, udtf.model.get(1).get(), 1e-5f);
        assertEquals(0.1000000f, udtf.model.get(2).get(), 1e-5f);
        assertEquals(0.1000000f, udtf.model.get(3).get(), 1e-5f);
    }

    @Test
    public void testPA2EtaWithoutParameter() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        /* do initialize() with aggressiveness parameter */
        udtf.initialize(new ObjectInspector[]{intListOI, intOI});
        float loss = 0.1f;

        PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
        float expectedLearningRate1 = 0.1818181f;
        assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);

        PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
        float expectedLearningRate2 = 0.1960784f;
        assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
    }

    @Test
    public void testPA2EtaWithParameter() throws UDFArgumentException {
        PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            new String("-c 3.0")
        );

        /* do initialize() with aggressiveness parameter */
        udtf.initialize(new ObjectInspector[]{intListOI, intOI, param});
        float loss = 0.1f;

        PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
        float expectedLearningRate1 = 0.4615384f;
        assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);

        PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
        float expectedLearningRate2 = 0.5660377f;
        assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
    }

}
TOP

Related Classes of hivemall.classifier.PassiveAggressiveUDTFTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.