Package net.bpiwowar.mg4j.extensions.adhoc

Source Code of net.bpiwowar.mg4j.extensions.adhoc.RelevanceModelScorer$Visitor

package net.bpiwowar.mg4j.extensions.adhoc;


import it.unimi.di.big.mg4j.index.Index;
import it.unimi.di.big.mg4j.index.IndexIterator;
import it.unimi.di.big.mg4j.search.AbstractIntersectionDocumentIterator;
import it.unimi.di.big.mg4j.search.DocumentIterator;
import it.unimi.di.big.mg4j.search.score.AbstractWeightedScorer;
import it.unimi.di.big.mg4j.search.score.DelegatingScorer;
import it.unimi.di.big.mg4j.search.visitor.AbstractDocumentIteratorVisitor;
import it.unimi.di.big.mg4j.search.visitor.CounterSetupVisitor;
import it.unimi.di.big.mg4j.search.visitor.TermCollectionVisitor;
import it.unimi.dsi.fastutil.ints.IntBigList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import org.apache.log4j.Logger;

import java.io.IOException;
import java.util.Arrays;

/** A scorer that implements the relevance model scorer
*
* @author B. Piwowarski
*/
public class RelevanceModelScorer extends AbstractWeightedScorer implements DelegatingScorer {
  public static final Logger LOGGER = Logger.getLogger( RelevanceModelScorer.class );
  public static final boolean DEBUG = true;

  private static final class Visitor extends AbstractDocumentIteratorVisitor {
    /** Offset-indexed precomputed values. */
    private final double[] k1Plus1TimesWeightedIdfPart;
    /** Offset-indexed precomputed values. */
    private final double k1Times1MinusB;
    /** An array (parallel to {@link TermCollectionVisitor#indices()} that caches average document sizes. */
    private final double k1TimesBDividedByAverageDocumentSize[];
    /** An array (parallel to {@link TermCollectionVisitor#indices()} that caches size lists. */
    private final IntBigList sizes[];
    /** Cached from {@link RelevanceModelScorer}. */
    private final double[] sizeComponent;
    /** Cached from {@link CounterSetupVisitor}. */
    private final int[] indexNumber;
    /** The length of {@link TermCollectionVisitor#indices()} cached. */
    private final int numberOfIndices;
    /** An array parallel to {@link #indexNumber} keeping track of whether we already accumulated the score for a specific term/index pair. */
    private final boolean[] seen;
    /** An array accumulating the indices in {@link #seen} that have been set to true, so to accelerate {@link #reset(int)}. */
    private final int[] seenList;
    /** The accumulated score. */
    public double score;
    /** The number of valid entries in {@link #seenList}. */
    private int numberOfSeen;
   
    public Visitor( final double k1Times1Minusb, final double[] k1Plus1TimesWeightedIdfPart, final double[] k1TimesBDividedByAverageDocumentSize, final int numberOfIndices, final int[] indexNumber, final IntBigList[] sizes ) {
      this.k1Times1MinusB = k1Times1Minusb;
      this.k1Plus1TimesWeightedIdfPart = k1Plus1TimesWeightedIdfPart;
      this.k1TimesBDividedByAverageDocumentSize = k1TimesBDividedByAverageDocumentSize;
      this.sizeComponent = new double[ numberOfIndices ];
      this.numberOfIndices = numberOfIndices;
      this.indexNumber = indexNumber;
      this.seen = new boolean[ indexNumber.length ];
      this.seenList = new int[ indexNumber.length ];
      this.sizes = sizes;
    }

    public Boolean visit( final IndexIterator indexIterator ) throws IOException {
      final int offset = indexIterator.id();
      if ( ! seen[ offset ] ) {
        seen[ seenList[ numberOfSeen++ ] = offset ] = true;
        final int count = indexIterator.count();
        score += ( count * k1Plus1TimesWeightedIdfPart[ offset ] ) / ( count + sizeComponent[ indexNumber[ offset ] ] );
      }
      return Boolean.TRUE;
    }
   
    public void reset( final long document ) {
      score = 0;
      // Clear seen information (on the first invocation does nothing as numberOfSeen == 0 ).
      while( numberOfSeen-- != 0 ) seen[ seenList[ numberOfSeen ] ] = false;
      numberOfSeen = 0;

      for( int i = numberOfIndices; i-- != 0; ) sizeComponent[ i ] = k1Times1MinusB + k1TimesBDividedByAverageDocumentSize[ i ] * sizes[ i ].getInt( document );
    }
  }

 
  /** The default value used for the parameter <var>k</var><sub>1</sub>. */
  public final static double DEFAULT_K1 = 1.2;
  /** The default value used for the parameter <var>b</var>. */
  public final static double DEFAULT_B = 0.5;
  /** The value of the document-frequency part for terms appearing in more than half of the documents. */
  public final static double EPSILON_SCORE = 1.0E-6;
  /** Disjunctive queries on {@linkplain IndexIterator index iterators} are handled using the flat evaluator only if they contain less than
   * this number of disjuncts. The generic evaluator is more efficient if there are several disjuncts, as it
   * invokes {@link IndexIterator#count()} only on the terms that are part of the front. This value is largely architecture, query,
   * term-distribution, and whatever else dependent. */
  public static final int MAX_FLAT_DISJUNCTS = 16;
 
