Package com.twitter.pers.bipartite

Source Code of com.twitter.pers.bipartite.SALSASmallMem

package com.twitter.pers.bipartite;

import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.vertexdata.ForeachCallback;
import edu.cmu.graphchi.vertexdata.VertexAggregator;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.datablocks.FloatPair;
import edu.cmu.graphchi.datablocks.FloatPairConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.hadoop.PigGraphChiBase;
import edu.cmu.graphchi.preprocessing.EdgeProcessor;
import edu.cmu.graphchi.preprocessing.FastSharder;
import edu.cmu.graphchi.preprocessing.VertexProcessor;
import edu.cmu.graphchi.util.IdFloat;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Logger;

/**
* Version of SALSA that uses just a little memory (values propagated
* via edges), and can be run under Pig.
*
* On each iteration either left or right side is computed. Each vertex
* can represent both sides. Left side has out-edges, right side in-edges.
* Left side = authorities (users)
* Right side = hubs
*
* The algorithm starts with the right side, and the edges have initial
* values for the left side vertices (authorities).
*
* @author Aapo Kyrola, akyrola@cs.cmu.edu
* @copyright Twitter  (done during internship, Fall 2012)
*/
public class SALSASmallMem extends PigGraphChiBase implements GraphChiProgram<FloatPair, Float>  {


    private final static int RIGHTSIDE = 0; // Start with right side
    private final static int LEFTSIDE = 1;

    private String graphName;
    private final static Logger logger = ChiLogger.getLogger("salsa-smallmem");

    int numShards = 20;

    GraphChiEngine<FloatPair, Float> engine;

    public SALSASmallMem() {
        super();
    }

