Package org.apache.commons.math.optimization

Source Code of org.apache.commons.math.optimization.MultiDirectionalTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.commons.math.optimization;

import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.MultiDirectional;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;

import junit.framework.*;

public class MultiDirectionalTest
  extends TestCase {

  public MultiDirectionalTest(String name) {
    super(name);
  }

  public void testCostExceptions() throws ConvergenceException {
      CostFunction wrong =
          new CostFunction() {
            public double cost(double[] x) throws CostException {
                if (x[0] < 0) {
                    throw new CostException("{0}", new Object[] { "oops"});
                } else if (x[0] > 1) {
                    throw new CostException(new RuntimeException("oops"));
                } else {
                    return x[0] * (1 - x[0]);
                }
            }
      };
      try {
          new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
                                                  new double[] { -0.5 }, new double[] { 0.5 });
          fail("an exception should have been thrown");
      } catch (CostException ce) {
          // expected behavior
          assertNull(ce.getCause());
      } catch (Exception e) {
          fail("wrong exception caught: " + e.getMessage());
      }
      try {
          new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
                                                  new double[] { 0.5 }, new double[] { 1.5 });
          fail("an exception should have been thrown");
      } catch (CostException ce) {
          // expected behavior
          assertNotNull(ce.getCause());
      } catch (Exception e) {
          fail("wrong exception caught: " + e.getMessage());
      }
  }

  public void testRosenbrock()
    throws CostException, ConvergenceException {

    CostFunction rosenbrock =
      new CostFunction() {
        public double cost(double[] x) {
          ++count;
          double a = x[1] - x[0] * x[0];
          double b = 1.0 - x[0];
          return 100 * a * a + b * b;
        }
      };

    count = 0;
    PointCostPair optimum =
      new MultiDirectional().minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
                                      new double[][] {
                                        { -1.21.0 }, { 0.9, 1.2 } , 3.5, -2.3 }
                                      });

    assertTrue(count > 60);
    assertTrue(optimum.getCost() > 0.01);

  }

  public void testPowell()
    throws CostException, ConvergenceException {

    CostFunction powell =
      new CostFunction() {
        public double cost(double[] x) {
          ++count;
          double a = x[0] + 10 * x[1];
          double b = x[2] - x[3];
          double c = x[1] - 2 * x[2];
          double d = x[0] - x[3];
          return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
        }
      };

    count = 0;
    PointCostPair optimum =
      new MultiDirectional().minimize(powell, 1000, new ValueChecker(1.0e-3),
                                      new double[] {  3.0, -1.0, 0.0, 1.0 },
                                      new double[] {  4.00.0, 1.0, 2.0 });
    assertTrue(count > 850);
    assertTrue(optimum.getCost() > 0.015);

  }

  private static class ValueChecker implements ConvergenceChecker {

    public ValueChecker(double threshold) {
      this.threshold = threshold;
    }

    public boolean converged(PointCostPair[] simplex) {
      PointCostPair smallest = simplex[0];
      PointCostPair largest  = simplex[simplex.length - 1];
      return (largest.getCost() - smallest.getCost()) < threshold;
    }

    private double threshold;

  };

  public static Test suite() {
    return new TestSuite(MultiDirectionalTest.class);
  }

  private int count;

}
TOP

Related Classes of org.apache.commons.math.optimization.MultiDirectionalTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.