Package de.jungblut.math.activation

Source Code of de.jungblut.math.activation.AbstractActivationFunction

package de.jungblut.math.activation;

import java.util.Iterator;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.sparse.SparseDoubleVector;

/**
* Implements the boiler plate code for applying functions on container classes
* like vectors and matrices by applying the function on every element. This
* implementation is aware of the type of the vector and matrix, so it is also
* optimized for sparse as well as dense types.
*
* @author thomas.jungblut
*
*/
public abstract class AbstractActivationFunction implements ActivationFunction {

  @Override
  public DoubleVector apply(DoubleVector vector) {
    DoubleVector newInstance = newInstance(vector);
    if (vector.isSparse()) {
      Iterator<DoubleVectorElement> iterateNonZero = vector.iterateNonZero();
      while (iterateNonZero.hasNext()) {
        DoubleVectorElement next = iterateNonZero.next();
        newInstance.set(next.getIndex(), apply(next.getValue()));
      }
    } else {
      for (int i = 0; i < vector.getDimension(); i++) {
        newInstance.set(i, apply(vector.get(i)));
      }
    }
    return newInstance;
  }

  @Override
  public DoubleMatrix apply(DoubleMatrix matrix) {
    DoubleMatrix newInstance = newInstance(matrix);
    if (matrix.isSparse()) {
      // if we have a sparse matrix, it is more efficient to loop over the
      // sparse row vectors
      int[] rows = matrix.rowIndices();
      for (int row : rows) {
        DoubleVector rowVector = matrix.getRowVector(row);
        if (rowVector.getLength() > 0) {
          DoubleVector apply = apply(rowVector);
          newInstance.setRowVector(row, apply);
        }
      }
    } else {
      // on dense matrices we can be faster by directly looping over the items
      for (int i = 0; i < matrix.getRowCount(); i++) {
        for (int j = 0; j < matrix.getColumnCount(); j++) {
          newInstance.set(i, j, apply(matrix.get(i, j)));
        }
      }
    }
    return newInstance;
  }

  @Override
  public DoubleVector gradient(DoubleVector vector) {
    DoubleVector newInstance = newInstance(vector);
    if (vector.isSparse()) {
      Iterator<DoubleVectorElement> iterateNonZero = vector.iterateNonZero();
      while (iterateNonZero.hasNext()) {
        DoubleVectorElement next = iterateNonZero.next();
        newInstance.set(next.getIndex(), gradient(next.getValue()));
      }
    } else {
      for (int i = 0; i < vector.getDimension(); i++) {
        newInstance.set(i, gradient(vector.get(i)));
      }
    }
    return newInstance;
  }

  @Override
  public DoubleMatrix gradient(DoubleMatrix matrix) {
    DoubleMatrix newInstance = newInstance(matrix);
    if (matrix.isSparse()) {
      // if we have a sparse matrix, it is more efficient to loop over the
      // sparse column vectors
      int[] columnIndices = matrix.columnIndices();
      for (int col : columnIndices) {
        newInstance.setColumnVector(col, gradient(matrix.getColumnVector(col)));
      }
    } else {
      // on dense matrices we can be faster by directly looping over the items
      for (int i = 0; i < matrix.getRowCount(); i++) {
        for (int j = 0; j < matrix.getColumnCount(); j++) {
          newInstance.set(i, j, gradient(matrix.get(i, j)));
        }
      }
    }
    return newInstance;
  }

  protected DoubleMatrix newInstance(DoubleMatrix mat) {
    if (mat.isSparse()) {
      return new SparseDoubleRowMatrix(mat.getRowCount(), mat.getColumnCount());
    } else {
      return new DenseDoubleMatrix(mat.getRowCount(), mat.getColumnCount());
    }
  }

  protected DoubleVector newInstance(DoubleVector v) {
    if (v.isSparse()) {
      return new SparseDoubleVector(v.getDimension());
    } else {
      return new DenseDoubleVector(v.getDimension());
    }
  }

  @Override
  public String toString() {
    return getClass().getSimpleName();
  }

}
TOP

Related Classes of de.jungblut.math.activation.AbstractActivationFunction

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.