Package com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation

Source Code of com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation.SmileInterpolator

/**
* Copyright (C) 2012 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

import org.apache.commons.lang.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.MersenneTwister64;
import cern.jet.random.engine.RandomEngine;

import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.EuropeanVanillaOption;
import com.opengamma.analytics.financial.model.volatility.smile.fitting.SmileModelFitter;
import com.opengamma.analytics.financial.model.volatility.smile.function.SmileModelData;
import com.opengamma.analytics.financial.model.volatility.smile.function.VolatilityFunctionProvider;
import com.opengamma.analytics.math.MathException;
import com.opengamma.analytics.math.function.Function1D;
import com.opengamma.analytics.math.function.ParameterizedFunction;
import com.opengamma.analytics.math.matrix.DoubleMatrix1D;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResults;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.analytics.math.statistics.leastsquare.NonLinearLeastSquare;
import com.opengamma.util.ArgumentChecker;

/**
* Interpolate a smile, i.e. fit every data point (market volatility/price), by fitting a smile model (e.g. SABR) through consecutive sets of 3 strikes, so for N data points (prices) there will be N-2
* 3-point fits. In the interior where smile fits overlap, a weighting between the two smiles is taken, which varies from giving 100% weight to the left smile at the mid point of that fit, down to 0%
* at the mid point of the right fit.
*
* @param <T> The type of the smile model data
*/
public abstract class SmileInterpolator<T extends SmileModelData> implements GeneralSmileInterpolator {

  private static final double FIT_ERROR = 1e-4; //1bps
  private static final double LARGE_ERROR = 0.1;
  private static final WeightingFunction DEFAULT_WEIGHTING_FUNCTION = WeightingFunctionFactory.SINE_WEIGHTING_FUNCTION;

  /**
   * The logger
   */
  protected static final Logger s_logger = LoggerFactory.getLogger(SmileInterpolator.class);

  private final VolatilityFunctionProvider<T> _model;
  private final WeightingFunction _weightingFunction;
  private final RandomEngine _random;

  public SmileInterpolator(final VolatilityFunctionProvider<T> model) {
    this(MersenneTwister.DEFAULT_SEED, model);
  }

