Package net.myrrix.online.candidate

Source Code of net.myrrix.online.candidate.LocationSensitiveHashTest

/*
* Copyright Myrrix Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.myrrix.online.candidate;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.MyrrixTest;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;

public final class LocationSensitiveHashTest extends MyrrixTest {

  private static final Logger log = LoggerFactory.getLogger(LocationSensitiveHashTest.class);

  private static final int NUM_FEATURES = 50;
  private static final int NUM_ITEMS = 2000000;
  private static final int NUM_RECS = 10;
  private static final int ITERATIONS = 20;
  private static final double LN2 = Math.log(2.0);

  @Test
  public void testLSH() {
    System.setProperty("model.lsh.sampleRatio", "0.1");
    System.setProperty("model.lsh.numHashes", "20");
    RandomGenerator random = RandomManager.getRandom();

    Mean avgPercentTopRecsConsidered = new Mean();
    Mean avgNDCG = new Mean();
    Mean avgPercentAllItemsConsidered= new Mean();

    for (int iteration = 0; iteration < ITERATIONS; iteration++) {

      FastByIDMap<float[]> Y = new FastByIDMap<float[]>();
      for (int i = 0; i < NUM_ITEMS; i++) {
        Y.put(i, RandomUtils.randomUnitVector(NUM_FEATURES, random));
      }
      float[] userVec = RandomUtils.randomUnitVector(NUM_FEATURES, random);

      double[] results = doTestRandomVecs(Y, userVec);
      double percentTopRecsConsidered = results[0];
      double ndcg = results[1];
      double percentAllItemsConsidered = results[2];

      log.info("Considered {}% of all candidates, {} nDCG, got {}% recommendations correct",
               100 * percentAllItemsConsidered,
               ndcg,
               100 * percentTopRecsConsidered);

      avgPercentTopRecsConsidered.increment(percentTopRecsConsidered);
      avgNDCG.increment(ndcg);
      avgPercentAllItemsConsidered.increment(percentAllItemsConsidered);
    }

    log.info("{}", avgPercentTopRecsConsidered.getResult());
    log.info("{}", avgNDCG.getResult());
    log.info("{}", avgPercentAllItemsConsidered.getResult());

    assertTrue(avgPercentTopRecsConsidered.getResult() > 0.55);
    assertTrue(avgNDCG.getResult() > 0.55);
    assertTrue(avgPercentAllItemsConsidered.getResult() < 0.075);
  }

  private static double[] doTestRandomVecs(FastByIDMap<float[]> Y, float[] userVec) {

    CandidateFilter lsh = new LocationSensitiveHash(Y);

    FastIDSet candidates = new FastIDSet();
    float[][] userVecs = { userVec };
    for (Iterator<FastByIDMap.MapEntry<float[]>> candidatesIterator : lsh.getCandidateIterator(userVecs)) {
      while (candidatesIterator.hasNext()) {
        candidates.add(candidatesIterator.next().getKey());
      }
    }

    List<Long> topIDs = findTopRecommendations(Y, userVec);

    double score = 0.0;
    double maxScore = 0.0;
    int intersectionSize = 0;
    for (int i = 0; i < topIDs.size(); i++) {
      double value = LN2 / Math.log(2.0 + i);
      long id = topIDs.get(i);
      if (candidates.contains(id)) {
        intersectionSize++;
        score += value;
      }
      maxScore += value;
    }

    double percentTopRecsConsidered = (double) intersectionSize / topIDs.size();
    double ndcg = maxScore == 0.0 ? 0.0 : score / maxScore;
    double percentAllItemsConsidered = (double) candidates.size() / Y.size();

    return new double[] {percentTopRecsConsidered, ndcg, percentAllItemsConsidered};
  }

  private static List<Long> findTopRecommendations(FastByIDMap<float[]> Y, float[] userVec) {
    SortedMap<Double,Long> allScores = Maps.newTreeMap(Collections.reverseOrder());
    for (FastByIDMap.MapEntry<float[]> entry : Y.entrySet()) {
      double dot = SimpleVectorMath.dot(entry.getValue(), userVec);
      allScores.put(dot, entry.getKey());
    }
    List<Long> topRecommendations = Lists.newArrayList();
    for (Map.Entry<Double,Long> entry : allScores.entrySet()) {
      topRecommendations.add(entry.getValue());
      if (topRecommendations.size() == NUM_RECS) {
        return topRecommendations;
      }
    }
    return topRecommendations;
  }

}
TOP

Related Classes of net.myrrix.online.candidate.LocationSensitiveHashTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.