Package aima.test.core.unit.learning.reinforcement

Source Code of aima.test.core.unit.learning.reinforcement.ReinforcementLearningTest

package aima.test.core.unit.learning.reinforcement;

import java.util.Hashtable;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import aima.core.environment.cellworld.CellWorld;
import aima.core.environment.cellworld.CellWorldPosition;
import aima.core.learning.reinforcement.PassiveADPAgent;
import aima.core.learning.reinforcement.PassiveTDAgent;
import aima.core.learning.reinforcement.QLearningAgent;
import aima.core.learning.reinforcement.QTable;
import aima.core.probability.mdp.MDP;
import aima.core.probability.mdp.MDPFactory;
import aima.core.probability.mdp.MDPPerception;
import aima.core.probability.mdp.MDPPolicy;
import aima.core.probability.mdp.MDPUtilityFunction;
import aima.core.util.MockRandomizer;
import aima.core.util.Randomizer;
import aima.core.util.datastructure.Pair;

/**
* @author Ravi Mohan
*
*/
public class ReinforcementLearningTest {
  MDP<CellWorldPosition, String> fourByThree;

  MDPPolicy<CellWorldPosition, String> policy;

  @Before
  public void setUp() {
    fourByThree = MDPFactory.createFourByThreeMDP();

    policy = new MDPPolicy<CellWorldPosition, String>();

    policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP);
    policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT);
    policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT);
    policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT);

    policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP);
    policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP);

    policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT);
    policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT);
    policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT);
  }

  @Test
  public void testPassiveADPAgent() {

    PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>(
        fourByThree, policy);

    // Randomizer r = new JavaRandomizer();
    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
        0.3, 0.7, 0.4, 0.6, 0.5 });
    MDPUtilityFunction<CellWorldPosition> uf = null;
    for (int i = 0; i < 100; i++) {
      agent.executeTrial(r);
      uf = agent.getUtilityFunction();

    }

    Assert.assertEquals(0.676, uf.getUtility(new CellWorldPosition(1, 1)),
        0.001);
    Assert.assertEquals(0.626, uf.getUtility(new CellWorldPosition(1, 2)),
        0.001);
    Assert.assertEquals(0.573, uf.getUtility(new CellWorldPosition(1, 3)),
        0.001);
    Assert.assertEquals(0.519, uf.getUtility(new CellWorldPosition(1, 4)),
        0.001);

    Assert.assertEquals(0.746, uf.getUtility(new CellWorldPosition(2, 1)),
        0.001);
    Assert.assertEquals(0.865, uf.getUtility(new CellWorldPosition(2, 3)),
        0.001);
    // assertEquals(-1.0, uf.getUtility(new
    // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never
    // gets to this square

    Assert.assertEquals(0.796, uf.getUtility(new CellWorldPosition(3, 1)),
        0.001);
    Assert.assertEquals(0.906, uf.getUtility(new CellWorldPosition(3, 3)),
        0.001);
    Assert.assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)),
        0.001);
  }

  @Test
  public void testPassiveTDAgent() {
    PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>(
        fourByThree, policy);
    // Randomizer r = new JavaRandomizer();
    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
        0.3, 0.7, 0.4, 0.6, 0.5 });
    MDPUtilityFunction<CellWorldPosition> uf = null;
    for (int i = 0; i < 200; i++) {
      agent.executeTrial(r);
      uf = agent.getUtilityFunction();
      // System.out.println(uf);

    }

    Assert.assertEquals(0.662, uf.getUtility(new CellWorldPosition(1, 1)),
        0.001);
    Assert.assertEquals(0.610, uf.getUtility(new CellWorldPosition(1, 2)),
        0.001);
    Assert.assertEquals(0.553, uf.getUtility(new CellWorldPosition(1, 3)),
        0.001);
    Assert.assertEquals(0.496, uf.getUtility(new CellWorldPosition(1, 4)),
        0.001);

    Assert.assertEquals(0.735, uf.getUtility(new CellWorldPosition(2, 1)),
        0.001);
    Assert.assertEquals(0.835, uf.getUtility(new CellWorldPosition(2, 3)),
        0.001);
    // assertEquals(-1.0, uf.getUtility(new
    // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never
    // gets to this square

    Assert.assertEquals(0.789, uf.getUtility(new CellWorldPosition(3, 1)),
        0.001);
    Assert.assertEquals(0.889, uf.getUtility(new CellWorldPosition(3, 3)),
        0.001);
    Assert.assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)),
        0.001);
  }

  @SuppressWarnings("unused")
  @Test
  public void testQLearningAgent() {
    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
        fourByThree);
    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
        0.3, 0.7, 0.4, 0.6, 0.5 });

    // Randomizer r = new JavaRandomizer();
    Hashtable<Pair<CellWorldPosition, String>, Double> q = null;
    QTable<CellWorldPosition, String> qTable = null;
    for (int i = 0; i < 100; i++) {
      qla.executeTrial(r);
      q = qla.getQ();
      qTable = qla.getQTable();

    }
    // qTable.normalize();
    // System.out.println(qTable);
    // System.out.println(qTable.getPolicy());
  }

  @Test
  public void testFirstStepsOfQLAAgentUnderNormalProbability() {
    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
        fourByThree);

    Randomizer alwaysLessThanEightyPercent = new MockRandomizer(
        new double[] { 0.7 });
    CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
    String action = qla.decideAction(new MDPPerception<CellWorldPosition>(
        startingPosition, -0.04));
    Assert.assertEquals(CellWorld.LEFT, action);
    Assert.assertEquals(0.0,
        qla.getQTable().getQValue(startingPosition, action), 0.001);

    qla.execute(action, alwaysLessThanEightyPercent);
    Assert.assertEquals(new CellWorldPosition(1, 3), qla.getCurrentState());
    Assert.assertEquals(-0.04, qla.getCurrentReward(), 0.001);
    Assert.assertEquals(0.0,
        qla.getQTable().getQValue(startingPosition, action), 0.001);
    qla.decideAction(new MDPPerception<CellWorldPosition>(
        new CellWorldPosition(1, 3), -0.04));

    Assert.assertEquals(-0.04,
        qla.getQTable().getQValue(startingPosition, action), 0.001);
  }

  @Test
  public void testFirstStepsOfQLAAgentWhenFirstStepTerminates() {
    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
        fourByThree);

    CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
    String action = qla.decideAction(new MDPPerception<CellWorldPosition>(
        startingPosition, -0.04));
    Assert.assertEquals(CellWorld.LEFT, action);

    Randomizer betweenEightyANdNinetyPercent = new MockRandomizer(
        new double[] { 0.85 }); // to force left to become an "up"
    qla.execute(action, betweenEightyANdNinetyPercent);
    Assert.assertEquals(new CellWorldPosition(2, 4), qla.getCurrentState());
    Assert.assertEquals(-1.0, qla.getCurrentReward(), 0.001);
    Assert.assertEquals(0.0,
        qla.getQTable().getQValue(startingPosition, action), 0.001);
    String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>(
        new CellWorldPosition(2, 4), -1));
    Assert.assertNull(action2);
    Assert.assertEquals(-1.0,
        qla.getQTable().getQValue(startingPosition, action), 0.001);
  }
}
TOP

Related Classes of aima.test.core.unit.learning.reinforcement.ReinforcementLearningTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.