  /** The counter setup visitor used to estimate counts. */
  private final CounterSetupVisitor setupVisitor;
  /** The term collection visitor used to estimate counts. */
  private final TermCollectionVisitor termVisitor;

  /** The parameter <var>k</var><sub>1</sub>. */
  private final double k1;
  /** The parameter <var>b</var>. */
  private final double b;

  /** The parameter {@link #k1} multiplied by one minus {@link #b}, precomputed. */
  private final double k1Times1MinusB;
  /** A value precomputed for flat evaluation. */
  private double k1TimesBDividedByAverageDocumentSize;
  /** The list of sizes, cached for flat evaluation. */
  private IntBigList sizes;
  /** An array indexed by offsets that caches the inverse document-frequency part of the formula, multiplied by the index weight, cached for flat evaluation. */
  private double[] k1Plus1TimesWeightedIdfPart;
  /** The value of {@link TermCollectionVisitor#numberOfPairs()} cached, if {@link #indexIterator} is <code>null</code>. */
  private int numberOfPairs;
  /** An array of nonzero-frequency index iterators, all on the same index, used by the flat evaluator, or <code>null</code> for generic evaluation. */
  private IndexIterator[] flatIndexIterator;
  /** A visitor used by the generic evaluator. */
  private Visitor visitor;

  /** Creates a BM25 scorer using {@link #DEFAULT_K1} and {@link #DEFAULT_B} as parameters.
   */
  public RelevanceModelScorer() {
    this( DEFAULT_K1, DEFAULT_B );
  }

  /** Creates a BM25 scorer using specified <var>k</var><sub>1</sub> and <var>b</var> parameters.
   * @param k1 the <var>k</var><sub>1</sub> parameter.
   * @param b the <var>b</var> parameter.
   */
  public RelevanceModelScorer(final double k1, final double b) {
    termVisitor = new TermCollectionVisitor();
    setupVisitor = new CounterSetupVisitor( termVisitor );
    this.k1 = k1;
    this.b = b;
    k1Times1MinusB = k1 * ( 1 - b );
  }

  /** Creates a BM25 scorer using specified <var>k</var><sub>1</sub> and <var>b</var> parameters specified by strings.
   *
   * @param k1 the <var>k</var><sub>1</sub> parameter.
   * @param b the <var>b</var> parameter.
   */
  public RelevanceModelScorer(final String k1, final String b) {
    this( Double.parseDouble( k1 ), Double.parseDouble( b ) );
  }

  public synchronized RelevanceModelScorer copy() {
    final RelevanceModelScorer scorer = new RelevanceModelScorer( k1, b );
    scorer.setWeights( index2Weight );
    return scorer;
  }

  public double score() throws IOException {
   
    final long document = documentIterator.document();

    if ( flatIndexIterator == null ) {
      visitor.reset( document );
      documentIterator.acceptOnTruePaths( visitor );
      return visitor.score;
    }
    else {
      final double sizeComponent = k1Times1MinusB + k1TimesBDividedByAverageDocumentSize * sizes.getInt( document );
      double score = 0;
      final double[] k1Plus1TimesWeightedIdfPart = this.k1Plus1TimesWeightedIdfPart;
      final IndexIterator[] actualIndexIterator = this.flatIndexIterator;

      for ( int i = numberOfPairs; i-- != 0; )
        if ( actualIndexIterator[ i ].document() == document ) {
          final int c = actualIndexIterator[ i ].count();
          score += ( c * k1Plus1TimesWeightedIdfPart[ i ] ) / ( c + sizeComponent );
        }
      return score;
    }
  }

  public double score( final Index index ) {
    throw new UnsupportedOperationException();
  }


