/**
* Copyright (C) 2012 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import org.apache.commons.lang.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.MersenneTwister64;
import cern.jet.random.engine.RandomEngine;
import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.EuropeanVanillaOption;
import com.opengamma.analytics.financial.model.volatility.smile.fitting.SmileModelFitter;
import com.opengamma.analytics.financial.model.volatility.smile.function.SmileModelData;
import com.opengamma.analytics.financial.model.volatility.smile.function.VolatilityFunctionProvider;
import com.opengamma.analytics.math.MathException;
import com.opengamma.analytics.math.function.Function1D;
import com.opengamma.analytics.math.function.ParameterizedFunction;
import com.opengamma.analytics.math.matrix.DoubleMatrix1D;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResults;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.analytics.math.statistics.leastsquare.NonLinearLeastSquare;
import com.opengamma.util.ArgumentChecker;
/**
* Interpolate a smile, i.e. fit every data point (market volatility/price), by fitting a smile model (e.g. SABR) through consecutive sets of 3 strikes, so for N data points (prices) there will be N-2
* 3-point fits. In the interior where smile fits overlap, a weighting between the two smiles is taken, which varies from giving 100% weight to the left smile at the mid point of that fit, down to 0%
* at the mid point of the right fit.
*
* @param <T> The type of the smile model data
*/
public abstract class SmileInterpolator<T extends SmileModelData> implements GeneralSmileInterpolator {
private static final double FIT_ERROR = 1e-4; //1bps
private static final double LARGE_ERROR = 0.1;
private static final WeightingFunction DEFAULT_WEIGHTING_FUNCTION = WeightingFunctionFactory.SINE_WEIGHTING_FUNCTION;
/**
* The logger
*/
protected static final Logger s_logger = LoggerFactory.getLogger(SmileInterpolator.class);
private final VolatilityFunctionProvider<T> _model;
private final WeightingFunction _weightingFunction;
private final RandomEngine _random;
public SmileInterpolator(final VolatilityFunctionProvider<T> model) {
this(MersenneTwister.DEFAULT_SEED, model);
}
public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model) {
this(seed, model, DEFAULT_WEIGHTING_FUNCTION);
}
public SmileInterpolator(final VolatilityFunctionProvider<T> model, final WeightingFunction weightFunction) {
this(MersenneTwister.DEFAULT_SEED, model, weightFunction);
}
public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model, final WeightingFunction weightFunction) {
ArgumentChecker.notNull(model, "model");
ArgumentChecker.notNull(weightFunction, "weightFunction");
_random = new MersenneTwister64(seed);
_model = model;
_weightingFunction = weightFunction;
}
public List<T> getFittedModelParameters(final double forward, final double[] strikes, final double expiry, final double[] impliedVols) {
ArgumentChecker.notNull(strikes, "strikes");
ArgumentChecker.notNull(impliedVols, "implied volatilities");
final int n = strikes.length;
ArgumentChecker.isTrue(n > 2, "cannot fit less than three points; have {}", n);
ArgumentChecker.isTrue(impliedVols.length == n, "#strikes != # vols; have {} and {}", impliedVols.length, n);
validateStrikes(strikes);
final List<T> modelParameters = new ArrayList<>(n);
final double[] errors = new double[n];
Arrays.fill(errors, FIT_ERROR);
final SmileModelFitter<T> globalFitter = getFitter(forward, strikes, expiry, impliedVols, errors);
final BitSet gFixed = getGlobalFixedValues();
LeastSquareResultsWithTransform gBest = null;
double chiSqr = Double.POSITIVE_INFINITY;
//TODO set these in sub classes
int tries = 0;
int count = 0;
while (chiSqr > 100.0 * n && count < 5) { //10bps average error
final DoubleMatrix1D gStart = getGlobalStart(forward, strikes, expiry, impliedVols);
try {
final LeastSquareResultsWithTransform glsRes = globalFitter.solve(gStart, gFixed);
if (glsRes.getChiSq() < chiSqr) {
gBest = glsRes;
chiSqr = gBest.getChiSq();
}
count++;
} catch (final Exception e) {
}
tries++;
if (tries > 20) {
throw new MathException("Cannot fit data");
}
}
if (gBest == null) {
throw new IllegalStateException("Global estimate was null; should never happen");
}
if (n == 3) {
if (gBest.getChiSq() / n > 1.0) {
s_logger.debug("chi^2 on fit to ", +n + " points is " + gBest.getChiSq());
}
modelParameters.add(toSmileModelData(gBest.getModelParameters()));
} else {
final BitSet lFixed = getLocalFixedValues();
DoubleMatrix1D lStart = gBest.getModelParameters();
for (int i = 0; i < n - 2; i++) {
final double[][] temp = getStrikesVolsAndErrors(i, strikes, impliedVols, errors);
final double[] tStrikes = temp[0];
final double[] tVols = temp[1];
final double[] tErrors = temp[2];
final SmileModelFitter<T> localFitter = getFitter(forward, tStrikes, expiry, tVols, tErrors);
LeastSquareResultsWithTransform lRes = localFitter.solve(lStart, lFixed);
LeastSquareResultsWithTransform best = lRes;
count = 0;
while (lRes.getChiSq() > 3.0 && count < 10) {
lStart = getGlobalStart(forward, strikes, expiry, impliedVols);
lRes = localFitter.solve(lStart, lFixed);
if (lRes.getChiSq() < best.getChiSq()) {
best = lRes;
}
count++;
}
if (best.getChiSq() > 3.0) {
s_logger.debug("chi^2 on 3-point fit #" + i + " is " + best.getChiSq());
}
modelParameters.add(toSmileModelData(best.getModelParameters()));
}
}
return modelParameters;
}
/**
* Returns the random number generator for seeding any interpolation algorithms.
*
* @return the random number generator, not null
*/
protected RandomEngine getRandom() {
return _random;
}
public VolatilityFunctionProvider<T> getModel() {
return _model;
}
public WeightingFunction getWeightingFunction() {
return _weightingFunction;
}
protected double[][] getStrikesVolsAndErrors(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
return getStrikesVolsAndErrorsForThreePoints(index, strikes, impliedVols, errors);
}
/**
* Use this for models that can be expressed as having 3 parameters (e.g. SABR with beta fixed). It picks out 3 consecutive strike-volatility pairs for the 3 parameter fit (so the chi^2 should be
* zero if the model is capable of fitting the data)
*
* @param index Index of first strike
* @param strikes Array of all strikes
* @param impliedVols Array of all vols
* @param errors Array of all errors
* @return array containing the 3 strikes, vols and errors
*/
protected static double[][] getStrikesVolsAndErrorsForThreePoints(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
ArgumentChecker.notNull(strikes, "strikes");
ArgumentChecker.notNull(impliedVols, "implied vols");
ArgumentChecker.notNull(errors, "errors");
double[] tStrikes = new double[3];
double[] tVols = new double[3];
double[] tErrors = new double[3];
tStrikes = Arrays.copyOfRange(strikes, index, index + 3);
tVols = Arrays.copyOfRange(impliedVols, index, index + 3);
tErrors = Arrays.copyOfRange(errors, index, index + 3);
final double[][] res = new double[][] {tStrikes, tVols, tErrors };
return res;
}
/**
* Use this for models that cannot be easily expressed as having 3 parameters (e.g. mixed log-normal). It picks out 3 consecutive strikes and gives them a small error (1bps by default), while the
* rest of the data has a relatively large error (100bps by default). The fit is then made to all data (n > 3) which allows more than 3 parameters to be fitted (recall, the start position is set
* from a true global fit). The chi^2 should be close to zero if the model is capable of fitting the data.
*
* @param index Index of first strike
* @param strikes Array of all strikes
* @param impliedVols Array of all vols
* @param errors Array of all errors
* @return array containing the 3 strikes, vols and errors
*/
protected static double[][] getStrikesVolsAndErrorsForAllPoints(final int index, final double[] strikes, final double[] impliedVols, final double[] errors) {
ArgumentChecker.notNull(strikes, "strikes");
ArgumentChecker.notNull(impliedVols, "implied vols");
ArgumentChecker.notNull(errors, "errors");
final int n = errors.length;
final double[] lErrors = new double[n];
Arrays.fill(lErrors, LARGE_ERROR);
System.arraycopy(errors, index, lErrors, index, 3); //copy the original errors for the points we really want to fit
final double[][] res = new double[][] {strikes, impliedVols, lErrors };
return res;
}
protected abstract DoubleMatrix1D getGlobalStart(final double forward, final double[] strikes, final double expiry, final double[] impliedVols);
protected BitSet getGlobalFixedValues() {
return new BitSet();
}
protected BitSet getLocalFixedValues() {
return new BitSet();
}
protected abstract T toSmileModelData(final DoubleMatrix1D modelParameters); //TODO have the same thing in SmileModelFitter - could combine
protected abstract SmileModelFitter<T> getFitter(final double forward, final double[] strikes, final double expiry, final double[] impliedVols, final double[] errors);
@Override
public Function1D<Double, Double> getVolatilityFunction(final double forward, final double[] strikes, final double expiry, final double[] impliedVols) {
final List<T> modelParams = getFittedModelParameters(forward, strikes, expiry, impliedVols);
final int n = strikes.length;
return new Function1D<Double, Double>() {
@SuppressWarnings("synthetic-access")
@Override
public Double evaluate(final Double strike) {
final EuropeanVanillaOption option = new EuropeanVanillaOption(strike, expiry, true);
final Function1D<T, Double> volFunc = _model.getVolatilityFunction(option, forward);
final int index = SurfaceArrayUtils.getLowerBoundIndex(strikes, strike);
if (index == 0) {
return volFunc.evaluate(modelParams.get(0));
}
if (index >= n - 2) {
return volFunc.evaluate(modelParams.get(n - 3));
}
final double w = _weightingFunction.getWeight(strikes, index, strike);
if (w == 1) {
return volFunc.evaluate(modelParams.get(index - 1));
} else if (w == 0) {
return volFunc.evaluate(modelParams.get(index));
} else {
return w * volFunc.evaluate(modelParams.get(index - 1)) + (1 - w) * volFunc.evaluate(modelParams.get(index));
}
}
};
}
protected DoubleMatrix1D getPolynomialFit(final double forward, final double[] strikes, final double[] impliedVols) {
final int n = strikes.length;
final double[] x = new double[n];
for (int i = 0; i < n; i++) {
x[i] = Math.log(strikes[i] / forward);
}
final ParameterizedFunction<Double, DoubleMatrix1D, Double> func = new ParameterizedFunction<Double, DoubleMatrix1D, Double>() {
@Override
public Double evaluate(final Double x1, final DoubleMatrix1D parameters) {
final double a = parameters.getEntry(0);
final double b = parameters.getEntry(1);
final double c = parameters.getEntry(2);
return a + b * x1 + c * x1 * x1;
}
};
//TODO replace this with an explicit polynomial fitter
final NonLinearLeastSquare ls = new NonLinearLeastSquare();
final LeastSquareResults lsRes = ls.solve(new DoubleMatrix1D(x), new DoubleMatrix1D(impliedVols), func, new DoubleMatrix1D(0.1, 0.0, 0.0));
final DoubleMatrix1D fitP = lsRes.getFitParameters();
return fitP;
}
private void validateStrikes(final double[] strikes) {
final int n = strikes.length;
for (int i = 1; i < n; i++) {
ArgumentChecker.isTrue(strikes[i] > strikes[i - 1],
"strikes must be in ascending order; have {} (element {}) and {} (element {})", strikes[i - 1], i - 1, strikes[i], i);
}
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + _model.hashCode();
result = prime * result + _weightingFunction.hashCode();
return result;
}
@Override
public boolean equals(final Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final SmileInterpolator<?> other = (SmileInterpolator<?>) obj;
if (!ObjectUtils.equals(_model, other._model)) {
return false;
}
if (!ObjectUtils.equals(_weightingFunction, other._weightingFunction)) {
return false;
}
return true;
}
}