    @Override
    public void update(ChiVertex<FloatPair, Float> vertex, GraphChiContext context) {
        int side = context.getIteration() % 2;
        if (vertex.numEdges() > 0) {
            float nbrSum = 0.0f;

            if (side == LEFTSIDE) {
                for(int i=0; i < vertex.numOutEdges(); i++) {
                    nbrSum += vertex.outEdge(i).getValue();
                }
            } else {
                for(int i=0; i < vertex.numInEdges(); i++) {
                    nbrSum += vertex.inEdge(i).getValue();
                }
            }

            float newValue = nbrSum;

            FloatPair curValue = vertex.getValue();
            if (side == LEFTSIDE && vertex.numOutEdges() > 0) {
                curValue = new FloatPair(newValue, curValue.second);
                // Write value to outedges
                float broadcastValue = newValue / vertex.numOutEdges();
                for(int i=0; i < vertex.numOutEdges(); i++) {
                    vertex.outEdge(i).setValue(broadcastValue);
                }
            }
            else if (side == RIGHTSIDE && vertex.numInEdges() > 0) {
                // Renormalization
                int numRelevantEdges = vertex.numInEdges();
                int totalEdges = (intcurValue.second;
                if (totalEdges == 0) {
                    logger.warning("Normalization factor cannot be zero! Id:" + context.getVertexIdTranslate().backward(vertex.getId()));
                    totalEdges = numRelevantEdges;
                }
                newValue *= numRelevantEdges * 1.0f / (float)totalEdges;

                // Write value to in-edges
                float broadcastValue = newValue / vertex.numInEdges();
                for(int i=0; i < vertex.numInEdges(); i++) {
                    vertex.inEdge(i).setValue(broadcastValue);
                }
            }
            vertex.setValue(curValue);
        }
    }

    @Override
    public void beginIteration(GraphChiContext ctx) {
    }


    @Override
    public void beginInterval(GraphChiContext ctx, VertexInterval interval) {
    }

    @Override
    public void endInterval(GraphChiContext ctx, VertexInterval interval) {
    }



    public void endIteration(GraphChiContext ctx) {

    }

    @Override
    public void beginSubInterval(GraphChiContext ctx, VertexInterval interval) {
    }

    @Override
    public void endSubInterval(GraphChiContext ctx, VertexInterval interval) {
    }

    public void run(String graphName, int numShards) throws Exception {
        this.graphName = graphName;
        engine = new GraphChiEngine<FloatPair, Float>(graphName, numShards);
        engine.setEnableScheduler(false);
        engine.setSkipZeroDegreeVertices(true);
        engine.setEdataConverter(new FloatConverter());
        engine.setVertexDataConverter(new FloatPairConverter());
        engine.setMaxWindow(20000000);
        engine.run(this, 8);
    }

    private void outputResults(String graphName) throws IOException {

        VertexAggregator.foreach(engine.numVertices(), graphName, new FloatPairConverter(), new ForeachCallback<FloatPair>() {
            @Override
            public void callback(int vertexId, FloatPair vertexValue) {
                if (vertexValue.first > 0) {
                    System.out.println(engine.getVertexIdTranslate().backward(vertexId+ "\t" + vertexValue.first);
                }
            }
        });
    }

    /**
     ]     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws  Exception {
        int k = 0;
        String graphName = null;
        if (args.length == 2) graphName = args[k++];
        int nShards = Integer.parseInt(args[k++]);
        SALSASmallMem hits = new SALSASmallMem();

        if (graphName == null) {
            graphName = "pipein";
            FastSharder sharder = hits.createSharder(graphName, nShards);
            sharder.shard(System.in);
        }
        hits.run(graphName, nShards);

        hits.outputResults(graphName);
    }

    // PIG support


    @Override
    protected String getSchemaString() {
        return "(weight:float, vertex:int)";
    }

    @Override
    protected int getNumShards() {
        return numShards;
    }

    private ArrayList<IdFloat> results;
    private Iterator<IdFloat> resultIter;

    @Override
    protected void runGraphChi() throws Exception {
        run(getGraphName(), getNumShards());
        results = new ArrayList<IdFloat>(100000);

        // Collect results - into memory ... This may consume a lot of memory.
        // It would be better to have an iterator for the vertex data.
        VertexAggregator.foreach(engine.numVertices(), graphName, new FloatPairConverter(), new ForeachCallback<FloatPair>() {
            @Override
            public void callback(int vertexId, FloatPair vertexValue) {
                if (vertexValue.first > 0) {
                    results.add(new IdFloat(engine.getVertexIdTranslate().backward(vertexId), vertexValue.first));
                }
            }
        });
        engine = null;
        resultIter = results.iterator();
    }

    @Override
    protected FastSharder createSharder(String graphName, int numShards) throws IOException {
        this.numShards = numShards;
        return new FastSharder<FloatPair, Float>(graphName, numShards, new VertexProcessor<FloatPair>() {
            @Override
            /* For lists (hubs), the vertex value will encode the total number of edges */
            public FloatPair receiveVertexValue(int vertexId, String token) {
                return new FloatPair(0.0f, Float.parseFloat(token));
            }
        }, new EdgeProcessor<Float>() {
            @Override
            public Float receiveEdge(int from, int to, String token) {
                return Float.parseFloat(token);
            }
        }, new FloatPairConverter(), new FloatConverter());
    }



    @Override
    protected Tuple getNextResult(TupleFactory tupleFactory) throws ExecException {
        if (resultIter.hasNext()) {
            IdFloat res = resultIter.next();
            Tuple t = tupleFactory.newTuple(2);
            t.set(0, res.getValue());
            t.set(1, res.getVertexId());
            return t;
        } else {
            return null;
        }
    }

    @Override
    protected String getStatusString() {
        if (engine != null) {
            GraphChiContext ctx = engine.getContext();
            if (ctx != null) {
                return ctx.getCurInterval() + " iteration: " +  ctx.getIteration() + "/" + ctx.getNumIterations();
            }
        }
        return "Initializing";
    }
}
TOP

Related Classes of com.twitter.pers.bipartite.SALSASmallMem

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.