  public void wrap( DocumentIterator d ) throws IOException {
    super.wrap( d );

    /* Note that we use the index array provided by the weight function, *not* by the visitor or by the iterator.
     * If the function has an empty domain, this call is equivalent to prepare(). */
    termVisitor.prepare( index2Weight.keySet() );
   
    d.accept( termVisitor );

    if ( DEBUG ) LOGGER.debug( "Term Visitor found " + termVisitor.numberOfPairs() + " leaves" );

    // Note that we use the index array provided by the visitor, *not* by the iterator.
    final Index[] index = termVisitor.indices();

    if ( DEBUG ) LOGGER.debug( "Indices: " + Arrays.toString( index ) );

    flatIndexIterator = null;
   
    /* We use the flat evaluator only for single-index, term-only queries that are either quite small, and
     * then either conjunctive, or disjunctive with a reasonable number of terms. */

    if ( indexIterator != null && index.length == 1 && ( documentIterator instanceof AbstractIntersectionDocumentIterator || indexIterator.length < MAX_FLAT_DISJUNCTS ) ) {
      /* This code is a flat, simplified duplication of what a CounterSetupVisitor would do. It is here just for efficiency. */
      numberOfPairs = 0;
      /* Find duplicate terms. We score unique pairs term/index with nonzero frequency, as the standard method would do. */
      final LongOpenHashSet alreadySeen = new LongOpenHashSet();

      for( int i = indexIterator.length; i-- != 0; )
        if ( indexIterator[ i ].frequency() != 0 && alreadySeen.add( indexIterator[ i ].termNumber() ) ) numberOfPairs++;

      if ( numberOfPairs == indexIterator.length ) flatIndexIterator = indexIterator;
      else {
        /* We must compact the array, eliminating zero-frequency iterators. */
        flatIndexIterator = new IndexIterator[ numberOfPairs ];
        alreadySeen.clear();
        for( int i = 0, p = 0; i != indexIterator.length; i++ )
          if ( indexIterator[ i ].frequency() != 0 &&  alreadySeen.add( indexIterator[ i ].termNumber() ) ) flatIndexIterator[ p++ ] = indexIterator[ i ];
      }

      if ( flatIndexIterator.length != 0 ) {
        // Some caching of frequently-used values
        k1TimesBDividedByAverageDocumentSize = k1 * b * flatIndexIterator[ 0 ].index().numberOfDocuments / flatIndexIterator[ 0 ].index().numberOfOccurrences;
        if ( ( this.sizes = flatIndexIterator[ 0 ].index().sizes ) == null ) throw new IllegalStateException( "A BM25 scorer requires document sizes" );

        // We do all logs here, and multiply by the weight
        k1Plus1TimesWeightedIdfPart = new double[ numberOfPairs ];
        for( int i = 0; i < numberOfPairs; i++ ) {
          final long frequency = flatIndexIterator[ i ].frequency();
          k1Plus1TimesWeightedIdfPart[ i ] = ( k1 + 1 ) * Math.max( EPSILON_SCORE, 
              Math.log( ( flatIndexIterator[ i ].index().numberOfDocuments - frequency + 0.5 ) / ( frequency + 0.5 ) ) ) * index2Weight.getDouble( flatIndexIterator[ i ].index() );
        }
      }
    }
    else {
      // Some caching of frequently-used values
      final double[] k1TimesBDividedByAverageDocumentSize = new double[ index.length ];
      for ( int i = index.length; i-- != 0; )
        k1TimesBDividedByAverageDocumentSize[ i ] = k1 * b * index[ i ].numberOfDocuments / index[ i ].numberOfOccurrences;

      if ( DEBUG ) LOGGER.debug( "Average document sizes: " + Arrays.toString( k1TimesBDividedByAverageDocumentSize ) );
      final IntBigList[] sizes = new IntBigList[ index.length ];
      for( int i = index.length; i-- != 0; )
        if ( ( sizes[ i ] = index[ i ].sizes ) == null ) throw new IllegalStateException( "A BM25 scorer requires document sizes" );
     
      setupVisitor.prepare();
      d.accept( setupVisitor );
      numberOfPairs = termVisitor.numberOfPairs();
      final long[] frequency = setupVisitor.frequency;
      final int[] indexNumber = setupVisitor.indexNumber;

      // We do all logs here, and multiply by the weight
      k1Plus1TimesWeightedIdfPart = new double[ frequency.length ];
      for( int i = k1Plus1TimesWeightedIdfPart.length; i-- != 0; )
        k1Plus1TimesWeightedIdfPart[ i ] = ( k1 + 1 ) * Math.max( EPSILON_SCORE, 
            Math.log( ( index[ indexNumber[ i ] ].numberOfDocuments - frequency[ i ] + 0.5 ) / ( frequency[ i ] + 0.5 ) ) ) * index2Weight.getDouble( index[ indexNumber[ i ] ] );

      visitor = new Visitor( k1Times1MinusB, k1Plus1TimesWeightedIdfPart, k1TimesBDividedByAverageDocumentSize, termVisitor.indices().length, indexNumber, sizes );
    }

  }
 
  public boolean usesIntervals() {
    return false;
  }

}
TOP

Related Classes of net.bpiwowar.mg4j.extensions.adhoc.RelevanceModelScorer$Visitor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.