Friday, October 01, 2010

Custom Scoring with Lucene Payloads

Problem Definition

One of the commenters on one of my posts posed an interesting problem.

...we need to build a searching system, scoring the similarity only by the positions of the terms and with NO regard of TF/IDF. e.g. suppose we have q:{t1,t2,t3}, and documents:
d1:{t1,t3,t2,t4}
d2:{t1,t3}
d3:{t1,t2,t5,t3}
d4:{t4,t1,t2,t3}
then the result is to be d4>d3>d1>d2, just like a comparison among different bus paths and every single path consists of many stops, quite straightforward; however, we didn't find yet any lucene API fit for this job...

I have also been reading the Lucene in Action, Second Edition (LIA2) book recently, in an attempt to come up to speed with the many new features of Lucene. I also have in mind a solution for a similar problem here at work, so I decided that this would be a good opportunity to try some of the stuff I learned.

As I understand it, the comment translates to the following scoring rules. Given an index containing tour names and a list of points of interest (POI) where the tour bus stops, if a user enters a space separated list of points of interest, say {p1,p2,p3}, we should display in order:

  1. Tours which contain all points in {p1,p2,p3} sequence. Ordering is determined by how early the points occur in the document's POI List and the length of the list (matches in longer tours are less relevant than shorter ones).
  2. Tours which contain any point in the {p1,p2,p3} sequence. The more matches found, the more relevant.

Implementation

I tried initially with SpanQueries, creating an OR combination of an AND SpanQuery with a "slop" of (n-1) and no ordering requirement, where n is the number of points of interest in the query, and a set of SpanQueries for each point of interest. Its probably easier to see this in code, the JUnit test with this method and the private helper methods are shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Source: src/test/java/com/mycompany/poi/POIQueryTest.java
package com.mycompany.poi;

import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

import org.apache.commons.lang.StringUtils;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.WhitespaceAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Index;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Similarity;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.payloads.PayloadNearQuery;
import org.apache.lucene.search.payloads.PayloadTermQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.RAMDirectory;
import org.junit.Test;

import com.mycompany.poi.POIAnalyzer;
import com.mycompany.poi.POISimilarity;
import com.mycompany.poi.POISumPayloadFunction;

public class POIQueryTest {

  private static final boolean SHOW_EXPLANATION = false;
  
  private Directory directory;
  
  private String[] data = new String[] {
    "p1 p2 p3 p4",
    "p1 p2 p3",
    "p4 p1 p2 p3",
    "p2 p1 p3",
    "p1 p4 p2 p3",
    "p1 p2",
    "p4 p2",
    "p4 p5"
  };

  @Test
  public void testSpanQuery() throws Exception {
    index(new WhitespaceAnalyzer());
    String[] pois = StringUtils.split("p1 p2 p3", " ");
    SpanQuery[] baseQueries = new SpanQuery[pois.length];
    for (int i = 0; i < pois.length; i++) {
      baseQueries[i] = new SpanTermQuery(new Term("tour", pois[i]));
    }
    // first grab the cases where all points of interest are there
    // with a slop of (pois.length - 1) unordered
    SpanNearQuery allPois = new SpanNearQuery(baseQueries, 
      (pois.length - 1), false);
    allPois.setBoost(2.0F);
    // then append all the cases where at least one of the POIs appear
    SpanQuery[] poiQueries = new SpanQuery[pois.length + 1];
    poiQueries[0] = allPois;
    for (int i = 1; i < poiQueries.length; i++) {
      poiQueries[i] = baseQueries[i - 1];
    }
    SpanOrQuery query = new SpanOrQuery(poiQueries);
    printHits(query, "testSpanQuery", null, null, SHOW_EXPLANATION);    
  }
  
  private void index(Analyzer analyzer) throws Exception {
    directory = new RAMDirectory();
    IndexWriter writer = new IndexWriter(directory, 
      analyzer, IndexWriter.MaxFieldLength.UNLIMITED);
    for (int i = 0; i < data.length; i++) {
      Document doc = new Document();
      doc.add(new Field("title", "Tour #" + i, Store.YES, Index.NO));
      doc.add(new Field("tour", data[i], Store.YES, Index.ANALYZED));
      writer.addDocument(doc);
    }
    writer.close();
  }

