Package net.javlov.policy

Source Code of net.javlov.policy.EGreedyPolicy

/*
* Javlov - a Java toolkit for reinforcement learning with multi-agent support.
*
* Copyright (c) 2009-2011 Matthijs Snel
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
package net.javlov.policy;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import net.javlov.Option;
import net.javlov.Policy;
import net.javlov.QFunction;
import net.javlov.State;
import net.javlov.util.ArrayUtil;

/**
* Epsilon-greedy (e-greedy) policy. Chooses action with maximum Q-value with probability
* 1 - e, and a (uniform) random action with probability e. This means that the greedy action
* will be chosen with probability 1 - e + e/(nr of actions).
* @author Matthijs Snel
*
*/
public class EGreedyPolicy implements Policy {

  /**
   * The Q-value function.
   */
  private QFunction q;
 
  /**
   * Epsilon.
   */
  protected double e;
 
  /**
   * List of allowed actions. Index of actions in the list should correspond to their
   * ID.
   */
  protected List<? extends Option> optionPool;
 
  /**
   * Random number generator.
   */
  private Random rng;
 
  /**
   * Creates an epsilon-greedy policy with the specified Q-function, epsilon, and pool
   * of allowed actions.
   *
   * @param q the Q-value function.
   * @param epsilon action with maximum Q-value is chosen with probability
   * 1 - e, and a (uniform) random action with probability e.
   * @param actions list of allowed actions. Index of actions in the list should correspond to their
   * ID.
   */
  public EGreedyPolicy(QFunction q, double epsilon, List<? extends Option> options) {
    setQFunction(q);
    setEpsilon(epsilon);
    optionPool = options;
    rng = new Random();
  }

  /**
   * Creates an epsilon-greedy policy with the specified Q-function, epsilon, and pool
   * of allowed actions.
   *
   * @param q the Q-value function.
   * @param epsilon action with maximum Q-value is chosen with probability
   * 1 - e, and a (uniform) random action with probability e.
   * @param actions list of allowed actions. Index of actions in the list should correspond to their
   * ID.
   * @param rng random number generator used to pick actions.
   */
  public EGreedyPolicy(QFunction q, double epsilon, List<? extends Option> options, Random rng) {
    setQFunction(q);
    setEpsilon(epsilon);
    optionPool = options;
    this.rng = rng;
  }
 
  public void setQFunction(QFunction q) {
    this.q = q;
  }

  public QFunction getQFunction() {
    return q;
  }
 
  public double getEpsilon() {
    return e;
  }

  public void setEpsilon(double epsilon) {
    e = epsilon;
  }

  /**
   * Chooses option with maximum Q-value (greedy option) with probability
   * 1 - e, and a (uniformly distributed) random option with probability e.
   * This means that the greedy option will be chosen with probability
   * {@code 1 - e + e/(nr of options)}. If there is more than one greedy option, ties
   * are broken randomly.
   *
   * @param s the state based on which to choose the option. The Q-value function will
   * be queried first to determine the Q-values of the options for this state, after which
   * the option will be determined by calling {@link #pickNewOption(State, double[])}.
   * @return an {@code Option} chosen according to the rule as specified above.
   */
  @Override
  public <T> Option getOption(State<T> s) {
    return getOption(s, q.getValues(s));
  }

  //TODO inefficient implementation of determining this
  protected <T> List<Option> getStateOptionSet(State<T> s) {
    List<Option> stateOptionSet = new ArrayList<Option>(optionPool.size());
    for ( Option o : optionPool )
      if ( o.isEligible(s) )
        stateOptionSet.add(o);
   
    if ( stateOptionSet.size() == 0 )
      throw new RuntimeException("No eligible options for state: " + s);
   
    return stateOptionSet;
  }
 
  @Override
  public <T> Option getOption(State<T> s, double[] qvalues) {
    List<Option> stateOptionSet = getStateOptionSet(s);
   
    //System.out.println(stateOptionSet + "--" + Arrays.toString(qvalues));
   
    if ( stateOptionSet == null || stateOptionSet.size() == qvalues.length ) {
      //choose greedy option, randomly break ties if there is more than one max option
      if ( rng.nextDouble() > e ) {
        //get indices of options with max Q-value (returns 1 or more)
        int a[] = ArrayUtil.multimaxIndex(qvalues);
        if ( a.length < 1 ) {
          throw new RuntimeException("Impossible: " + stateOptionSet.size() + "," + Arrays.toString(qvalues) );
        }
        //System.out.println(Arrays.toString(a));
        return optionPool.get( a[rng.nextInt(a.length)] );
      }
      //choose random action
      return optionPool.get( rng.nextInt(qvalues.length) );
    }
    else {
      if ( rng.nextDouble() > e ) {
        List<Option> maxOpts = getMaxOpts(stateOptionSet, qvalues);
        //System.out.println(maxOpts);
        Option ret = ( maxOpts.size() == 1 ? maxOpts.get(0) : maxOpts.get( rng.nextInt(maxOpts.size()) ) );
        return ret;
      }
      Option ret = stateOptionSet.get( rng.nextInt(stateOptionSet.size()) );
      return ret;
    }
  }

  @Override
  public <T> double[] getOptionProbabilities( State<T> s, double[] qvalues ) {
    List<? extends Option> stateOptionSet = getStateOptionSet(s);
    double[] probs = new double[qvalues.length];

    if ( stateOptionSet == null || stateOptionSet.size() == qvalues.length ) {
      int a[] = ArrayUtil.multimaxIndex(qvalues);
      double otherProb = e / qvalues.length,
          maxProb = (1 - e) / a.length + otherProb;
     
      for ( int i = 0; i < a.length; i++ )
        probs[a[i]] = maxProb;
      for ( int i = 0; i < qvalues.length; i++ )
        if ( probs[i] == 0 )
          probs[i] = otherProb;
    }
    else {
      List<Option> maxOpts = getMaxOpts(stateOptionSet, qvalues);
      double otherProb = e / stateOptionSet.size(),
          maxProb = (1 - e) / maxOpts.size() + otherProb;
      stateOptionSet.removeAll(maxOpts);
     
      for ( Option opt : maxOpts )
        probs[opt.getID()] = maxProb;
      for ( Option opt : stateOptionSet )
        probs[opt.getID()] = otherProb;
    }
    return probs;
  }
 
  protected List<Option> getMaxOpts(List<? extends Option> stateOptionSet, double[] qvalues) {
    List<Option> maxOpts = new ArrayList<Option>();
    double maxVal = Double.NEGATIVE_INFINITY,
        val;
    for ( Option o : stateOptionSet ) {
      val = qvalues[o.getID()];
      if ( val > maxVal ) {
        maxVal = val;
        maxOpts.clear();
        maxOpts.add(o);
      } else if ( val == maxVal )
        maxOpts.add(o);
    }
    return maxOpts;
  }
 
  @Override
  public void init() {
    for ( Option o : optionPool )
      o.init();
  }

  @Override
  public void reset() {
    for ( Option o : optionPool )
      o.reset();
  }
}
TOP

Related Classes of net.javlov.policy.EGreedyPolicy

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.