Package com.github.neuralnetworks.calculation.neuronfunctions

Source Code of com.github.neuralnetworks.calculation.neuronfunctions.ConnectionCalculatorConv

package com.github.neuralnetworks.calculation.neuronfunctions;

import java.util.List;

import com.github.neuralnetworks.architecture.Connections;
import com.github.neuralnetworks.architecture.Conv2DConnection;
import com.github.neuralnetworks.architecture.Layer;
import com.github.neuralnetworks.calculation.ConnectionCalculator;
import com.github.neuralnetworks.calculation.ValuesProvider;
import com.github.neuralnetworks.util.Util;

/**
* Default implementation of Connection calculator for convolutional/subsampling
* layers
*/
public class ConnectionCalculatorConv implements ConnectionCalculator {

    private static final long serialVersionUID = -5405654469496055017L;

    protected AparapiConv2D inputFunction;
    protected Layer currentLayer;
    protected int miniBatchSize;

    @Override
    public void calculate(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
  Conv2DConnection c = null;
  Conv2DConnection bias = null;

  for (Connections con : connections) {
      if (con instanceof Conv2DConnection) {
    if (Util.isBias(con.getInputLayer())) {
        bias = (Conv2DConnection) con;
    } else {
        c = (Conv2DConnection) con;
    }
      }
  }

  if (c != null) {
      // currently works only as a feedforward (including bp)
      if (inputFunction == null || miniBatchSize != valuesProvider.getColumns()) {
    miniBatchSize = valuesProvider.getColumns();
    inputFunction = createInputFunction(c, miniBatchSize);
      }

      calculateBias(bias, valuesProvider);

      if (targetLayer == c.getOutputLayer()) {
    inputFunction.calculate(c, valuesProvider.getValues(Util.getOppositeLayer(c, targetLayer), c), valuesProvider.getValues(targetLayer, c));
      } else {
    inputFunction.calculate(c, valuesProvider.getValues(targetLayer, c), valuesProvider.getValues(Util.getOppositeLayer(c, targetLayer), c));
      }
  }
    }

    protected AparapiConv2D createInputFunction(Conv2DConnection c, int miniBatchSize) {
  return new AparapiConv2DFF(c, miniBatchSize);
    }

    protected void calculateBias(Conv2DConnection bias, ValuesProvider vp) {
  if (bias != null) {
      float[] biasValue = vp.getValues(bias.getInputLayer(), bias).getElements();
      if (biasValue[0] == 0) {
    Util.fillArray(biasValue, 1);
      }

      float[] a = vp.getValues(bias.getOutputLayer(), bias).getElements();
      int fm = a.length / bias.getWeights().length;
      for (int i = 0; i < a.length; i++) {
    a[i] += bias.getWeights()[i / fm];
      }
  }
    }

    public AparapiConv2D getInputFunction() {
  return inputFunction;
    }

    public void setInputFunction(AparapiConv2D inputFunction) {
  this.inputFunction = inputFunction;
    }
}
TOP

Related Classes of com.github.neuralnetworks.calculation.neuronfunctions.ConnectionCalculatorConv

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.