  private void printHits(Query query, String testName, 
      Similarity similarity, Set<String> deduper, 
      boolean showExplanation) throws IOException {
    IndexSearcher searcher = new IndexSearcher(directory);
    if (similarity != null) {
      searcher.setSimilarity(similarity);
    }
    TopDocs topdocs = searcher.search(query, 10);
    ScoreDoc[] hits = topdocs.scoreDocs;
    System.out.println("==== Query: " + query.toString());
    System.out.println("==== Results for " + testName + " ====");
    for (int i = 0; i < hits.length; i++) {
      Document doc = searcher.doc(hits[i].doc);
      String title = doc.get("title");
      String tour = doc.get("tour");
      float score = hits[i].score;
      if (deduper != null) {
        if (deduper.contains(title)) {
          continue;
        } else {
          deduper.add(title);
        }
      }
      System.out.println(StringUtils.join(new String[] {
        String.valueOf(i),
        title,
        tour,
        String.valueOf(score)
      }, "  "));
      if (showExplanation) {
        System.out.println("EXPLANATION:" + 
          searcher.explain(query, hits[i].doc));
      }
    }
    searcher.close();
  }
}

This produces all the records of interest, but not in the correct order (per the application requirement). Not surprising since we are asking for spans, and unordered ones at that. As you can see, the scores on the top 4 results are identical - from Lucene's point of view, this is correct, since all of them contain the unordered span {p1,p2,p3} separated by a maximum of 2 (3 - 1) POIs. But notice that the OR part of the query performs adequately, ie, the score varies correctly with the number of POIs found in the document.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
==== Query: spanOr([spanNear([tour:p1, tour:p2, tour:p3], 2, false)^2.0, 
     tour:p1, tour:p2, tour:p3])
==== Results for testSpanQuery ====
0  Tour #0  p1 p2 p3 p4  2.2629201
1  Tour #1  p1 p2 p3  2.2629201
2  Tour #2  p4 p1 p2 p3  2.2629201
3  Tour #3  p2 p1 p3  2.2629201
4  Tour #4  p1 p4 p2 p3  2.2303584
5  Tour #5  p1 p2  2.1382585
6  Tour #6  p4 p2  1.511977

So I need to somehow tell Lucene to use our custom rules for the first (allPois) part of the query. I do this using another new (to me, its been out since 2009) Lucene feature - Payloads. Effectively, I compute a score for each POI in the POI list in a tour document that takes into consideration its distance from the beginning (position) and the number of POIs in the list (density) and write it into the document's payload. To do this, I need a custom TokenFilter that does the scoring:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// Source: src/main/java/com/mycompany/poi/POITokenFilter.java
package com.mycompany.poi;

import java.io.IOException;

import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.analysis.tokenattributes.TermAttribute;
import org.apache.lucene.index.Payload;

public class POITokenFilter extends TokenFilter {

  private TermAttribute termAttr;
  private PayloadAttribute payloadAttr;
  private int termPosition = 0;
  private float density = 0;
  
  protected POITokenFilter(TokenStream input, int numTokens) {
    super(input);
    this.termAttr = addAttribute(TermAttribute.class);
    this.payloadAttr = addAttribute(PayloadAttribute.class);
    if (numTokens > 0) {
      this.density = 1.0F / numTokens;
    }
  }

  @Override
  public boolean incrementToken() throws IOException {
    if (input.incrementToken()) {
      float score = (1.0F - 
        ((termPosition <= 9) ? (0.1F * (float) termPosition) : 0.1F)) * density;
//      System.out.println(termAttr.term() + "=>" + score);
      payloadAttr.setPayload(new Payload(PayloadHelper.encodeFloat(score)));
      termPosition++;
      return true;
    }
//    System.out.println("==");
    return false;
  }
}

and a custom Analyzer that invokes the TokenFilter and wraps a WhitespaceAnalyzer to chop the POI list into individual POIs.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// Source: src/main/java/com/mycompany/poi/POIAnalyzer.java
package com.mycompany.poi;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.WhitespaceTokenizer;

public class POIAnalyzer extends Analyzer {

  @Override
  public TokenStream tokenStream(String fieldName, Reader reader) {
    // since our input is likely to be small peices of text in
    // the "tour" field, we just convert Reader to String (we
    // need to compute the number of terms in this string for
    // scoring purposes), then pass the String through via a StringReader
    int numTokens = 0;
    int c = 0;
    StringBuilder buf = new StringBuilder();
    try {
      while ((c = reader.read()) != -1) {
        buf.append((char) c);
        if ((char) c == ' ') {
          numTokens++;
        }
      }
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
    return new POITokenFilter(
      new WhitespaceTokenizer(new StringReader(buf.toString())), numTokens + 1);
  }
}

After adding the payload score to the index, we also need a custom Similarity class that will read the payload scores during search and make it available to Lucene's scoring system. We've opted to add the payload scores to Lucene's TF/IDF scores, but we can also do away with the TF/IDF score entirely (this is done when creating the PayloadTermQuery, see the JUnit test below).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// Source: src/main/java/com/mycompany/poi/POISimilarity.java
package com.mycompany.poi;

import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.search.DefaultSimilarity;

public class POISimilarity extends DefaultSimilarity {

  private static final long serialVersionUID = -909003452363957475L;

