Package org.encog.workbench.tabs.mlmethod

Source Code of org.encog.workbench.tabs.mlmethod.MLMethodTab

/*
* Encog(tm) Workbench v3.0
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2011 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.workbench.tabs.mlmethod;

import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import java.util.List;

import javax.swing.JButton;
import javax.swing.JEditorPane;
import javax.swing.JScrollPane;
import javax.swing.JToolBar;

import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.mathutil.randomize.ConsistentRandomizer;
import org.encog.mathutil.randomize.ConstRandomizer;
import org.encog.mathutil.randomize.Distort;
import org.encog.mathutil.randomize.FanInRandomizer;
import org.encog.mathutil.randomize.GaussianRandomizer;
import org.encog.mathutil.randomize.NguyenWidrowRandomizer;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.mathutil.rbf.RadialBasisFunction;
import org.encog.ml.MLClassification;
import org.encog.ml.MLContext;
import org.encog.ml.MLEncodable;
import org.encog.ml.MLInput;
import org.encog.ml.MLMethod;
import org.encog.ml.MLOutput;
import org.encog.ml.MLProperties;
import org.encog.ml.MLRegression;
import org.encog.ml.MLResettable;
import org.encog.neural.cpn.CPN;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.pattern.FeedForwardPattern;
import org.encog.neural.pattern.HopfieldPattern;
import org.encog.neural.prune.PruneSelective;
import org.encog.neural.rbf.RBFNetwork;
import org.encog.neural.thermal.HopfieldNetwork;
import org.encog.neural.thermal.ThermalNetwork;
import org.encog.util.Format;
import org.encog.util.HTMLReport;
import org.encog.workbench.EncogWorkBench;
import org.encog.workbench.WorkBenchError;
import org.encog.workbench.dialogs.RandomizeNetworkDialog;
import org.encog.workbench.dialogs.createnetwork.CreateFeedforward;
import org.encog.workbench.dialogs.createnetwork.CreateHopfieldDialog;
import org.encog.workbench.dialogs.select.SelectDialog;
import org.encog.workbench.dialogs.select.SelectItem;
import org.encog.workbench.frames.MapDataFrame;
import org.encog.workbench.frames.document.tree.ProjectEGFile;
import org.encog.workbench.process.TrainBasicNetwork;
import org.encog.workbench.tabs.EncogCommonTab;
import org.encog.workbench.tabs.query.general.ClassificationQueryTab;
import org.encog.workbench.tabs.query.general.RegressionQueryTab;
import org.encog.workbench.tabs.query.ocr.OCRQueryTab;
import org.encog.workbench.tabs.query.thermal.QueryThermalTab;
import org.encog.workbench.tabs.visualize.ThermalGrid.ThermalGridTab;
import org.encog.workbench.tabs.visualize.structure.StructureTab;
import org.encog.workbench.tabs.visualize.weights.AnalyzeWeightsTab;

public class MLMethodTab extends EncogCommonTab implements ActionListener {

  /**
   *
   */
  private static final long serialVersionUID = 1L;
  private JToolBar toolbar;
  private JButton buttonRandomize;
  private JButton buttonQuery;
  private JButton buttonTrain;
  private JButton buttonRestructure;
  private JButton buttonProperties;
  private JButton buttonVisualize;
  private final JScrollPane scroll;
  private final JEditorPane editor;
  private MLMethod method;

  public MLMethodTab(final ProjectEGFile data) {
    super(data);

    this.method = (MLMethod)data.getObject();
    setLayout(new BorderLayout());
    this.toolbar = new JToolBar();
    this.toolbar.setFloatable(false);
    this.toolbar.add(this.buttonRandomize = new JButton("Randomize/Reset"));
    this.toolbar.add(this.buttonQuery = new JButton("Query"));
    this.toolbar.add(this.buttonTrain = new JButton("Train"));
    this.toolbar.add(this.buttonRestructure = new JButton("Restructure"));
    this.toolbar.add(this.buttonProperties = new JButton("Properties"));
    this.toolbar.add(this.buttonVisualize = new JButton("Visualize"));

    this.buttonRandomize.addActionListener(this);
    this.buttonQuery.addActionListener(this);
    this.buttonTrain.addActionListener(this);
    this.buttonRestructure.addActionListener(this);
    this.buttonProperties.addActionListener(this);
    this.buttonVisualize.addActionListener(this);

    add(this.toolbar, BorderLayout.PAGE_START);

    this.editor = new JEditorPane("text/html", "");
    this.editor.setEditable(false);
    this.scroll = new JScrollPane(this.editor);
    add(this.scroll, BorderLayout.CENTER);
    produceReport();
  }

  public void actionPerformed(final ActionEvent action) {
    try {
      if (action.getSource() == this.buttonQuery) {
        performQuery();
      } else if (action.getSource() == this.buttonRandomize) {
        performRandomize();
      } else if (action.getSource() == this.buttonTrain) {
        performTrain();
      } else if (action.getSource() == this.buttonRestructure) {
        performRestructure();
      } else if (action.getSource() == this.buttonProperties) {
        performProperties();
      } else if (action.getSource() == this.buttonVisualize) {
        this.handleVisualize();
      }
    } catch (Throwable t) {
      EncogWorkBench.displayError("Error", t);
    }
  }

  private void performTrain() {
    TrainBasicNetwork t = new TrainBasicNetwork((ProjectEGFile)this.getEncogObject(),this);
    t.performTrain();
  }

  private void randomizeBasicNetwork() {
    RandomizeNetworkDialog dialog = new RandomizeNetworkDialog(
        EncogWorkBench.getInstance().getMainWindow());

    dialog.getHigh().setValue(1);
    dialog.getConstHigh().setValue(1);
    dialog.getLow().setValue(-1);
    dialog.getConstLow().setValue(-1);
    dialog.getSeedValue().setValue(1000);
    dialog.getConstantValue().setValue(0);
    dialog.getPerturbPercent().setValue(0.01);

    if (dialog.process()) {
      switch (dialog.getCurrentTab()) {
      case 0:
        optionRandomize(dialog);
        break;

      case 1:
        optionPerturb(dialog);
        break;
      case 2:
        optionGaussian(dialog);
        break;
      case 3:
        optionConsistent(dialog);
        break;

      case 4:
        optionConstant(dialog);
        break;
      }
    }

  }

  private void performRandomize() {

    if (EncogWorkBench.askQuestion("Are you sure?",
        "Randomize/reset network weights and lose all training?")) {
      if (this.method instanceof BasicNetwork) {
        randomizeBasicNetwork();
      } else if (method instanceof MLResettable) {
        ((MLResettable) method).reset();
      }
    }

  }

  private void optionConstant(RandomizeNetworkDialog dialog) {
    double value = dialog.getConstantValue().getValue();
    ConstRandomizer r = new ConstRandomizer(value);
    r.randomize((BasicNetwork)this.method);
    setDirty(true);
  }

  private void optionConsistent(RandomizeNetworkDialog dialog) {
    int seed = dialog.getSeedValue().getValue();
    double min = dialog.getConstLow().getValue();
    double max = dialog.getConstHigh().getValue();
    ConsistentRandomizer c = new ConsistentRandomizer(min, max, seed);
    c.randomize(this.method);
    setDirty(true);
  }

  private void optionPerturb(RandomizeNetworkDialog dialog) {
    double percent = dialog.getPerturbPercent().getValue();

    Distort distort = new Distort(percent);
    distort.randomize((BasicNetwork) this.method);
    setDirty(true);
  }

  private void optionGaussian(RandomizeNetworkDialog dialog) {
    double mean = dialog.getMean().getValue();
    double dev = dialog.getDeviation().getValue();

    GaussianRandomizer g = new GaussianRandomizer(mean, dev);
    g.randomize((BasicNetwork) this.method);
    setDirty(true);
  }

  private void optionRandomize(RandomizeNetworkDialog dialog) {
    Randomizer r = null;

    switch (dialog.getType().getSelectedIndex()) {
    case 0: // Random
      r = new RangeRandomizer(dialog.getLow().getValue(), dialog
          .getHigh().getValue());
      break;
    case 1: // Nguyen-Widrow
      r = new NguyenWidrowRandomizer(dialog.getLow().getValue(), dialog
          .getHigh().getValue());
      break;
    case 2: // Fan in
      r = new FanInRandomizer(dialog.getLow().getValue(), dialog
          .getHigh().getValue(), false);
      break;
    }

    if (r != null) {
      r.randomize((BasicNetwork) this.method);
      setDirty(true);
    }
  }

  private void performQuery() {
    try {
      if (this.method instanceof ThermalNetwork) {
        QueryThermalTab tab = new QueryThermalTab((ProjectEGFile)this.getEncogObject());
        EncogWorkBench.getInstance().getMainWindow().getTabManager()
            .openModalTab(tab, "Thermal Query");     
      }
      // only supports regression
      else {
        SelectItem selectClassification = null;
        SelectItem selectRegression = null;
        SelectItem selectOCR;
       
        List<SelectItem> list = new ArrayList<SelectItem>();
        if( this.method instanceof MLClassification ) {
        list.add(selectClassification = new SelectItem("Query Classification",
            "Machine Learning output is a class."));
        }
        if( this.method instanceof MLRegression ) {
        list.add(selectRegression = new SelectItem("Query Regression",
            "Machine Learning output is a number(s)."));
        }
        list.add(selectOCR = new SelectItem("Query OCR",
          "Query using drawn chars.  Supports regression or classification."));
        SelectDialog sel = new SelectDialog(EncogWorkBench.getInstance()
            .getMainWindow(), list);
        sel.setVisible(true);
       
        if( sel.getSelected()==selectClassification ) {
          ClassificationQueryTab tab = new ClassificationQueryTab(
              ((ProjectEGFile) this.getEncogObject()));
          EncogWorkBench.getInstance().getMainWindow().getTabManager()
              .openModalTab(tab, "Query Classification");         
        } else if( sel.getSelected()==selectRegression ) {
          RegressionQueryTab tab = new RegressionQueryTab(
              ((ProjectEGFile) this.getEncogObject()));
          EncogWorkBench.getInstance().getMainWindow().getTabManager()
              .openModalTab(tab, "Query Regression");         
        else if( sel.getSelected()==selectOCR ) {
          OCRQueryTab tab = new OCRQueryTab(
              ((ProjectEGFile) this.getEncogObject()));
          EncogWorkBench.getInstance().getMainWindow().getTabManager()
              .openModalTab(tab, "Query OCR");         
        }
      }
    } catch (Throwable t) {
      EncogWorkBench.displayError("Error", t);
    }
  }

  /**
   * @return the data
   */
  public BasicNetwork getData() {
    return (BasicNetwork) this.method;
  }

  public void mouseClicked(MouseEvent e) {
    // TODO Auto-generated method stub

  }

  public void performProperties() {
    if ( this.method instanceof MLProperties) {
      MapDataFrame frame = new MapDataFrame(
          ((MLProperties)method).getProperties(),
          "Properties");
      frame.setVisible(true);
      setDirty(true);
    } else {
      EncogWorkBench
          .displayError("Error",
              "This Machine Learning Method type does not support properties.");
    }
  }

  public void handleVisualize() {
    SelectItem selectWeights;
    SelectItem selectStructure;
    SelectItem selectThermal;
    List<SelectItem> list = new ArrayList<SelectItem>();
    list.add(selectWeights = new SelectItem("Weights Histogram",
        "A histogram of the weights."));
    list.add(selectStructure = new SelectItem("Network Structure",
        "The structure of the neural network."));
    list.add(selectThermal = new SelectItem("Thermal Matrix",
        "Shows the matrix of a Hopfield or Boltzmann Machine."));
    SelectDialog sel = new SelectDialog(EncogWorkBench.getInstance()
        .getMainWindow(), list);
    sel.setVisible(true);

    if (sel.getSelected() == selectWeights) {
      analyzeWeights();
    } else if (sel.getSelected() == selectStructure) {
      analyzeStructure();
    } else if (sel.getSelected() == selectThermal) {
      analyzeThermal();
    }

  }

  private void analyzeThermal() {
    ThermalGridTab tab = new ThermalGridTab((ProjectEGFile) this.getEncogObject());
    EncogWorkBench.getInstance().getMainWindow().getTabManager()
        .openModalTab(tab, "Thermal Grid");
  }

  private void analyzeStructure() {

    if (method instanceof MLMethod) {
      StructureTab tab = new StructureTab(
          ((MLMethod)this.method));
      EncogWorkBench.getInstance().getMainWindow().getTabManager()
          .openModalTab(tab, "Network Structure");
    } else {
      throw new WorkBenchError("No analysis available for: "
          + this.method.getClass().getSimpleName());
    }

  }

  public void analyzeWeights() {
    AnalyzeWeightsTab tab = new AnalyzeWeightsTab((ProjectEGFile)this.getEncogObject());
    EncogWorkBench.getInstance().getMainWindow().getTabManager()
        .openModalTab(tab, "Analyze Weights");
  }

  public void produceReport() {
    HTMLReport report = new HTMLReport();
    report.beginHTML();
    report.title("MLMethod");
    report.beginBody();
    report.h1(this.method.getClass().getSimpleName());

    report.beginTable();

    if (method instanceof MLInput) {
      MLInput reg = (MLInput) method;
      report.tablePair("Input Count",
          Format.formatInteger(reg.getInputCount()));
    }

    if (method instanceof MLOutput) {
      MLOutput reg = (MLOutput) method;
      report.tablePair("Output Count",
          Format.formatInteger(reg.getOutputCount()));
    }

    if (method instanceof MLEncodable) {
      MLEncodable encode = (MLEncodable)method;
      report.tablePair("Encoded Length",
          Format.formatInteger(encode.encodedArrayLength()));
    }

    report.tablePair("Resettable",
        (method instanceof MLResettable) ? "true" : "false");
   
    report.tablePair("Context",
        (method instanceof MLContext) ? "true" : "false");
   
   
    if( method instanceof NEATNetwork ) {
      NEATNetwork neat = (NEATNetwork)method;
      report.tablePair("Output Activation Function", neat.getOutputActivationFunction().getClass().getSimpleName());
      report.tablePair("NEAT Activation Function", neat.getActivationFunction().getClass().getSimpleName());
    }
   
    if( method instanceof CPN ) {
      CPN cpn = (CPN)method;
      report.tablePair("Instar Count", Format.formatInteger(cpn.getInstarCount()));
      report.tablePair("Outstar Count", Format.formatInteger(cpn.getOutstarCount()));     
    }
   
    report.endTable();
   
    if (this.method instanceof RBFNetwork) {
      RBFNetwork rbfNetwork = (RBFNetwork)this.method;
     
      report.h3("RBF Centers");
      report.beginTable();
      report.beginRow();
      report.header("RBF");
      report.header("Peak");
      report.header("Width");
      for(int i=1;i<=rbfNetwork.getInputCount();i++) {
        report.header("Center " + i);
      }
      report.endRow();
     
     
      for( RadialBasisFunction rbf : rbfNetwork.getRBF() ) {
        report.beginRow();
        report.cell(rbf.getClass().getSimpleName());
        report.cell(Format.formatDouble(rbf.getPeak(), 5));
        report.cell(Format.formatDouble(rbf.getWidth(), 5));
        for(int i=0;i<rbfNetwork.getInputCount();i++) {
          report.cell(Format.formatDouble(rbf.getCenter(i), 5));
        }
        report.endRow();
      }
    }

    if (this.method instanceof BasicNetwork) {
      report.h3("Layers");
      report.beginTable();
      report.beginRow();
      report.header("Layer #");
      report.header("Total Count");
      report.header("Neuron Count");
      report.header("Activation Function");
      report.header("Bias");
      report.header("Context Target Size");
      report.header("Context Target Offset");
      report.header("Context Count");
      report.endRow();

      BasicNetwork network = (BasicNetwork) method;
      FlatNetwork flat = network.getStructure().getFlat();
      int layerCount = network.getLayerCount();

      for (int l = 0; l < layerCount; l++) {
        report.beginRow();
        StringBuilder str = new StringBuilder();
        str.append(Format.formatInteger(l + 1));
        if (l == 0) {
          str.append(" (Output)");
        } else if (l == network.getLayerCount() - 1) {
          str.append(" (Input)");
        }
        report.cell(str.toString());
        report.cell(Format.formatInteger(flat.getLayerCounts()[l]));
        report.cell(Format.formatInteger(flat.getLayerFeedCounts()[l]));
        report.cell(flat.getActivationFunctions()[l].getClass()
            .getSimpleName());
        report.cell(Format.formatDouble(flat.getBiasActivation()[l], 4));
        report.cell(Format.formatInteger(flat.getContextTargetSize()[l]));
        report.cell(Format.formatInteger(flat.getContextTargetOffset()[l]));
        report.cell(Format.formatInteger(flat.getLayerContextCount()[l]));
        report.endRow();
      }
      report.endTable();
    }
   


    report.endBody();
    report.endHTML();
    this.editor.setText(report.toString());
  }

  private void restructureHopfield() {
    HopfieldNetwork hopfield = (HopfieldNetwork) method;
    CreateHopfieldDialog dialog = new CreateHopfieldDialog(EncogWorkBench
        .getInstance().getMainWindow());
    dialog.getNeuronCount().setValue(hopfield.getNeuronCount());

    if (dialog.process()
        && (hopfield.getNeuronCount() != dialog.getNeuronCount()
            .getValue())) {
      HopfieldPattern pattern = new HopfieldPattern();
      pattern.setInputNeurons(dialog.getNeuronCount().getValue());
      setDirty(true);
      produceReport();
    }
  }

  private void restructureFeedforward() {
    CreateFeedforward dialog = new CreateFeedforward(EncogWorkBench
        .getInstance().getMainWindow());
    BasicNetwork network = (BasicNetwork)method;

    ActivationFunction oldActivationOutput = network.getActivation(network
        .getLayerCount() - 1);
    dialog.setActivationFunctionOutput(oldActivationOutput);
    dialog.getInputCount().setValue(network.getInputCount());
    dialog.getOutputCount().setValue(network.getOutputCount());
    int hiddenLayerCount = network.getLayerCount() - 2;

    ActivationFunction oldActivationHidden = new ActivationTANH();
    for (int i = 0; i < hiddenLayerCount; i++) {
      int num = network.getLayerNeuronCount(i + 1);
      String str = "Hidden Layer " + (i + 1) + ": " + num + " neurons";
      dialog.getHidden().getModel().addElement(str);
    }
    dialog.setActivationFunctionHidden(oldActivationHidden);

    if (dialog.process()) {
      // decide if entire network is to be recreated
      if ((dialog.getActivationFunctionHidden() != oldActivationHidden)
          || (dialog.getActivationFunctionOutput() != oldActivationOutput)
          || dialog.getHidden().getModel().size() != (network
              .getLayerCount() - 2)) {
        FeedForwardPattern feedforward = new FeedForwardPattern();
        feedforward.setActivationFunction(dialog
            .getActivationFunctionHidden());
        feedforward.setInputNeurons(dialog.getInputCount().getValue());
        for (int i = 0; i < dialog.getHidden().getModel().size(); i++) {
          String str = (String) dialog.getHidden().getModel()
              .getElementAt(i);
          int i1 = str.indexOf(':');
          int i2 = str.indexOf("neur");
          if (i1 != -1 && i2 != -1) {
            str = str.substring(i1 + 1, i2).trim();
            int neuronCount = Integer.parseInt(str);
            feedforward.addHiddenLayer(neuronCount);
          }
        }
        feedforward.setInputNeurons(dialog.getInputCount().getValue());
        feedforward.setOutputNeurons(dialog.getOutputCount().getValue());
        BasicNetwork obj = (BasicNetwork) feedforward.generate();
      } else {
        // try to prune it
        PruneSelective prune = new PruneSelective(network);
        int newInputCount = dialog.getInputCount().getValue();
        int newOutputCount = dialog.getOutputCount().getValue();

        // did input neurons change?
        if (newInputCount != network.getInputCount()) {
          prune.changeNeuronCount(0, newInputCount);
        }

        // did output neurons change?
        if (newOutputCount != network.getOutputCount()) {
          prune.changeNeuronCount(0, newOutputCount);
        }

        // did the hidden layers change?
        for (int i = 0; i < network.getLayerCount() - 2; i++) {
          int newHiddenCount = 1;
          String str = (String) dialog.getHidden().getModel()
              .getElementAt(i);
          int i1 = str.indexOf(':');
          int i2 = str.indexOf("neur");
          if (i1 != -1 && i2 != -1) {
            str = str.substring(i1 + 1, i2).trim();
            newHiddenCount = Integer.parseInt(str);
          }

          // did this hidden layer change?
          if (network.getLayerNeuronCount(i) != newHiddenCount) {
            prune.changeNeuronCount(i + 1, newHiddenCount);
          }
        }
      }
      setDirty(true);
      produceReport();
    }
  }

  private void performRestructure() {
    if (method instanceof HopfieldNetwork) {
      restructureHopfield();
    } else if (method instanceof BasicNetwork) {
      restructureFeedforward();
    } else {
      EncogWorkBench.displayError("Error",
          "This Machine Learning Method cannot be restructured.");
    }
  }

  @Override
  public String getName() {
    return this.getEncogObject().getName();
  }

}
TOP

Related Classes of org.encog.workbench.tabs.mlmethod.MLMethodTab

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.