Package org.ejml.simple

Examples of org.ejml.simple.SimpleMatrix


        DebugQR qr = new DebugQR(width,width);

        double gamma = 0.2;
        double tau = 0.75;

        SimpleMatrix U = new SimpleMatrix(width,1);
        SimpleMatrix A = new SimpleMatrix(width,width);

        RandomMatrices.setRandom(U.getMatrix(),rand);
        RandomMatrices.setRandom(A.getMatrix(),rand);

        qr.getQR().set(A.getMatrix());

        // compute the results using standard matrix operations
        SimpleMatrix I = SimpleMatrix.identity(width-w);

        SimpleMatrix u_sub = U.extractMatrix(w,width,0,1);
        SimpleMatrix A_sub = A.extractMatrix(w,width,w,width);
        SimpleMatrix expected = I.minus(u_sub.mult(u_sub.transpose()).scale(gamma)).mult(A_sub);

        qr.updateA(w,U.getMatrix().getData(),gamma,tau);

        DenseMatrix64F found = qr.getQR();

        assertEquals(-tau,found.get(w,w),1e-8);

        for( int i = w+1; i < width; i++ ) {
            assertEquals(U.get(i,0),found.get(i,w),1e-8);
        }

        // the right should be the same
        for( int i = w; i < width; i++ ) {
            for( int j = w+1; j < width; j++ ) {
                double a = expected.get(i-w,j-w);
                double b = found.get(i,j);

                assertEquals(a,b,1e-6);
            }
        }
View Full Code Here


    /**
     * Creates a Q matrix for debugging purposes.
     */
    private SimpleMatrix createQ(int x1, int x2 , double c, double s , boolean transposed ) {
        SimpleMatrix Q = SimpleMatrix.identity(N);
        Q.set(x1,x1,c);
        if( transposed ) {
            Q.set(x1,x2,s);
            Q.set(x2,x1,-s);
        } else {
            Q.set(x1,x2,-s);
            Q.set(x2,x1,s);
        }
        Q.set(x2,x2,c);
        return Q;
    }
View Full Code Here

        Q.set(x2,x2,c);
        return Q;
    }

    private SimpleMatrix createB() {
        SimpleMatrix B = new SimpleMatrix(N,N);

        for( int i = 0; i < N-1; i++ ) {
            B.set(i,i,diag[i]);
            B.set(i,i+1,off[i]);
        }
        B.set(N-1,N-1,diag[N-1]);

        return B;
    }
View Full Code Here

      return;
    }

    System.err.println("Using wordVector " + source + " for " + target);

    wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
  }
View Full Code Here

  public static void replaceWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
    if (!wordVectors.containsKey(source)) {
      return;
    }

    wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
  }
View Full Code Here

      slices[i] = loadMatrix(basePath + "bin/Wt_" + (i + 1) + ".bin", basePath + "Wt_" + (i + 1) + ".txt");
    }
    SimpleTensor tensor = new SimpleTensor(slices);
    System.err.println("W tensor size: " + tensor.numRows() + "x" + tensor.numCols() + "x" + tensor.numSlices());

    SimpleMatrix W = loadMatrix(basePath + "bin/W.bin", basePath + "W.txt");
    System.err.println("W matrix size: " + W.numRows() + "x" + W.numCols());

    SimpleMatrix Wcat = loadMatrix(basePath + "bin/Wcat.bin", basePath + "Wcat.txt");
    System.err.println("W cat size: " + Wcat.numRows() + "x" + Wcat.numCols());

    SimpleMatrix combinedWV = loadMatrix(basePath + "bin/Wv.bin", basePath + "Wv.txt");
    System.err.println("Word matrix size: " + combinedWV.numRows() + "x" + combinedWV.numCols());

    File vocabFile = new File(basePath + "vocab_1.txt");
    if (!vocabFile.exists()) {
      vocabFile = new File(basePath + "words.txt");
    }
    List<String> lines = Generics.newArrayList();
    for (String line : IOUtils.readLines(vocabFile)) {
      lines.add(line.trim());
    }

    System.err.println("Lines in vocab file: " + lines.size());

    Map<String, SimpleMatrix> wordVectors = Generics.newTreeMap();

    for (int i = 0; i < lines.size() && i < combinedWV.numCols(); ++i) {
      String[] pieces = lines.get(i).split(" +");
      if (pieces.length == 0 || pieces.length > 1) {
        continue;
      }
      wordVectors.put(pieces[0], combinedWV.extractMatrix(0, numSlices, i, i+1));
      if (pieces[0].equals("UNK")) {
        wordVectors.put(SentimentModel.UNKNOWN_WORD, wordVectors.get("UNK"));
      }
    }
View Full Code Here

    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : model.binaryTransform) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();

      binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
    }

    if (!model.op.combineClassification) {
      for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : model.binaryClassification) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();

        binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      }
    }

    if (model.op.useTensors) {
      for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : model.binaryTensors) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();
        int numSlices = entry.getValue().numSlices();

        binaryTensorTD.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleTensor(numRows, numCols, numSlices));
      }
    }

    for (Map.Entry<String, SimpleMatrix> entry : model.unaryClassification.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      unaryCD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
    }
    for (Map.Entry<String, SimpleMatrix> entry : model.wordVectors.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      wordVectorD.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
    }

    // TODO: This part can easily be parallelized
    List<Tree> forwardPropTrees = Generics.newArrayList();
    for (Tree tree : trainingBatch) {
View Full Code Here

                            TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices,
                            double scale,
                            double regCost) {
    double cost = 0.0; // the regularization cost
    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
      SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
      D = D.scale(scale).plus(entry.getValue().scale(regCost));
      derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
      cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
    }
    return cost;
  }
View Full Code Here

                            Map<String, SimpleMatrix> currentMatrices,
                            double scale,
                            double regCost) {
    double cost = 0.0; // the regularization cost
    for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
      SimpleMatrix D = derivatives.get(entry.getKey());
      D = D.scale(scale).plus(entry.getValue().scale(regCost));
      derivatives.put(entry.getKey(), D);
      cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
    }
    return cost;
  }
View Full Code Here

                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
                                           TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
                                           TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD,
                                           Map<String, SimpleMatrix> unaryCD,
                                           Map<String, SimpleMatrix> wordVectorD) {
    SimpleMatrix delta = new SimpleMatrix(model.op.numHid, 1);
    backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
  }
View Full Code Here

TOP

Related Classes of org.ejml.simple.SimpleMatrix

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.