  @Override
  public float scorePayload(int docId, String fieldName,
      int start, int end, byte[] payload, int offset, int length) {
    if (payload != null) {
      float score = PayloadHelper.decodeFloat(payload, offset);
      return score;
    } else {
      return 1.0F;
    }
  }
}

We also need to tell Lucene how to aggregate the payload scores across multiple matches in a document. So if a document matched the POIs {p1,p2}, each contributing its own payload score, the payload score returned to the scoring subsystem is by default the average of the individual scores. I guess I could have used that, but I wanted to get the sum of the scores, so I had to subclass the default AveragePayloadFunction (which is set in the PayloadTermQuery and PayloadNearQuery in the JUnit test below).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// Source: src/main/java/com/mycompany/poi/POISumPayloadFunction.java
package com.mycompany.poi;

import org.apache.lucene.search.payloads.AveragePayloadFunction;

public class POISumPayloadFunction extends AveragePayloadFunction {

  private static final long serialVersionUID = -3478867768985954830L;

  @Override
  public float docScore(int docId, String field, int numPayloadsSeen, 
      float payloadScore) {
    return numPayloadsSeen > 0 ? payloadScore : 1;
  }
}

Having built all these components, we are finally able to build a JUnit test that uses a PayloadNearQuery instead of a SpanNearQuery for the ALL part of the query. There is no implementation of a PayloadOrQuery available in Lucene yet so I could either build one, or just combine the results of the PayloadNearQuery and the SpanOrQuery containing the individual PayloadTermQuery objects for each POI in the application. I chose the latter approach even though it is suboptimal - it takes 2 calls to Lucene and is not particularly elegant, but I did not feel like figuring out how to build PayloadQuery subclasses right now - for the problem at work, I believe I would have to build PayloadAndQuery and PayloadOrQuery subclasses anyway, I will post details once I do that.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// Source: src/main/java/com/mycompany/poi/POIQueryTest.java (see above)
  ...
  @Test
  public void testPayloadSpanQuery() throws Exception {
    index(new POIAnalyzer());
    String[] pois = StringUtils.split("p1 p2 p3", " ");
    SpanQuery[] baseQueries = new SpanQuery[pois.length];
    for (int i = 0; i < pois.length; i++) {
      baseQueries[i] = new PayloadTermQuery(new Term("tour", pois[i]), 
        new POISumPayloadFunction());
    }
    // first grab the cases where all points of interest are there
    // with a slop of (pois.length - 1) unordered, and use our payload
    // scores to influence the ordering
    PayloadNearQuery allPois = new PayloadNearQuery(baseQueries, 
      (pois.length - 1), false, new POISumPayloadFunction());
    allPois.setBoost(2.0F);
    Set<String> deduper = new HashSet<String>();
    printHits(allPois, "testPayloadSpanQuery", new POISimilarity(), 
      deduper, SHOW_EXPLANATION);    
    // then backfill with the results from a SpanOrQuery filtering out
    // results that have already appeared (using deduper)
    SpanQuery[] poiQueries = new SpanQuery[pois.length + 1];
    poiQueries[0] = allPois;
    for (int i = 1; i < poiQueries.length; i++) {
      poiQueries[i] = baseQueries[i - 1];
    }
    SpanOrQuery query = new SpanOrQuery(poiQueries);
    printHits(query, "testBackfillSpanQuery", null, deduper, SHOW_EXPLANATION); 
  }
  ...

I get the following results from this code. As you can see, its the same set of results, but the results are ordered as per our application's requirements. You can also see the partition where the PayloadNearQuery ends and the SpanOrQuery starts.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
==== Query: payloadNear([tour:p1, tour:p2, tour:p3], 2, false)^2.0
==== Results for testPayloadSpanQuery ====
0  Tour #1  p1 p2 p3  0.7697731
1  Tour #3  p2 p1 p3  0.76977307
2  Tour #0  p1 p2 p3 p4  0.5773298
3  Tour #2  p4 p1 p2 p3  0.51318204
4  Tour #4  p1 p4 p2 p3  0.47812912
==== Query: spanOr([payloadNear([tour:p1, tour:p2, tour:p3], 2, false)^2.0, 
     tour:p1, tour:p2, tour:p3])
==== Results for testBackfillSpanQuery ====
5  Tour #5  p1 p2  2.1382585
6  Tour #6  p4 p2  1.511977

Resources

The following resources came in handy while developing this solution.

  • This Lucene wiki page contains an idea of how to use Payloads and SpanQueries - not sure if I had seen this before starting out, if I did, then this may have been the place I got the idea for this solution.
  • This tutorial on Payloads by Grant Ingersoll on the Lucid Imagination blog.
  • The LIA2 book, pages 225-230.

2 comments (moderated to prevent spam):

Anonymous said...

Great post. Very useful.

Sujit Pal said...

Thanks, glad you found it useful.