/**
* Copyright 2014 National University of Ireland, Galway.
*
* This file is part of the SIREn project. Project and contact information:
*
* https://github.com/rdelbru/SIREn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sindice.siren.search.node;
import java.io.IOException;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.MultiTermQuery.ConstantScoreAutoRewrite;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ByteBlockPool;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.BytesRefHash.DirectBytesStartArray;
import org.apache.lucene.util.RamUsageEstimator;
import org.sindice.siren.search.node.NodeBooleanClause.Occur;
/**
* A rewrite method that tries to pick the best constant-score rewrite method
* based on term and document counts from the query.
*
* <p>
*
* Term and document cutoffs are deactivated to disable {@link
* MultiNodeTermQuery#CONSTANT_SCORE_FILTER_REWRITE}.
*
* <p>
*
* Code taken from {@link ConstantScoreAutoRewrite} and adapted for SIREn.
*/
class NodeConstantScoreAutoRewrite extends NodeTermCollectingRewrite<NodeBooleanQuery> {
// Term cutoff deactivated until a efficient filter-based approach is found
public static int DEFAULT_TERM_COUNT_CUTOFF = Integer.MAX_VALUE;
// Document cutoff deactivated until a efficient filter-based approach is found
public static double DEFAULT_DOC_COUNT_PERCENT = Integer.MAX_VALUE;
private int termCountCutoff = DEFAULT_TERM_COUNT_CUTOFF;
private double docCountPercent = DEFAULT_DOC_COUNT_PERCENT;
/** If the number of terms in this query is equal to or
* larger than this setting then {@link
* #CONSTANT_SCORE_FILTER_REWRITE} is used. */
public void setTermCountCutoff(final int count) {
termCountCutoff = count;
}
/** @see #setTermCountCutoff */
public int getTermCountCutoff() {
return termCountCutoff;
}
/** If the number of documents to be visited in the
* postings exceeds this specified percentage of the
* maxDoc() for the index, then {@link
* #CONSTANT_SCORE_FILTER_REWRITE} is used.
* @param percent 0.0 to 100.0 */
public void setDocCountPercent(final double percent) {
docCountPercent = percent;
}
/** @see #setDocCountPercent */
public double getDocCountPercent() {
return docCountPercent;
}
@Override
protected NodeBooleanQuery getTopLevelQuery() {
return new NodeBooleanQuery();
}
@Override
protected void addClause(final NodeBooleanQuery topLevel, final Term term,
final int docFreq, final float boost /*ignored*/,
final TermContext states) {
topLevel.add(new NodeTermQuery(term, states), Occur.SHOULD);
}
@Override
public Query rewrite(final IndexReader reader, final MultiNodeTermQuery query) throws IOException {
// Disabled cutoffs
final int docCountCutoff = Integer.MAX_VALUE;
final int termCountLimit = Integer.MAX_VALUE;
final CutOffTermCollector col = new CutOffTermCollector(docCountCutoff, termCountLimit);
this.collectTerms(reader, query, col);
final int size = col.pendingTerms.size();
if (col.hasCutOff) {
return MultiNodeTermQuery.CONSTANT_SCORE_FILTER_REWRITE.rewrite(reader, query);
} else if (size == 0) {
return this.getTopLevelQuery();
} else {
final NodeBooleanQuery bq = this.getTopLevelQuery();
final BytesRefHash pendingTerms = col.pendingTerms;
final int sort[] = pendingTerms.sort(col.termsEnum.getComparator());
for(int i = 0; i < size; i++) {
final int pos = sort[i];
// docFreq is not used for constant score here, we pass 1
// to explicitely set a fake value, so it's not calculated
this.addClause(bq,
new Term(query.field, pendingTerms.get(pos, new BytesRef())),
1, 1.0f, col.array.termState[pos]);
}
// Strip scores
final NodeQuery result = new NodeConstantScoreQuery(bq);
result.setBoost(query.getBoost());
// set level and node constraints
result.setLevelConstraint(query.getLevelConstraint());
result.setNodeConstraint(query.getNodeConstraint()[0], query.getNodeConstraint()[1]);
// set ancestor
result.setAncestorPointer(query.ancestor);
return result;
}
}
static final class CutOffTermCollector extends TermCollector {
CutOffTermCollector(final int docCountCutoff, final int termCountLimit) {
this.docCountCutoff = docCountCutoff;
this.termCountLimit = termCountLimit;
}
@Override
public void setNextEnum(final TermsEnum termsEnum) throws IOException {
this.termsEnum = termsEnum;
}
@Override
public boolean collect(final BytesRef bytes) throws IOException {
int pos = pendingTerms.add(bytes);
docVisitCount += termsEnum.docFreq();
if (pendingTerms.size() >= termCountLimit || docVisitCount >= docCountCutoff) {
hasCutOff = true;
return false;
}
final TermState termState = termsEnum.termState();
assert termState != null;
if (pos < 0) {
pos = (-pos)-1;
array.termState[pos].register(termState, readerContext.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
} else {
array.termState[pos] = new TermContext(topReaderContext, termState, readerContext.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
}
return true;
}
int docVisitCount = 0;
boolean hasCutOff = false;
TermsEnum termsEnum;
final int docCountCutoff, termCountLimit;
final TermStateByteStart array = new TermStateByteStart(16);
final BytesRefHash pendingTerms = new BytesRefHash(new ByteBlockPool(new ByteBlockPool.DirectAllocator()), 16, array);
}
@Override
public int hashCode() {
final int prime = 1279;
return (int) (prime * termCountCutoff + Double.doubleToLongBits(docCountPercent));
}
@Override
public boolean equals(final Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (this.getClass() != obj.getClass())
return false;
final NodeConstantScoreAutoRewrite other = (NodeConstantScoreAutoRewrite) obj;
if (other.termCountCutoff != termCountCutoff) {
return false;
}
if (Double.doubleToLongBits(other.docCountPercent) != Double.doubleToLongBits(docCountPercent)) {
return false;
}
return true;
}
/**
* Special implementation of BytesStartArray that keeps parallel arrays for
* {@link TermContext}
*/
static final class TermStateByteStart extends DirectBytesStartArray {
TermContext[] termState;
public TermStateByteStart(final int initSize) {
super(initSize);
}
@Override
public int[] init() {
final int[] ord = super.init();
termState = new TermContext[ArrayUtil.oversize(ord.length,
RamUsageEstimator.NUM_BYTES_OBJECT_REF)];
assert termState.length >= ord.length;
return ord;
}
@Override
public int[] grow() {
final int[] ord = super.grow();
if (termState.length < ord.length) {
final TermContext[] tmpTermState = new TermContext[ArrayUtil.oversize(ord.length,
RamUsageEstimator.NUM_BYTES_OBJECT_REF)];
System.arraycopy(termState, 0, tmpTermState, 0, termState.length);
termState = tmpTermState;
}
assert termState.length >= ord.length;
return ord;
}
@Override
public int[] clear() {
termState = null;
return super.clear();
}
}
}