Package org.grouplens.lenskit.mf.funksvd

Source Code of org.grouplens.lenskit.mf.funksvd.FunkSVDItemScorer

/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.mf.funksvd;

import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.baseline.BaselineScorer;
import org.grouplens.lenskit.basic.AbstractItemScorer;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.event.Ratings;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.PreferenceDomain;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.mf.svd.BiasedMFKernel;
import org.grouplens.lenskit.mf.svd.DomainClampingKernel;
import org.grouplens.lenskit.mf.svd.DotProductKernel;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Inject;

/**
* Do recommendations and predictions based on SVD matrix factorization.
*
* Recommendation is done based on folding-in.  The strategy is do a fold-in
* operation as described in
* <a href="http://www.grouplens.org/node/212">Sarwar et al., 2002</a> with the
* user's ratings.
*
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public class FunkSVDItemScorer extends AbstractItemScorer {

    protected final FunkSVDModel model;
    protected final BiasedMFKernel kernel;
    private UserEventDAO dao;
    private final ItemScorer baselineScorer;
    private final int featureCount;

    @Nullable
    private final FunkSVDUpdateRule rule;

    /**
     * Construct the item scorer.
     *
     * @param dao      The DAO.
     * @param model    The model.
     * @param baseline The baseline scorer.  Be very careful when configuring a different baseline
     *                 at runtime than at model-build time; such a configuration is unlikely to
     *                 perform well.
     * @param rule     The update rule, or {@code null} (the default) to only use the user features
     *                 from the model. If provided, this update rule is used to update a user's
     *                 feature values based on their profile when scores are requested.
     */
    @Inject
    public FunkSVDItemScorer(UserEventDAO dao, FunkSVDModel model,
                             @BaselineScorer ItemScorer baseline,
                             @Nullable PreferenceDomain dom,
                             @Nullable @RuntimeUpdate FunkSVDUpdateRule rule) {
        // FIXME Unify requirement on update rule and DAO
        this.dao = dao;
        this.model = model;
        baselineScorer = baseline;
        this.rule = rule;

        if (dom == null) {
            kernel = new DotProductKernel();
        } else {
            kernel = new DomainClampingKernel(dom);
        }

        featureCount = model.getFeatureCount();
    }

    @Nullable
    public FunkSVDUpdateRule getUpdateRule() {
        return rule;
    }

    /**
     * Predict for a user using their preference array and history vector.
     *
     * @param user   The user's ID
     * @param uprefs The user's preference vector.
     * @param output The output vector, whose key domain is the items to predict for. It must
     *               be initialized to the user's baseline predictions.
     */
    private void computeScores(long user, AVector uprefs, MutableSparseVector output) {
        for (VectorEntry e : output) {
            final long item = e.getKey();
            AVector ivec = model.getItemVector(item);
            if (ivec == null) {
                // no item-vector, cannot make an informed prediction.
                // unset the baseline to note that we are not predicting for this item.
                output.unset(e);
            } else {
                double score = kernel.apply(e.getValue(), uprefs, ivec);
                output.set(e, score);
            }
        }
    }

    /**
     * Get estimates for all a user's ratings and the target items.
     *
     * @param user    The user ID.
     * @param ratings The user's ratings.
     * @param items   The target items.
     * @return Baseline predictions for all items either in the target set or the set of
     *         rated items.
     */
    private MutableSparseVector initialEstimates(long user, SparseVector ratings, LongSortedSet items) {
        LongSet allItems = LongUtils.setUnion(items, ratings.keySet());
        MutableSparseVector estimates = MutableSparseVector.create(allItems);
        baselineScorer.score(user, estimates);
        return estimates;
    }

    @Override
    public void score(long user, @Nonnull MutableSparseVector scores) {
        UserHistory<Rating> history = dao.getEventsForUser(user, Rating.class);
        if (history == null) {
            history = History.forUser(user);
        }
        SparseVector ratings = Ratings.userRatingVector(history);
       
        AVector uprefs = model.getUserVector(user);
        if (uprefs == null) {
            if (ratings.isEmpty()) {
                // no real work to do.
                return;
            }
            uprefs = model.getAverageUserVector();
        }

        MutableSparseVector estimates = initialEstimates(user, ratings, scores.keyDomain());
        // propagate estimates to the output scores
        scores.set(estimates);

        if (!ratings.isEmpty() && rule != null) {
            AVector updated = Vector.create(uprefs);
            for (int f = 0; f < featureCount; f++) {
                trainUserFeature(user, updated, ratings, estimates, f);
            }
            uprefs = updated;
        }

        // scores are the estimates, uprefs are trained up.
        computeScores(user, uprefs, scores);
    }

    private void trainUserFeature(long user, AVector uprefs, SparseVector ratings,
                                  MutableSparseVector estimates, int feature) {
        assert rule != null;
        assert uprefs.length() == featureCount;
        assert feature >= 0 && feature < featureCount;

        int tailStart = feature + 1;
        int tailSize = featureCount - feature - 1;
        AVector utail = uprefs.subVector(tailStart, tailSize);
        MutableSparseVector tails = MutableSparseVector.create(ratings.keySet());
        for (VectorEntry e: tails.view(VectorEntry.State.EITHER)) {
            AVector ivec = model.getItemVector(e.getKey());
            if (ivec == null) {
                // FIXME Do this properly
                tails.set(e, 0);
            } else {
                ivec = ivec.subVector(tailStart, tailSize);
                tails.set(e, utail.dotProduct(ivec));
            }
        }

        double rmse = Double.MAX_VALUE;
        TrainingLoopController controller = rule.getTrainingLoopController();
        while (controller.keepTraining(rmse)) {
            rmse = doFeatureIteration(user, uprefs, ratings, estimates, feature, tails);
        }
    }

    private double doFeatureIteration(long user, AVector uprefs,
                                      SparseVector ratings, MutableSparseVector estimates,
                                      int feature, SparseVector itemTails) {
        assert rule != null;

        FunkSVDUpdater updater = rule.createUpdater();
        for (VectorEntry e: ratings) {
            final long iid = e.getKey();
            final AVector ivec = model.getItemVector(iid);
            if (ivec == null) {
                continue;
            }

            updater.prepare(feature, e.getValue(), estimates.get(iid),
                            uprefs.get(feature), ivec.get(feature), itemTails.get(iid));
            // Step 4: update user preferences
            uprefs.addAt(feature, updater.getUserFeatureUpdate());
        }
        return updater.getRMSE();
    }
}
TOP

Related Classes of org.grouplens.lenskit.mf.funksvd.FunkSVDItemScorer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.