/**
* Copyright 2010 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.planner.benchmark;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import javax.imageio.ImageIO;
import com.thoughtworks.xstream.XStream;
import com.thoughtworks.xstream.annotations.XStreamAlias;
import com.thoughtworks.xstream.annotations.XStreamImplicit;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.drools.planner.config.localsearch.LocalSearchSolverConfig;
import org.drools.planner.core.Solver;
import org.drools.planner.core.score.Score;
import org.drools.planner.core.score.definition.ScoreDefinition;
import org.drools.planner.core.solution.Solution;
import org.drools.planner.benchmark.statistic.BestScoreStatistic;
import org.drools.planner.benchmark.statistic.SolverStatistic;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.labels.CategoryItemLabelGenerator;
import org.jfree.chart.labels.StandardCategoryItemLabelGenerator;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.category.CategoryItemRenderer;
import org.jfree.data.category.DefaultCategoryDataset;
/**
* @author Geoffrey De Smet
*/
@XStreamAlias("solverBenchmarkSuite")
public class SolverBenchmarkSuite {
public static final NumberFormat TIME_FORMAT = NumberFormat.getIntegerInstance(Locale.ENGLISH);
private File benchmarkDirectory = null;
private File solvedSolutionFilesDirectory = null;
private File solverStatisticFilesDirectory = null;
private SolverStatisticType solverStatisticType = SolverStatisticType.NONE;
private Comparator<SolverBenchmark> solverBenchmarkComparator = null;
@XStreamAlias("inheritedLocalSearchSolver")
private LocalSearchSolverConfig inheritedLocalSearchSolverConfig = null;
@XStreamImplicit(itemFieldName = "inheritedUnsolvedSolutionFile")
private List<File> inheritedUnsolvedSolutionFileList = null;
@XStreamImplicit(itemFieldName = "solverBenchmark")
private List<SolverBenchmark> solverBenchmarkList = null;
public File getBenchmarkDirectory() {
return benchmarkDirectory;
}
public void setBenchmarkDirectory(File benchmarkDirectory) {
this.benchmarkDirectory = benchmarkDirectory;
}
public File getSolvedSolutionFilesDirectory() {
return solvedSolutionFilesDirectory;
}
public void setSolvedSolutionFilesDirectory(File solvedSolutionFilesDirectory) {
this.solvedSolutionFilesDirectory = solvedSolutionFilesDirectory;
}
public File getSolverStatisticFilesDirectory() {
return solverStatisticFilesDirectory;
}
public void setSolverStatisticFilesDirectory(File solverStatisticFilesDirectory) {
this.solverStatisticFilesDirectory = solverStatisticFilesDirectory;
}
public SolverStatisticType getSolverStatisticType() {
return solverStatisticType;
}
public void setSolverStatisticType(SolverStatisticType solverStatisticType) {
this.solverStatisticType = solverStatisticType;
}
public Comparator<SolverBenchmark> getSolverBenchmarkComparator() {
return solverBenchmarkComparator;
}
public void setSolverBenchmarkComparator(Comparator<SolverBenchmark> solverBenchmarkComparator) {
this.solverBenchmarkComparator = solverBenchmarkComparator;
}
public LocalSearchSolverConfig getInheritedLocalSearchSolverConfig() {
return inheritedLocalSearchSolverConfig;
}
public void setInheritedLocalSearchSolverConfig(LocalSearchSolverConfig inheritedLocalSearchSolverConfig) {
this.inheritedLocalSearchSolverConfig = inheritedLocalSearchSolverConfig;
}
public List<File> getInheritedUnsolvedSolutionFileList() {
return inheritedUnsolvedSolutionFileList;
}
public void setInheritedUnsolvedSolutionFileList(List<File> inheritedUnsolvedSolutionFileList) {
this.inheritedUnsolvedSolutionFileList = inheritedUnsolvedSolutionFileList;
}
public List<SolverBenchmark> getSolverBenchmarkList() {
return solverBenchmarkList;
}
public void setSolverBenchmarkList(List<SolverBenchmark> solverBenchmarkList) {
this.solverBenchmarkList = solverBenchmarkList;
}
// ************************************************************************
// Builder methods
// ************************************************************************
public void benchmarkingStarted() {
Set<String> nameSet = new HashSet<String>(solverBenchmarkList.size());
Set<SolverBenchmark> noNameBenchmarkSet = new LinkedHashSet<SolverBenchmark>(solverBenchmarkList.size());
for (SolverBenchmark solverBenchmark : solverBenchmarkList) {
if (solverBenchmark.getName() != null) {
boolean unique = nameSet.add(solverBenchmark.getName());
if (!unique) {
throw new IllegalStateException("The benchmark name (" + solverBenchmark.getName()
+ ") is used in more than 1 benchmark.");
}
} else {
noNameBenchmarkSet.add(solverBenchmark);
}
if (inheritedLocalSearchSolverConfig != null) {
solverBenchmark.inheritLocalSearchSolverConfig(inheritedLocalSearchSolverConfig);
}
if (inheritedUnsolvedSolutionFileList != null) {
solverBenchmark.inheritUnsolvedSolutionFileList(inheritedUnsolvedSolutionFileList);
}
}
int generatedNameIndex = 0;
for (SolverBenchmark solverBenchmark : noNameBenchmarkSet) {
String generatedName = "Config_" + generatedNameIndex;
while (nameSet.contains(generatedName)) {
generatedNameIndex++;
generatedName = "Config_" + generatedNameIndex;
}
solverBenchmark.setName(generatedName);
generatedNameIndex++;
}
if (benchmarkDirectory == null) {
throw new IllegalArgumentException("The benchmarkDirectory (" + benchmarkDirectory + ") must not be null.");
}
benchmarkDirectory.mkdirs();
if (solvedSolutionFilesDirectory == null) {
solvedSolutionFilesDirectory = new File(benchmarkDirectory, "solved");
}
solvedSolutionFilesDirectory.mkdirs();
if (solverStatisticFilesDirectory == null) {
solverStatisticFilesDirectory = new File(benchmarkDirectory, "statistic");
}
solverStatisticFilesDirectory.mkdirs();
if (solverBenchmarkComparator == null) {
solverBenchmarkComparator = new TotalScoreSolverBenchmarkComparator();
}
}
public void benchmark(XStream xStream) { // TODO refactor out xstream
benchmarkingStarted();
// LinkedHashMap because order of unsolvedSolutionFile should be respected in output
Map<File, SolverStatistic> unsolvedSolutionFileToStatisticMap = new LinkedHashMap<File, SolverStatistic>();
for (SolverBenchmark solverBenchmark : solverBenchmarkList) {
Solver solver = solverBenchmark.getLocalSearchSolverConfig().buildSolver();
for (SolverBenchmarkResult result : solverBenchmark.getSolverBenchmarkResultList()) {
File unsolvedSolutionFile = result.getUnsolvedSolutionFile();
Solution unsolvedSolution = readUnsolvedSolution(xStream, unsolvedSolutionFile);
solver.setStartingSolution(unsolvedSolution);
if (solverStatisticType != SolverStatisticType.NONE) {
SolverStatistic statistic = unsolvedSolutionFileToStatisticMap.get(unsolvedSolutionFile);
if (statistic == null) {
statistic = solverStatisticType.create();
unsolvedSolutionFileToStatisticMap.put(unsolvedSolutionFile, statistic);
}
statistic.addListener(solver, solverBenchmark.getName());
}
solver.solve();
result.setTimeMillisSpend(solver.getTimeMillisSpend());
Solution solvedSolution = solver.getBestSolution();
result.setScore(solvedSolution.getScore());
if (solverStatisticType != SolverStatisticType.NONE) {
SolverStatistic statistic = unsolvedSolutionFileToStatisticMap.get(unsolvedSolutionFile);
statistic.removeListener(solver, solverBenchmark.getName());
}
writeSolvedSolution(xStream, solverBenchmark, result, solvedSolution);
}
}
benchmarkingEnded(xStream, unsolvedSolutionFileToStatisticMap);
}
private Solution readUnsolvedSolution(XStream xStream, File unsolvedSolutionFile) {
Solution unsolvedSolution;
Reader reader = null;
try {
reader = new InputStreamReader(new FileInputStream(unsolvedSolutionFile), "utf-8");
unsolvedSolution = (Solution) xStream.fromXML(reader);
} catch (IOException e) {
throw new IllegalArgumentException("Problem reading unsolvedSolutionFile: " + unsolvedSolutionFile, e);
} finally {
IOUtils.closeQuietly(reader);
}
return unsolvedSolution;
}
private void writeSolvedSolution(XStream xStream, SolverBenchmark solverBenchmark, SolverBenchmarkResult result,
Solution solvedSolution) {
if (solvedSolutionFilesDirectory == null) {
return;
}
File solvedSolutionFile = null;
String baseName = FilenameUtils.getBaseName(result.getUnsolvedSolutionFile().getName());
String solverBenchmarkName = solverBenchmark.getName().replaceAll(" ", "_").replaceAll("[^\\w\\d_\\-]", "");
String scoreString = result.getScore().toString().replaceAll("[\\/ ]", "_");
String timeString = TIME_FORMAT.format(result.getTimeMillisSpend()) + "ms";
solvedSolutionFile = new File(solvedSolutionFilesDirectory, baseName + "_" + solverBenchmarkName
+ "_score" + scoreString + "_time" + timeString + ".xml");
Writer writer = null;
try {
writer = new OutputStreamWriter(new FileOutputStream(solvedSolutionFile), "utf-8");
xStream.toXML(solvedSolution, writer);
} catch (IOException e) {
throw new IllegalArgumentException("Problem writing solvedSolutionFile: " + solvedSolutionFile, e);
} finally {
IOUtils.closeQuietly(writer);
}
}
public void benchmarkingEnded(XStream xStream, Map<File, SolverStatistic> unsolvedSolutionFileToStatisticMap) {
determineRankings();
writeBestScoreSummaryChart();
// 2 lines at 80 chars per line give a max of 160 per entry
StringBuilder htmlFragment = new StringBuilder(unsolvedSolutionFileToStatisticMap.size() * 160);
htmlFragment.append(" <h1>Summary</h1>\n");
htmlFragment.append(" <h2>Summary chart</h2>\n");
htmlFragment.append(writeBestScoreSummaryChart());
htmlFragment.append(" <h2>Summary table</h2>\n");
htmlFragment.append(writeBestScoreSummaryTable());
htmlFragment.append(" <h1>Statistic ").append(solverStatisticType.toString()).append("</h1>\n");
for (Map.Entry<File, SolverStatistic> entry : unsolvedSolutionFileToStatisticMap.entrySet()) {
File unsolvedSolutionFile = entry.getKey();
SolverStatistic statistic = entry.getValue();
String baseName = FilenameUtils.getBaseName(unsolvedSolutionFile.getName());
htmlFragment.append(" <h2>").append(baseName).append("</h2>\n");
htmlFragment.append(statistic.writeStatistic(solverStatisticFilesDirectory, baseName));
}
writeHtmlOverview(htmlFragment);
writeBenchmarkResult(xStream);
}
private void determineRankings() {
List<SolverBenchmark> sortedSolverBenchmarkList = new ArrayList<SolverBenchmark>(solverBenchmarkList);
Collections.sort(sortedSolverBenchmarkList, solverBenchmarkComparator);
Collections.reverse(sortedSolverBenchmarkList); // Best results first, worst results last
for (SolverBenchmark solverBenchmark : solverBenchmarkList) {
solverBenchmark.setRanking(sortedSolverBenchmarkList.indexOf(solverBenchmark));
}
}
private CharSequence writeBestScoreSummaryChart() {
DefaultCategoryDataset dataset = new DefaultCategoryDataset();
for (SolverBenchmark solverBenchmark : solverBenchmarkList) {
ScoreDefinition scoreDefinition = solverBenchmark.getLocalSearchSolverConfig().getScoreDefinitionConfig()
.buildScoreDefinition();
for (SolverBenchmarkResult result : solverBenchmark.getSolverBenchmarkResultList()) {
Score score = result.getScore();
Double scoreGraphValue = scoreDefinition.translateScoreToGraphValue(score);
String solverLabel = solverBenchmark.getName();
if (solverBenchmark.getRanking() == 0) {
solverLabel += " (winner)";
}
dataset.addValue(scoreGraphValue, solverLabel, result.getUnsolvedSolutionFile().getName());
}
}
JFreeChart chart = ChartFactory.createBarChart(
"Best score summary (higher score is better)", "Data", "Score",
dataset, PlotOrientation.VERTICAL, true, true, false
);
CategoryItemRenderer renderer = ((CategoryPlot) chart.getPlot()).getRenderer();
CategoryItemLabelGenerator generator = new StandardCategoryItemLabelGenerator();
renderer.setBaseItemLabelGenerator(generator);
renderer.setBaseItemLabelsVisible(true);
BufferedImage chartImage = chart.createBufferedImage(1024, 768);
File chartSummaryFile = new File(solverStatisticFilesDirectory, "summary.png");
OutputStream out = null;
try {
out = new FileOutputStream(chartSummaryFile);
ImageIO.write(chartImage, "png", out);
} catch (IOException e) {
throw new IllegalArgumentException("Problem writing graphStatisticFile: " + chartSummaryFile, e);
} finally {
IOUtils.closeQuietly(out);
}
return " <img src=\"" + chartSummaryFile.getName() + "\"/>\n";
}
private CharSequence writeBestScoreSummaryTable() {
StringBuilder htmlFragment = new StringBuilder(solverBenchmarkList.size() * 160);
htmlFragment.append(" <table border=\"1\">\n");
htmlFragment.append(" <tr><th/>");
for (File unsolvedSolutionFile : inheritedUnsolvedSolutionFileList) {
htmlFragment.append("<th>").append(unsolvedSolutionFile.getName()).append("</th>");
}
htmlFragment.append("<th>Average</th><th>Ranking</th></tr>\n");
boolean oddLine = true;
for (SolverBenchmark solverBenchmark : solverBenchmarkList) {
String backgroundColor = solverBenchmark.getRanking() == 0 ? "Yellow" : oddLine ? "White" : "Gray";
htmlFragment.append(" <tr style=\"background-color: ").append(backgroundColor).append("\"><th>")
.append(solverBenchmark.getName()).append("</th>");
for (SolverBenchmarkResult result : solverBenchmark.getSolverBenchmarkResultList()) {
Score score = result.getScore();
htmlFragment.append("<td>").append(score.toString()).append("</td>");
}
htmlFragment.append("<td>").append(solverBenchmark.getAverageScore().toString())
.append("</td><td>").append(solverBenchmark.getRanking()).append("</td>");
htmlFragment.append("</tr>\n");
oddLine = !oddLine;
}
htmlFragment.append(" </table>\n");
return htmlFragment.toString();
}
private void writeHtmlOverview(CharSequence htmlFragment) {
File htmlOverviewFile = new File(solverStatisticFilesDirectory, "index.html");
Writer writer = null;
try {
writer = new OutputStreamWriter(new FileOutputStream(htmlOverviewFile), "utf-8");
writer.append("<html>\n");
writer.append("<head>\n");
writer.append(" <title>Statistic</title>\n");
writer.append("</head>\n");
writer.append("<body>\n");
writer.append(htmlFragment);
writer.append("</body>\n");
writer.append("</html>\n");
} catch (IOException e) {
throw new IllegalArgumentException("Problem writing htmlOverviewFile: " + htmlOverviewFile, e);
} finally {
IOUtils.closeQuietly(writer);
}
}
public void writeBenchmarkResult(XStream xStream) {
File benchmarkResultFile = new File(benchmarkDirectory, "benchmarkResult.xml");
OutputStreamWriter writer = null;
try {
writer = new OutputStreamWriter(new FileOutputStream(benchmarkResultFile), "utf-8");
xStream.toXML(this, writer);
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException("This JVM does not support utf-8 encoding.", e);
} catch (FileNotFoundException e) {
throw new IllegalArgumentException(
"Could not create benchmarkResultFile (" + benchmarkResultFile + ").", e);
} finally {
IOUtils.closeQuietly(writer);
}
}
public static enum SolverStatisticType {
NONE,
BEST_SOLUTION_CHANGED;
public SolverStatistic create() {
switch (this) {
case NONE:
return null;
case BEST_SOLUTION_CHANGED:
return new BestScoreStatistic();
default:
throw new IllegalStateException("The solverStatisticType (" + this + ") is not implemented");
}
}
}
}