  public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model) {
    this(seed, model, DEFAULT_WEIGHTING_FUNCTION);
  }

  public SmileInterpolator(final VolatilityFunctionProvider<T> model, final WeightingFunction weightFunction) {
    this(MersenneTwister.DEFAULT_SEED, model, weightFunction);
  }

  public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model, final WeightingFunction weightFunction) {
    ArgumentChecker.notNull(model, "model");
    ArgumentChecker.notNull(weightFunction, "weightFunction");
    _random = new MersenneTwister64(seed);
    _model = model;
    _weightingFunction = weightFunction;
  }

  public List<T> getFittedModelParameters(final double forward, final double[] strikes, final double expiry, final double[] impliedVols) {
    ArgumentChecker.notNull(strikes, "strikes");
    ArgumentChecker.notNull(impliedVols, "implied volatilities");
    final int n = strikes.length;
    ArgumentChecker.isTrue(n > 2, "cannot fit less than three points; have {}", n);
    ArgumentChecker.isTrue(impliedVols.length == n, "#strikes != # vols; have {} and {}", impliedVols.length, n);
    validateStrikes(strikes);

    final List<T> modelParameters = new ArrayList<>(n);

    final double[] errors = new double[n];
    Arrays.fill(errors, FIT_ERROR);

    final SmileModelFitter<T> globalFitter = getFitter(forward, strikes, expiry, impliedVols, errors);
    final BitSet gFixed = getGlobalFixedValues();
    LeastSquareResultsWithTransform gBest = null;
    double chiSqr = Double.POSITIVE_INFINITY;

    //TODO set these in sub classes
    int tries = 0;
    int count = 0;
    while (chiSqr > 100.0 * n && count < 5) { //10bps average error
      final DoubleMatrix1D gStart = getGlobalStart(forward, strikes, expiry, impliedVols);
      try {
        final LeastSquareResultsWithTransform glsRes = globalFitter.solve(gStart, gFixed);
        if (glsRes.getChiSq() < chiSqr) {
          gBest = glsRes;
          chiSqr = gBest.getChiSq();
        }
        count++;
      } catch (final Exception e) {
      }
      tries++;
      if (tries > 20) {
        throw new MathException("Cannot fit data");
      }
    }
    if (gBest == null) {
      throw new IllegalStateException("Global estimate was null; should never happen");
    }
    if (n == 3) {
      if (gBest.getChiSq() / n > 1.0) {
        s_logger.debug("chi^2 on fit to ", +n + " points is " + gBest.getChiSq());
      }
      modelParameters.add(toSmileModelData(gBest.getModelParameters()));
    } else {
      final BitSet lFixed = getLocalFixedValues();
      DoubleMatrix1D lStart = gBest.getModelParameters();
      for (int i = 0; i < n - 2; i++) {
        final double[][] temp = getStrikesVolsAndErrors(i, strikes, impliedVols, errors);
        final double[] tStrikes = temp[0];
        final double[] tVols = temp[1];
        final double[] tErrors = temp[2];
        final SmileModelFitter<T> localFitter = getFitter(forward, tStrikes, expiry, tVols, tErrors);
        LeastSquareResultsWithTransform lRes = localFitter.solve(lStart, lFixed);
        LeastSquareResultsWithTransform best = lRes;

        count = 0;
        while (lRes.getChiSq() > 3.0 && count < 10) {
          lStart = getGlobalStart(forward, strikes, expiry, impliedVols);
          lRes = localFitter.solve(lStart, lFixed);
          if (lRes.getChiSq() < best.getChiSq()) {
            best = lRes;
          }
          count++;
        }

        if (best.getChiSq() > 3.0) {
          s_logger.debug("chi^2 on 3-point fit #" + i + " is " + best.getChiSq());
        }
        modelParameters.add(toSmileModelData(best.getModelParameters()));
      }
    }
    return modelParameters;
  }

  /**
   * Returns the random number generator for seeding any interpolation algorithms.
   *
   * @return the random number generator, not null
   */
  protected RandomEngine getRandom() {
    return _random;
  }

  public VolatilityFunctionProvider<T> getModel() {
    return _model;
  }

  public WeightingFunction getWeightingFunction() {
    return _weightingFunction;
  }

  protected double[][] getStrikesVolsAndErrors(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
    return getStrikesVolsAndErrorsForThreePoints(index, strikes, impliedVols, errors);
  }

  /**
   * Use this for models that can be expressed as having 3 parameters (e.g. SABR with beta fixed). It picks out 3 consecutive strike-volatility pairs for the 3 parameter fit (so the chi^2 should be
   * zero if the model is capable of fitting the data)
   *
   * @param index Index of first strike
   * @param strikes Array of all strikes
   * @param impliedVols Array of all vols
   * @param errors Array of all errors
   * @return array containing the 3 strikes, vols and errors
   */
  protected static double[][] getStrikesVolsAndErrorsForThreePoints(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
    ArgumentChecker.notNull(strikes, "strikes");
    ArgumentChecker.notNull(impliedVols, "implied vols");
    ArgumentChecker.notNull(errors, "errors");
    double[] tStrikes = new double[3];
    double[] tVols = new double[3];
    double[] tErrors = new double[3];
    tStrikes = Arrays.copyOfRange(strikes, index, index + 3);
    tVols = Arrays.copyOfRange(impliedVols, index, index + 3);
    tErrors = Arrays.copyOfRange(errors, index, index + 3);
    final double[][] res = new double[][] {tStrikes, tVols, tErrors };
    return res;
  }

  /**
   * Use this for models that cannot be easily expressed as having 3 parameters (e.g. mixed log-normal). It picks out 3 consecutive strikes and gives them a small error (1bps by default), while the
   * rest of the data has a relatively large error (100bps by default). The fit is then made to all data (n > 3) which allows more than 3 parameters to be fitted (recall, the start position is set
   * from a true global fit). The chi^2 should be close to zero if the model is capable of fitting the data.
   *
   * @param index Index of first strike
   * @param strikes Array of all strikes
   * @param impliedVols Array of all vols
   * @param errors Array of all errors
   * @return array containing the 3 strikes, vols and errors
   */
  protected static double[][] getStrikesVolsAndErrorsForAllPoints(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
    ArgumentChecker.notNull(strikes, "strikes");
    ArgumentChecker.notNull(impliedVols, "implied vols");
    ArgumentChecker.notNull(errors, "errors");
    final int n = errors.length;
    final double[] lErrors = new double[n];
    Arrays.fill(lErrors, LARGE_ERROR);
    System.arraycopy(errors, index, lErrors, index, 3); //copy the original errors for the points we really want to fit
    final double[][] res = new double[][] {strikes, impliedVols, lErrors };
    return res;
  }

  protected abstract DoubleMatrix1D getGlobalStart(final double forward, final double[] strikes, final double expiry, final double[] impliedVols);

  protected BitSet getGlobalFixedValues() {
    return new BitSet();
  }

  protected BitSet getLocalFixedValues() {
    return new BitSet();
  }

  protected abstract T toSmileModelData(final DoubleMatrix1D modelParameters); //TODO have the same thing in SmileModelFitter - could combine

  protected abstract SmileModelFitter<T> getFitter(final double forward, final double[] strikes, final double expiry, final double[] impliedVols, final double[] errors);

  @Override
  public Function1D<Double, Double> getVolatilityFunction(final double forward, final double[] strikes, final double expiry, final double[] impliedVols) {

    final List<T> modelParams = getFittedModelParameters(forward, strikes, expiry, impliedVols);
    final int n = strikes.length;

    return new Function1D<Double, Double>() {
      @SuppressWarnings("synthetic-access")
      @Override
      public Double evaluate(final Double strike) {
        final EuropeanVanillaOption option = new EuropeanVanillaOption(strike, expiry, true);
        final Function1D<T, Double> volFunc = _model.getVolatilityFunction(option, forward);
        final int index = SurfaceArrayUtils.getLowerBoundIndex(strikes, strike);
        if (index == 0) {
          return volFunc.evaluate(modelParams.get(0));
        }
        if (index >= n - 2) {
          return volFunc.evaluate(modelParams.get(n - 3));
        }
        final double w = _weightingFunction.getWeight(strikes, index, strike);
        if (w == 1) {
          return volFunc.evaluate(modelParams.get(index - 1));
        } else if (w == 0) {
          return volFunc.evaluate(modelParams.get(index));
        } else {
          return w * volFunc.evaluate(modelParams.get(index - 1)) + (1 - w) * volFunc.evaluate(modelParams.get(index));
        }
      }
    };
  }

  protected DoubleMatrix1D getPolynomialFit(final double forward, final double[] strikes, final double[] impliedVols) {
    final int n = strikes.length;
    final double[] x = new double[n];
    for (int i = 0; i < n; i++) {
      x[i] = Math.log(strikes[i] / forward);
    }

    final ParameterizedFunction<Double, DoubleMatrix1D, Double> func = new ParameterizedFunction<Double, DoubleMatrix1D, Double>() {
      @Override
      public Double evaluate(final Double x1, final DoubleMatrix1D parameters) {
        final double a = parameters.getEntry(0);
        final double b = parameters.getEntry(1);
        final double c = parameters.getEntry(2);
        return a + b * x1 + c * x1 * x1;
      }
    };

    //TODO replace this with an explicit polynomial fitter
    final NonLinearLeastSquare ls = new NonLinearLeastSquare();
    final LeastSquareResults lsRes = ls.solve(new DoubleMatrix1D(x), new DoubleMatrix1D(impliedVols), func, new DoubleMatrix1D(0.1, 0.0, 0.0));
    final DoubleMatrix1D fitP = lsRes.getFitParameters();
    return fitP;
  }

  private void validateStrikes(final double[] strikes) {
    final int n = strikes.length;
    for (int i = 1; i < n; i++) {
      ArgumentChecker.isTrue(strikes[i] > strikes[i - 1],
          "strikes must be in ascending order; have {} (element {}) and {} (element {})", strikes[i - 1], i - 1, strikes[i], i);
    }
  }

  @Override
  public int hashCode() {
    final int prime = 31;
    int result = 1;
    result = prime * result + _model.hashCode();
    result = prime * result + _weightingFunction.hashCode();
    return result;
  }

  @Override
  public boolean equals(final Object obj) {
    if (this == obj) {
      return true;
    }
    if (obj == null) {
      return false;
    }
    if (getClass() != obj.getClass()) {
      return false;
    }
    final SmileInterpolator<?> other = (SmileInterpolator<?>) obj;
    if (!ObjectUtils.equals(_model, other._model)) {
      return false;
    }
    if (!ObjectUtils.equals(_weightingFunction, other._weightingFunction)) {
      return false;
    }
    return true;
  }

}
TOP

Related Classes of com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation.SmileInterpolator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.