/*
* Encog(tm) Workbench v3.0
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2011 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.workbench.tabs.visualize.structure;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Paint;
import java.awt.Point;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JPanel;
import javax.swing.border.Border;
import org.apache.commons.collections15.Transformer;
import org.encog.ml.MLMethod;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.neat.NEATLink;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATNeuron;
import org.encog.neural.networks.BasicNetwork;
import org.encog.workbench.WorkBenchError;
import org.encog.workbench.tabs.EncogCommonTab;
import edu.uci.ics.jung.algorithms.layout.StaticLayout;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.SparseMultigraph;
import edu.uci.ics.jung.graph.util.EdgeType;
import edu.uci.ics.jung.visualization.GraphZoomScrollPane;
import edu.uci.ics.jung.visualization.Layer;
import edu.uci.ics.jung.visualization.VisualizationViewer;
import edu.uci.ics.jung.visualization.control.AbstractModalGraphMouse;
import edu.uci.ics.jung.visualization.control.CrossoverScalingControl;
import edu.uci.ics.jung.visualization.control.DefaultModalGraphMouse;
import edu.uci.ics.jung.visualization.control.ScalingControl;
import edu.uci.ics.jung.visualization.decorators.ToStringLabeller;
import edu.uci.ics.jung.visualization.renderers.Renderer;
public class StructureTab extends EncogCommonTab {
private VisualizationViewer<DrawnNeuron, DrawnConnection> vv;
public StructureTab(MLMethod method) {
super(null);
// Graph<V, E> where V is the type of the vertices
// and E is the type of the edges
Graph<DrawnNeuron, DrawnConnection> g = null;
if( method instanceof BasicNetwork ) {
BasicNetwork network = (BasicNetwork)method;
g = buildGraph(network.getStructure().getFlat());
} else if( method instanceof NEATNetwork ) {
NEATNetwork neat = (NEATNetwork)method;
g = buildGraph(neat);
}
if( g==null ) {
throw new WorkBenchError("Can't visualize network: " + method.getClass().getSimpleName());
}
Transformer<DrawnNeuron, Point2D> staticTranformer = new Transformer<DrawnNeuron, Point2D>() {
public Point2D transform(DrawnNeuron n) {
int x = (int) (n.getX() * 600);
int y = (int) (n.getY() * 300);
Point2D result = new Point(x + 32, y);
return result;
}
};
Transformer<DrawnNeuron, Paint> vertexPaint = new Transformer<DrawnNeuron, Paint>() {
public Paint transform(DrawnNeuron neuron) {
switch (neuron.getType()) {
case Bias:
return Color.yellow;
case Input:
return Color.white;
case Output:
return Color.green;
case Context:
return Color.cyan;
default:
return Color.red;
}
}
};
Transformer<DrawnConnection, Paint> edgePaint = new Transformer<DrawnConnection, Paint>() {
public Paint transform(DrawnConnection connection) {
if( connection.isContext() ) {
return Color.lightGray;
} else {
return Color.black;
}
}
};
// The Layout<V, E> is parameterized by the vertex and edge types
StaticLayout<DrawnNeuron, DrawnConnection> layout = new StaticLayout<DrawnNeuron, DrawnConnection>(
g, staticTranformer);
layout.setSize(new Dimension(5000,5000)); // sets the initial size of the space
// The BasicVisualizationServer<V,E> is parameterized by the edge types
//BasicVisualizationServer<DrawnNeuron, DrawnConnection> vv = new BasicVisualizationServer<DrawnNeuron, DrawnConnection>(
// layout);
//Dimension d = new Dimension(600,600);
vv = new VisualizationViewer<DrawnNeuron, DrawnConnection>(layout);
//vv.setPreferredSize(d); //Sets the viewing area size
vv.getRenderer().getVertexLabelRenderer()
.setPosition(Renderer.VertexLabel.Position.CNTR);
vv.getRenderContext().setVertexLabelTransformer(new ToStringLabeller());
vv.getRenderContext().setVertexFillPaintTransformer(vertexPaint);
vv.getRenderContext().setEdgeDrawPaintTransformer(edgePaint);
vv.getRenderContext().setArrowDrawPaintTransformer(edgePaint);
vv.getRenderContext().setArrowFillPaintTransformer(edgePaint);
vv.setVertexToolTipTransformer(new ToStringLabeller());
vv.setVertexToolTipTransformer(new Transformer<DrawnNeuron,String>() {
public String transform(DrawnNeuron edge) {
return edge.getToolTip();
}});
vv.setEdgeToolTipTransformer(new Transformer<DrawnConnection,String>() {
public String transform(DrawnConnection edge) {
return edge.getToolTip();
}});
final GraphZoomScrollPane panel = new GraphZoomScrollPane(vv);
this.setLayout(new BorderLayout());
add(panel, BorderLayout.CENTER);
final AbstractModalGraphMouse graphMouse = new DefaultModalGraphMouse();
vv.setGraphMouse(graphMouse);
vv.addKeyListener(graphMouse.getModeKeyListener());
final ScalingControl scaler = new CrossoverScalingControl();
JButton plus = new JButton("+");
plus.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
scaler.scale(vv, 1.1f, vv.getCenter());
}
});
JButton minus = new JButton("-");
minus.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
scaler.scale(vv, 1/1.1f, vv.getCenter());
}
});
JButton reset = new JButton("reset");
reset.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.LAYOUT).setToIdentity();
vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.VIEW).setToIdentity();
}});
JPanel controls = new JPanel();
controls.setLayout(new FlowLayout(FlowLayout.LEFT));
controls.add(plus);
controls.add(minus);
controls.add(reset);
Border border = BorderFactory.createEtchedBorder();
controls.setBorder(border);
add(controls, BorderLayout.NORTH);
}
private Graph<DrawnNeuron, DrawnConnection> buildGraph(NEATNetwork neat) {
int inputCount = 1;
int outputCount = 1;
int hiddenCount = 1;
int biasCount = 1;
List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>();
Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>();
List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>();
Map<NEATNeuron,DrawnNeuron> neuronMap = new HashMap<NEATNeuron,DrawnNeuron>();
// place all the neurons
for(NEATNeuron neatNeuron : neat.getNeurons() ) {
String name="";
DrawnNeuronType t = DrawnNeuronType.Hidden;
switch(neatNeuron.getNeuronType()) {
case Bias:
t = DrawnNeuronType.Bias;
name="B"+(biasCount++);
break;
case Input:
t = DrawnNeuronType.Input;
name="I"+(inputCount++);
break;
case Output:
t = DrawnNeuronType.Output;
name="O"+(outputCount++);
break;
case Hidden:
t = DrawnNeuronType.Hidden;
name="H"+(hiddenCount++);
break;
}
DrawnNeuron neuron = new DrawnNeuron(t, name, neatNeuron.getSplitX(), neatNeuron.getSplitY());
neurons.add(neuron);
neuronMap.put(neatNeuron, neuron);
}
// place all the connections
for(NEATNeuron neatNeuron : neat.getNeurons() ) {
for(NEATLink neatLink: neatNeuron.getOutputboundLinks() ) {
DrawnNeuron fromNeuron = neuronMap.get(neatLink.getFromNeuron());
DrawnNeuron toNeuron = neuronMap.get(neatLink.getToNeuron());
DrawnConnection connection = new DrawnConnection(fromNeuron,toNeuron,neatLink.getWeight());
fromNeuron.getOutbound().add(connection);
toNeuron.getInbound().add(connection);
}
}
for (DrawnNeuron neuron : neurons) {
result.addVertex(neuron);
for (DrawnConnection connection : neuron.getOutbound()) {
result.addEdge(connection, connection.getFrom(),
connection.getTo(), EdgeType.DIRECTED);
}
}
return result;
}
public Graph<DrawnNeuron, DrawnConnection> buildGraph(FlatNetwork flat) {
int inputCount = 1;
int outputCount = 1;
int hiddenCount = 1;
int biasCount = 1;
int contextCount = 1;
int layerCount = flat.getLayerCounts().length;
List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>();
Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>();
List<DrawnNeuron> lastFedNeurons;
List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>();
double layerSize = 1.0/layerCount;
int neuronNumber = 1;
for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) {
lastFedNeurons = new ArrayList<DrawnNeuron>();
double x = (double) (layerCount - currentLayer - 1)
/ (double) layerCount;
int neuronCount = flat.getLayerCounts()[currentLayer];
int feedCount = flat.getLayerFeedCounts()[currentLayer];
for (int currentNeuron = 0; currentNeuron < neuronCount; currentNeuron++) {
DrawnNeuronType type;
double xOffset = 0;
String name = "?";
// not a bias or context
if (currentNeuron < feedCount) {
if (currentLayer == 0) {
type = DrawnNeuronType.Output;
name = "O" + (outputCount++);
} else if (currentLayer == (layerCount - 1)) {
type = DrawnNeuronType.Input;
name = "I" + (inputCount++);
} else {
type = DrawnNeuronType.Hidden;
name = "H" + (hiddenCount++);
}
}
// is a bias
else if (currentNeuron == feedCount) {
type = DrawnNeuronType.Bias;
name = "B" + (biasCount++);
}
// is a context
else {
type = DrawnNeuronType.Context;
name = "C" + (contextCount++);
xOffset=layerSize/4;
}
double y = (double) currentNeuron / (double) neuronCount;
double margin = ((double) (neuronCount - 1) / (double) neuronCount);
margin = 1.0 - margin;
margin /= 2.0;
DrawnNeuron neuron = new DrawnNeuron(type, name, x+xOffset, y + margin);
neurons.add(neuron);
if (neuron.getType() == DrawnNeuronType.Hidden
|| neuron.getType() == DrawnNeuronType.Output) {
lastFedNeurons.add(neuron);
}
int toNeuron = 0;
int count = connections.size();
for (DrawnNeuron connectTo : connections) {
int weightIndex = flat.getLayerIndex()[currentLayer]+(toNeuron*count)+currentNeuron;
double w = 0;// this.flat.getWeights()[weightIndex];
DrawnConnection connection = new DrawnConnection(neuron, connectTo, w);
neuron.getOutbound().add(connection);
neuron.getInbound().add(connection);
toNeuron++;
}
}
connections = lastFedNeurons;
}
for (DrawnNeuron neuron : neurons) {
result.addVertex(neuron);
for (DrawnConnection connection : neuron.getOutbound()) {
result.addEdge(connection, connection.getFrom(),
connection.getTo(), EdgeType.DIRECTED);
}
}
// draw context links
for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) {
if( flat.getContextTargetSize()[currentLayer]>0 ) {
int count = flat.getContextTargetSize()[currentLayer];
int offset = flat.getContextTargetOffset()[currentLayer];
int source = flat.getLayerIndex()[currentLayer];
for(int i=0;i<count;i++) {
DrawnNeuron n1 = neurons.get(source+i);
DrawnNeuron n2 = neurons.get(offset+i);
DrawnConnection connection = new DrawnConnection(n1, n2, 0);
result.addEdge(connection, connection.getFrom(),
connection.getTo(), EdgeType.DIRECTED);
connection.setContext(true);
}
}
}
return result;
}
@Override
public String getName() {
return "Structure: " + this.getEncogObject().getName();
}
}