package org.apache.lucene.search.highlight;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.analysis.CachingTokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.index.FilterIndexReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.index.memory.MemoryIndex;
import org.apache.lucene.search.*;
import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.StringHelper;
/**
* Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether
* {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
*/
public class WeightedSpanTermExtractor {
private String fieldName;
private TokenStream tokenStream;
private Map<String,IndexReader> readers = new HashMap<String,IndexReader>(10);
private String defaultField;
private boolean expandMultiTermQuery;
private boolean cachedTokenStream;
private boolean wrapToCaching = true;
private int maxDocCharsToAnalyze;
public WeightedSpanTermExtractor() {
}
public WeightedSpanTermExtractor(String defaultField) {
if (defaultField != null) {
this.defaultField = StringHelper.intern(defaultField);
}
}
private void closeReaders() {
Collection<IndexReader> readerSet = readers.values();
for (final IndexReader reader : readerSet) {
try {
reader.close();
} catch (IOException e) {
// alert?
}
}
}
/**
* Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
*
* @param query
* Query to extract Terms from
* @param terms
* Map to place created WeightedSpanTerms in
* @throws IOException
*/
protected void extract(Query query, Map<String,WeightedSpanTerm> terms) throws IOException {
if (query instanceof BooleanQuery) {
BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();
for (int i = 0; i < queryClauses.length; i++) {
if (!queryClauses[i].isProhibited()) {
extract(queryClauses[i].getQuery(), terms);
}
}
} else if (query instanceof PhraseQuery) {
PhraseQuery phraseQuery = ((PhraseQuery) query);
Term[] phraseQueryTerms = phraseQuery.getTerms();
SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
for (int i = 0; i < phraseQueryTerms.length; i++) {
clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
}
int slop = phraseQuery.getSlop();
int[] positions = phraseQuery.getPositions();
// add largest position increment to slop
if (positions.length > 0) {
int lastPos = positions[0];
int largestInc = 0;
int sz = positions.length;
for (int i = 1; i < sz; i++) {
int pos = positions[i];
int inc = pos - lastPos;
if (inc > largestInc) {
largestInc = inc;
}
lastPos = pos;
}
if(largestInc > 1) {
slop += largestInc;
}
}
boolean inorder = false;
if (slop == 0) {
inorder = true;
}
SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
sp.setBoost(query.getBoost());
extractWeightedSpanTerms(terms, sp);
} else if (query instanceof TermQuery) {
extractWeightedTerms(terms, query);
} else if (query instanceof SpanQuery) {
extractWeightedSpanTerms(terms, (SpanQuery) query);
} else if (query instanceof FilteredQuery) {
extract(((FilteredQuery) query).getQuery(), terms);
} else if (query instanceof DisjunctionMaxQuery) {
for (Iterator<Query> iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
extract(iterator.next(), terms);
}
} else if (query instanceof MultiTermQuery && expandMultiTermQuery) {
MultiTermQuery mtq = ((MultiTermQuery)query);
if(mtq.getRewriteMethod() != MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE) {
mtq = (MultiTermQuery) mtq.clone();
mtq.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
query = mtq;
}
FakeReader fReader = new FakeReader();
MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE.rewrite(fReader, mtq);
if (fReader.field != null) {
IndexReader ir = getReaderForField(fReader.field);
extract(query.rewrite(ir), terms);
}
} else if (query instanceof MultiPhraseQuery) {
final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
final List<Term[]> termArrays = mpq.getTermArrays();
final int[] positions = mpq.getPositions();
if (positions.length > 0) {
int maxPosition = positions[positions.length - 1];
for (int i = 0; i < positions.length - 1; ++i) {
if (positions[i] > maxPosition) {
maxPosition = positions[i];
}
}
@SuppressWarnings("unchecked")
final List<SpanQuery>[] disjunctLists = new List[maxPosition + 1];
int distinctPositions = 0;
for (int i = 0; i < termArrays.size(); ++i) {
final Term[] termArray = termArrays.get(i);
List<SpanQuery> disjuncts = disjunctLists[positions[i]];
if (disjuncts == null) {
disjuncts = (disjunctLists[positions[i]] = new ArrayList<SpanQuery>(termArray.length));
++distinctPositions;
}
for (int j = 0; j < termArray.length; ++j) {
disjuncts.add(new SpanTermQuery(termArray[j]));
}
}
int positionGaps = 0;
int position = 0;
final SpanQuery[] clauses = new SpanQuery[distinctPositions];
for (int i = 0; i < disjunctLists.length; ++i) {
List<SpanQuery> disjuncts = disjunctLists[i];
if (disjuncts != null) {
clauses[position++] = new SpanOrQuery(disjuncts
.toArray(new SpanQuery[disjuncts.size()]));
} else {
++positionGaps;
}
}
final int slop = mpq.getSlop();
final boolean inorder = (slop == 0);
SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
sp.setBoost(query.getBoost());
extractWeightedSpanTerms(terms, sp);
}
}
extractUnknownQuery(query, terms);
}
protected void extractUnknownQuery(Query query,
Map<String, WeightedSpanTerm> terms) throws IOException {
// for sub-classing to extract custom queries
}
/**
* Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>SpanQuery</code>.
*
* @param terms
* Map to place created WeightedSpanTerms in
* @param spanQuery
* SpanQuery to extract Terms from
* @throws IOException
*/
protected void extractWeightedSpanTerms(Map<String,WeightedSpanTerm> terms, SpanQuery spanQuery) throws IOException {
Set<String> fieldNames;
if (fieldName == null) {
fieldNames = new HashSet<String>();
collectSpanQueryFields(spanQuery, fieldNames);
} else {
fieldNames = new HashSet<String>(1);
fieldNames.add(fieldName);
}
// To support the use of the default field name
if (defaultField != null) {
fieldNames.add(defaultField);
}
Map<String, SpanQuery> queries = new HashMap<String, SpanQuery>();
Set<Term> nonWeightedTerms = new HashSet<Term>();
final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
if (mustRewriteQuery) {
for (final String field : fieldNames) {
final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getReaderForField(field));
queries.put(field, rewrittenQuery);
rewrittenQuery.extractTerms(nonWeightedTerms);
}
} else {
spanQuery.extractTerms(nonWeightedTerms);
}
List<PositionSpan> spanPositions = new ArrayList<PositionSpan>();
for (final String field : fieldNames) {
IndexReader reader = getReaderForField(field);
final Spans spans;
if (mustRewriteQuery) {
spans = queries.get(field).getSpans(reader);
} else {
spans = spanQuery.getSpans(reader);
}
// collect span positions
while (spans.next()) {
spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
}
}
if (spanPositions.size() == 0) {
// no spans found
return;
}
for (final Term queryTerm : nonWeightedTerms) {
if (fieldNameComparator(queryTerm.field())) {
WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());
if (weightedSpanTerm == null) {
weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
weightedSpanTerm.addPositionSpans(spanPositions);
weightedSpanTerm.positionSensitive = true;
terms.put(queryTerm.text(), weightedSpanTerm);
} else {
if (spanPositions.size() > 0) {
weightedSpanTerm.addPositionSpans(spanPositions);
}
}
}
}
}
/**
* Fills a <code>Map</code> with <@link WeightedSpanTerm>s using the terms from the supplied <code>Query</code>.
*
* @param terms
* Map to place created WeightedSpanTerms in
* @param query
* Query to extract Terms from
* @throws IOException
*/
protected void extractWeightedTerms(Map<String,WeightedSpanTerm> terms, Query query) throws IOException {
Set<Term> nonWeightedTerms = new HashSet<Term>();
query.extractTerms(nonWeightedTerms);
for (final Term queryTerm : nonWeightedTerms) {
if (fieldNameComparator(queryTerm.field())) {
WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
terms.put(queryTerm.text(), weightedSpanTerm);
}
}
}
/**
* Necessary to implement matches for queries against <code>defaultField</code>
*/
protected boolean fieldNameComparator(String fieldNameToCheck) {
boolean rv = fieldName == null || fieldNameToCheck == fieldName
|| fieldNameToCheck == defaultField;
return rv;
}
protected IndexReader getReaderForField(String field) throws IOException {
if(wrapToCaching && !cachedTokenStream && !(tokenStream instanceof CachingTokenFilter)) {
tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
cachedTokenStream = true;
}
IndexReader reader = readers.get(field);
if (reader == null) {
MemoryIndex indexer = new MemoryIndex();
indexer.addField(field, new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
tokenStream.reset();
IndexSearcher searcher = indexer.createSearcher();
reader = searcher.getIndexReader();
readers.put(field, reader);
}
return reader;
}
/**
* Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
*
* <p>
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @return Map containing WeightedSpanTerms
* @throws IOException
*/
public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream)
throws IOException {
return getWeightedSpanTerms(query, tokenStream, null);
}
/**
* Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>.
*
* <p>
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @param fieldName
* restricts Term's used based on field name
* @return Map containing WeightedSpanTerms
* @throws IOException
*/
public Map<String,WeightedSpanTerm> getWeightedSpanTerms(Query query, TokenStream tokenStream,
String fieldName) throws IOException {
if (fieldName != null) {
this.fieldName = StringHelper.intern(fieldName);
} else {
this.fieldName = null;
}
Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
this.tokenStream = tokenStream;
try {
extract(query, terms);
} finally {
closeReaders();
}
return terms;
}
/**
* Creates a Map of <code>WeightedSpanTerms</code> from the given <code>Query</code> and <code>TokenStream</code>. Uses a supplied
* <code>IndexReader</code> to properly weight terms (for gradient highlighting).
*
* <p>
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @param fieldName
* restricts Term's used based on field name
* @param reader
* to use for scoring
* @return Map of WeightedSpanTerms with quasi tf/idf scores
* @throws IOException
*/
public Map<String,WeightedSpanTerm> getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream, String fieldName,
IndexReader reader) throws IOException {
if (fieldName != null) {
this.fieldName = StringHelper.intern(fieldName);
} else {
this.fieldName = null;
}
this.tokenStream = tokenStream;
Map<String,WeightedSpanTerm> terms = new PositionCheckingMap<String>();
extract(query, terms);
int totalNumDocs = reader.numDocs();
Set<String> weightedTerms = terms.keySet();
Iterator<String> it = weightedTerms.iterator();
try {
while (it.hasNext()) {
WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
// docFreq counts deletes
if(totalNumDocs < docFreq) {
docFreq = totalNumDocs;
}
// IDF algorithm taken from DefaultSimilarity class
float idf = (float) (Math.log((float) totalNumDocs / (double) (docFreq + 1)) + 1.0);
weightedSpanTerm.weight *= idf;
}
} finally {
closeReaders();
}
return terms;
}
protected void collectSpanQueryFields(SpanQuery spanQuery, Set<String> fieldNames) {
if (spanQuery instanceof FieldMaskingSpanQuery) {
collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames);
} else if (spanQuery instanceof SpanFirstQuery) {
collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames);
} else if (spanQuery instanceof SpanNearQuery) {
for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
collectSpanQueryFields(clause, fieldNames);
}
} else if (spanQuery instanceof SpanNotQuery) {
collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames);
} else if (spanQuery instanceof SpanOrQuery) {
for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
collectSpanQueryFields(clause, fieldNames);
}
} else {
fieldNames.add(spanQuery.getField());
}
}
protected boolean mustRewriteQuery(SpanQuery spanQuery) {
if (!expandMultiTermQuery) {
return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
} else if (spanQuery instanceof FieldMaskingSpanQuery) {
return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery());
} else if (spanQuery instanceof SpanFirstQuery) {
return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch());
} else if (spanQuery instanceof SpanNearQuery) {
for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
if (mustRewriteQuery(clause)) {
return true;
}
}
return false;
} else if (spanQuery instanceof SpanNotQuery) {
SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery;
return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
} else if (spanQuery instanceof SpanOrQuery) {
for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
if (mustRewriteQuery(clause)) {
return true;
}
}
return false;
} else if (spanQuery instanceof SpanTermQuery) {
return false;
} else {
return true;
}
}
/**
* This class makes sure that if both position sensitive and insensitive
* versions of the same term are added, the position insensitive one wins.
*/
@SuppressWarnings("serial")
protected static class PositionCheckingMap<K> extends HashMap<K,WeightedSpanTerm> {
@Override
public void putAll(Map<? extends K,? extends WeightedSpanTerm> m) {
for (Map.Entry<? extends K,? extends WeightedSpanTerm> entry : m.entrySet())
this.put(entry.getKey(), entry.getValue());
}
@Override
public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
WeightedSpanTerm prev = super.put(key, value);
if (prev == null) return prev;
WeightedSpanTerm prevTerm = prev;
WeightedSpanTerm newTerm = value;
if (!prevTerm.positionSensitive) {
newTerm.positionSensitive = false;
}
return prev;
}
}
public boolean getExpandMultiTermQuery() {
return expandMultiTermQuery;
}
public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
this.expandMultiTermQuery = expandMultiTermQuery;
}
public boolean isCachedTokenStream() {
return cachedTokenStream;
}
public TokenStream getTokenStream() {
return tokenStream;
}
/**
* By default, {@link TokenStream}s that are not of the type
* {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
* ensure an efficient reset - if you are already using a different caching
* {@link TokenStream} impl and you don't want it to be wrapped, set this to
* false.
*
* @param wrap
*/
public void setWrapIfNotCachingTokenFilter(boolean wrap) {
this.wrapToCaching = wrap;
}
/**
*
* A fake IndexReader class to extract the field from a MultiTermQuery
*
*/
static final class FakeReader extends FilterIndexReader {
private static final IndexReader EMPTY_MEMORY_INDEX_READER =
new MemoryIndex().createSearcher().getIndexReader();
String field;
FakeReader() {
super(EMPTY_MEMORY_INDEX_READER);
}
@Override
public TermEnum terms(final Term t) throws IOException {
// only set first fieldname, maybe use a Set?
if (t != null && field == null)
field = t.field();
return super.terms(t);
}
}
protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;
}
}