Thursday, April 12, 2012

Image Classification (Photo or Drawing) using Weka

Some time back, I was asked if there was a simple way to automatically classify images as either photographs or drawings. I had initially thought this would involve some complex image processing, but the idea presented in this paper - A Statistical Combined Classifier and its Application to Region and Image Classification (PDF) by Steven J Simske - shows that the problem can be reduced to something similar to a bag of words model commonly used in text classification.

Consider the two images shown below. Clearly (to a human), the one on the left is a photograph, and the second is a chart (or drawing). The main thing that jumps out is how "soft" the color gradations are in the first one compared to the second. For ease of computation, we reduce the RGB values for each pixels to 256 grayscale values using the formula on this page. The corresponding black and white version of the images are shown on the second row.

The third row shows the corresponding histogram plots for the percentage distribution of the pixels across all the 256 gray scale values. This is the reduction suggested by the paper referenced above. Effectively, the image has turned into a "document" which uses a vocabulary of 256 "terms". Since the images can have different sizes (and hence different number of pixels), we normalize the counts to percentages so histograms for different images are comparable to each other.

Notice that the histogram for the photo is shorter and wider and in general smoother than the one for the drawing (the y-axis scale for the first is 0-3 and that for the second is 0-30). The paper describes three features that the authors used for their experiments:

  1. Pct0.5 - percent of the histogram bins with >0.5% of the pixels.
  2. Pct2Pk - percent of the histogram range within the largest 2 peaks.
  3. Bimodal - average of various sums of nearest neighbor pixel differences.

The paper has more detailed about each feature (including details of how to compute them). Since all of these values are derived from the histogram itself, I initially decided to build a classifier that just used the grayscale percentage counts as features instead. My thought was that the features would be mutually independent and thus perform better with something like Naive Bayes, or at least no worse than using three derived features as suggested in the paper.

Turns out I was wrong, but since I used Weka for most of the experimentation, the mistake wasn't overwhelmingly expensive :-).

So, anyway, for my classifier, I decided to use the first two features suggested by the model, plus another one that reflected the "choppiness" of the histograms - the sum of absolute differences between successive grayscale percentage counts which I call AbsDiff. I used AbsDiff instead of Bimodal because the calculation of the Bimodal feature is quite compute intensive (see the paper for details).

I tested out both models using the Weka Explorer with the built-in Naive Bayes classifier. There are (many) other classifiers available in Weka, but its been a while since I read the Weka book, and I chose Naive Bayes just because I was familiar with it.

The training data for both models was a manually annotated training set of approximately 200 images. While tedious, such annotation does not require any specialized domain knowledge (unlike medical text for example). The images were then analyzed and their data (the grayscale percentage counts in the case of the first model and the feature values in the case of the second model) were written out to a file in Weka's ARFF format. Here is a snippet (the rest of it is just more training data in the same format as the first four rows) from the ARFF file for the model I ended up using.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
@relation pdc2

@attribute pct05pc real
@attribute pct2pk real
@attribute absdiff real
@attribute klass {photo,drawing}

@data
30.46875,32.77869002006933,0.35458615684431966,photo
20.703125,73.67603966544412,0.45621152596647285,photo
14.453125,67.73598566997008,0.18333175409590724,drawing
3.90625,52.661171146208545,0.15240938475387075,drawing
...

Results for both models using Naive Bayes with 10-fold cross validation are shown below:

1
2
3
4
5
6
7
# Using raw grayscale percentage counts (pdc1)
Correctly Classified Instances         172               88.2051 %
Incorrectly Classified Instances        23               11.7949 %

# Using computed features (pdc2)
Correctly Classified Instances         176               90.2564 %
Incorrectly Classified Instances        19                9.7436 %

For testing, I once again manually annotated another set of 200 images, embedded the Weka NaiveBayes classifier inside my Java code (shown below), trained it with the ARFF file for the training data, then passed each of the manually annotated images through the classifier. Here is the code for the classifier, based on the example classifier code provided on this Weka Wiki page.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
package com.mycompany.classifiers;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import java.net.URLConnection;

import javax.imageio.ImageIO;

import org.apache.commons.collections15.Bag;
import org.apache.commons.collections15.bag.HashBag;

import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Instance;
import weka.core.Instances;

public class PhotoOrDrawingClassifier {

  private Classifier classifier;
  private Instances dataset;

  //////////////// train /////////////////

  public void train(String arff) throws Exception {
    classifier = new NaiveBayes();
    dataset = new Instances(
      new BufferedReader(new FileReader(arff)));
    // our class is the last attribute
    dataset.setClassIndex(dataset.numAttributes() - 1);
    classifier.buildClassifier(dataset);
  }
  
  /////////////// test ////////////////////
  
  public String test(String imgfile) throws Exception {
    double[] histogram = buildHistogram(new File(imgfile));
    Instance instance = new Instance(dataset.numAttributes());
    instance.setDataset(dataset);
    instance.setValue(dataset.attribute(0), getPct05Pc(histogram));
    instance.setValue(dataset.attribute(1), getPct2Pk(histogram));
    instance.setValue(dataset.attribute(2), getAbsDiff(histogram));
    double[] distribution = classifier.distributionForInstance(instance);
    return distribution[0] > distribution[1] ? "photo" : "drawing";
  }
  
  //////////////// helper code /////////////////////////
  
  private static final double LUMINANCE_RED = 0.299D;
  private static final double LUMINANCE_GREEN = 0.587D;
  private static final double LUMINANCE_BLUE = 0.114;
  private static final int HIST_WIDTH = 256;
  private static final int HIST_HEIGHT = 100;
  private static final int HIST_5_PCT = 10;
  
  /**
   * Parses pixels out of an image file, converts the RGB values to
   * its equivalent grayscale value (0-255), then constructs a 
   * histogram of the percentage of counts of grayscale values.
   * @param infile - the image file.
   * @return - a histogram of grayscale percentage counts.
   */
  protected double[] buildHistogram(File infile) throws Exception {
    BufferedImage input = ImageIO.read(infile);
    int width = input.getWidth();
    int height = input.getHeight();
    Bag<Integer> graylevels = new HashBag<Integer>();
    double maxWidth = 0.0D;
    double maxHeight = 0.0D;
    for (int row = 0; row < width; row++) {
      for (int col = 0; col < height; col++) {
        Color c = new Color(input.getRGB(row, col));
        int graylevel = (int) (LUMINANCE_RED * c.getRed() +
          LUMINANCE_GREEN * c.getGreen() + 
          LUMINANCE_BLUE * c.getBlue());
        graylevels.add(graylevel);
        maxHeight++;
        if (graylevel > maxWidth) {
          maxWidth = graylevel;
        }
      }
    }
    double[] histogram = new double[HIST_WIDTH];
    for (Integer graylevel : graylevels.uniqueSet()) {
      int idx = graylevel;
      histogram[idx] += 
        graylevels.getCount(graylevel) * HIST_HEIGHT / maxHeight;
    }
    return histogram;
  }
  
  protected double getPct05Pc(double[] histogram) {
    double numBins = 0.0D;
    for (int gl = 0; gl < histogram.length; gl++) {
      if (histogram[gl] > 0.5D) {
        numBins++;
      }
    }
    return numBins * 100 / HIST_WIDTH;
  }

  protected double getPct2Pk(double[] histogram) {
    double pct2pk = 0.0D;
    // find the maximum entry (first peak) 
    int maxima1 = getMaxima(histogram, new int[][] {{0, histogram.length}});
    // navigate left until an inflection point is reached
    int lminima1 = getMinima(histogram, new int[] {maxima1, 0});
    int rminima1 = getMinima(histogram, new int[] {maxima1, histogram.length});
    for (int gl = lminima1; gl <= rminima1; gl++) {
      pct2pk += histogram[gl];
    }
    // find the second peak
    int maxima2 = getMaxima(histogram, new int[][] {
        {0, lminima1 - 1}, {rminima1 + 1, histogram.length}}); 
    int lminima2 = 0;
    int rminima2 = 0;
    if (maxima2 > maxima1) {
      // new maxima is to the right of previous on
      lminima2 = getMinima(histogram, new int[] {maxima2, rminima1 + 1});
      rminima2 = getMinima(histogram, new int[] {maxima2, histogram.length}); 
    } else {
      // new maxima is to the left of previous one
      lminima2 = getMinima(histogram, new int[] {maxima2, 0});
      rminima2 = getMinima(histogram, new int[] {maxima2, lminima1 - 1});
    }
    for (int gl = lminima2; gl < rminima2; gl++) {
      pct2pk += histogram[gl];
    }
    return pct2pk;
  }
  
  protected double getAbsDiff(double[] histogram) {
    double absdiff = 0.0D;
    int diffSteps = 0;
    for (int i = 1; i < histogram.length; i++) {
      if (histogram[i-1] != histogram[i]) {
        absdiff += Math.abs(histogram[i] - histogram[i-1]);
        diffSteps++;
      }
    }
    return absdiff / diffSteps;
  }
  
  private int getMaxima(double[] histogram, int[][] ranges) {
    int maxima = 0;
    double maxY = 0.0D;
    for (int i = 0; i < ranges.length; i++) {
      for (int gl = ranges[i][0]; gl < ranges[i][1]; gl++) {
        if (histogram[gl] > maxY) {
          maxY = histogram[gl];
          maxima = gl;
        }
      }
    }
    return maxima;
  }

  private int getMinima(double[] histogram, int[] range) {
    int start = range[0];
    int end = range[1];
    if (start == end) {
      return start;
    }
    boolean forward = start < end;
    double prevY = histogram[start];
    double dy = 0.0D;
    double prevDy = 0.0D;
    if (forward) {
      // avoid getting trapped in local minima
      int minlookahead = start + HIST_5_PCT;
      for (int pos = start + 1; pos < end; pos++) {
        dy = histogram[pos] - prevY;
        if (signdiff(dy, prevDy) && pos >= minlookahead) {
          return pos;
        }
        prevY = histogram[pos];
        prevDy = dy;
      }
    } else {
      // avoid getting trapped in local minima
      int minlookbehind = start - HIST_5_PCT;
      for (int pos = start - 1; pos >= end; pos--) {
        dy = histogram[pos] - prevY;
        if (signdiff(dy, prevDy) && pos <= minlookbehind) {
          return pos;
        }
        prevY = histogram[pos];
        prevDy = dy;
      }
    }
    return start;
  }

  private boolean signdiff(double dy, double prevDy) {
    return ((dy < 0.0D && prevDy > 0.0D) ||
        (dy > 0.0 && prevDy < 0.0D));
  }
  
  /**
   * Downloads image file referenced by URL into a local file specified
   * by filename for processing.
   * @param url - URL for the image file.
   * @param filename - the local filename for the file.
   */
  protected void download(String url, String filename) 
      throws Exception {
    URL u = new URL(url);
    URLConnection conn = u.openConnection();
    InputStream is = conn.getInputStream();
    byte[] buf = new byte[4096];
    int len = -1;
    OutputStream os = new FileOutputStream(new File(filename));
    while ((len = is.read(buf)) != -1) {
      os.write(buf, 0, len);
    }
    os.close();
  }
}

I tried to make it so the model would be dumped out into a serialized file after training, and the classification code (test) would only look at the serialized file, similar to how most other toolkits do it. Weka does provide a way to serialize the model, but the serialized model does not contain the data format (the header information in the training ARFF file) in order to classify unseen images, as explained in this thread. So since the training doesn't take that long, rather than supplying an empty ARFF file, I built it so that the classifier and the header from the training data (dataset in the code) are stored in memory and used for testing.

Client code to train the classifer and then use it in a streaming manner on unseen images is shown in the JUnit snippet below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
  @Test
  public void testModel() throws Exception {
    PhotoOrDrawingClassifier classifier = new PhotoOrDrawingClassifier();
    classifier.train("/path/to/pdc2.arff");
    String filename = null;
    // a file containing paths to the test set images
    BufferedReader reader = new BufferedReader(
      new FileReader(new File("/path/to/test.txt")));
    while ((line = reader.readLine()) != null) {
      String klass = classifier.test(filename);
      System.out.println(filename + " => " + klass);
    }
    reader.close();
  }

Running this code classifies the images in my test set, and comparing with my annotations gave me the following results (formatted similarly to how the Weka explorer produces its results, for ease of comparison).

1
2
Correctly Classified Instances         177               88.5000 %
Incorrectly Classified Instances        23               11.5000 %

88.5% seems a bit low for such a simple task, but it can probably be improved by using some other classifier. I guess I need to go through the Weka book again to find candidate classifier algorithms that would work well with the dataset.

Another way to make the classification perform better may be to rethink it a bit. While collecting training and test instances, I found several images for which the classification is not clearcut. These are either image groups which consist of line drawings and photographs together, or labelled shaded color drawings, whose histograms would not have the choppiness associated with charts. When annotating my training and test sets, I used my (programmer's) judgement to classify them as one or the other based on "how close" it was to a photo or drawing. Perhaps we may get more accurate results if we considered classifying the images into more than just two groups.

Wednesday, April 04, 2012

Generating Unigram and Bigrams into MySQL from Hadoop SequenceFiles

In my previous post, I described how I used GNU Parallel to read a fairly large Lucene index into a set of Hadoop SequenceFiles. The objective is to use the data in the index to build a Unigram and Bigram Language Model for a spelling corrector. Since the spelling correction code is going to be called from a web application, I figured a good place to store the unigrams and bigrams in a MySQL database.

This is a fairly trivial task from the point of view of writing Map-Reduce code (the unigram writer is just a minor variation of the WordCount example), but this is the first time I was using Map-Reduce to crunch through a reasonably large dataset. I was also running Hadoop on a single large machine in pseudo-distributed mode, unlike previously where I mostly used it in local mode to build little proofs of concept. So there were certain things I learned about running Hadoop, which I will mention as they come up. But first, the code.

Java Code

As stated above, the code for both the UnigramCounter and BigramCounter are fairly trivial examples of Map-Reduce code. But I include them anyway, for completeness.

UnigramCounter.java

The UnigramCounter Mapper splits up the input text into words and writes them out to the context, where the Reducer picks them up and aggregates the counts, computes the soundex and metaphone values for the word, and writes the record out to a database table. The soundex and metaphones are for finding sound-alikes - I am not sure which one will give me the best results, so I compute both.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
package com.mycompany.spell3.train;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.codec.language.Metaphone;
import org.apache.commons.codec.language.Soundex;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

public class UnigramCounter extends Configured implements Tool {

  private static final String PROP_DBNAME = "dbname";
  private static final String PROP_DBUSER = "dbuser";
  private static final String PROP_DBPASS = "dbpass";

  private static final String NULL_PATH = "/prod/hadoop/dummy";

  public static class MapClass extends 
      Mapper<LongWritable,Text,Text,IntWritable> {

    private static final IntWritable ONE = new IntWritable(1);
    
    @Override
    protected void map(LongWritable key, Text value, 
        Context context) throws IOException, 
        InterruptedException {
      String s = StringUtils.lowerCase(value.toString());
      String[] words = s.split("[^a-z]+");
      for (String word : words) {
        context.write(new Text(word), ONE);
      }
    }
  }
  
  public static class ReduceClass extends 
      Reducer<Text,IntWritable,Text,IntWritable> {

    private String MYSQL_DB_DRIVER = "com.mysql.jdbc.Driver";

    private Connection conn;
    private PreparedStatement ps;
    private AtomicInteger counter = new AtomicInteger(0);
    private Soundex soundex;
    private Metaphone metaphone;
    
    @Override
    protected void setup(Context context) 
        throws IOException, InterruptedException {
      try {
        Class.forName(MYSQL_DB_DRIVER);
        Configuration conf = context.getConfiguration();
        conn = DriverManager.getConnection(
          "jdbc:mysql://localhost:3306/" + conf.get(PROP_DBNAME),
          conf.get(PROP_DBUSER), conf.get(PROP_DBPASS));
        conn.setAutoCommit(false);
        ps = conn.prepareStatement(
          "insert into unigram_counts(word,soundex,metaphone,cnt) " +
          "values (?,?,?,?)");
        soundex = new Soundex();
        metaphone = new Metaphone();
      } catch (Exception e) {
        throw new IOException(e);
      }
    }

    @Override
    protected void reduce(Text key, Iterable<IntWritable> values, 
        Context context) throws IOException, 
        InterruptedException {
      int sum = 0;
      for (IntWritable value : values) {
        sum += value.get();
      }
      insertToDb(key.toString(), sum);
    }
    
    private void insertToDb(String word, int count) 
        throws IOException {
      try {
        ps.setString(1, word);
        ps.setString(2, soundex.soundex(word));
        ps.setString(3, metaphone.metaphone(word));
        ps.setInt(4, count);
        ps.execute();
        int current = counter.incrementAndGet();
        if (current % 1000 == 0) {
          conn.commit();
        }
      } catch (SQLException e) {
        System.out.println("Failed to insert unigram: " + word);
        e.printStackTrace();
      }
    }

    @Override
    protected void cleanup(Context context) 
        throws IOException, InterruptedException {
      if (ps != null) {
        try { ps.close(); } catch (SQLException e1) {}
      }
      if (conn != null) {
        try {
          conn.commit();
          conn.close();
        } catch (SQLException e) {
          throw new IOException(e);
        }
      }
    }
  }
  
  @Override
  public int run(String[] args) throws Exception {
    Path input = new Path(args[0]);
    Path output = new Path(NULL_PATH);
    
    Configuration conf = getConf();
    conf.set(PROP_DBNAME, args[1]);
    conf.set(PROP_DBUSER, args[2]);
    conf.set(PROP_DBPASS, args[3]);
    
    Job job = new Job(conf, "Unigram-Counter");
    
    FileInputFormat.setInputPaths(job, input);
    FileOutputFormat.setOutputPath(job, output);
    
    job.setJarByClass(UnigramCounter.class);
    job.setMapperClass(MapClass.class);
    job.setReducerClass(ReduceClass.class);
    job.setInputFormatClass(SequenceFileInputFormat.class);
    job.setMapOutputKeyClass(Text.class);
    job.setMapOutputValueClass(IntWritable.class);
    job.setNumReduceTasks(5);
    
    boolean succ = job.waitForCompletion(true);
    if (! succ) {
      System.out.println("Job failed, exiting");
      return -1;
    }
    return 0;
  }

  public static void main(String[] args) throws Exception {
    if (args.length != 4) {
      System.out.println(
        "Usage: UnigramCounter path_to_seqfiles output_db db_user db_pass");
      System.exit(-1);
    }
    int res = ToolRunner.run(new Configuration(), 
      new UnigramCounter(), args);
    System.exit(res);
  }
}

BigramCounter.java

The BigramCounter Mapper uses a Sentence BreakIterator to break the input up into sentences, computes bigrams of word pairs within each sentence and writes them out to the context, where the Reducer picks them up, aggregates the counts and writes the bigram and count to another database table.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
package com.mycompany.spell3.train;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.BreakIterator;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

public class BigramCounter extends Configured implements Tool {

  private static final String PROP_DBNAME = "dbname";
  private static final String PROP_DBUSER = "dbuser";
  private static final String PROP_DBPASS = "dbpass";

  private static final String NULL_PATH = "/prod/hadoop/dummy";

  public static class MapClass extends 
      Mapper<LongWritable,Text,Text,IntWritable> {
    
    private static final IntWritable ONE = new IntWritable(1);
    private static final String SENTENCE_START = "<s>";
    private static final String SENTENCE_END = "</s>";
    private static final String WORD_SEPARATOR = "__";
    
    @Override
    protected void map(LongWritable key, Text value, 
        Context context) 
        throws IOException, InterruptedException {
      String s = value.toString();
      BreakIterator sit = BreakIterator.getSentenceInstance();
      sit.setText(s);
      int start = sit.first();
      int end = -1;
      while ((end = sit.next()) != BreakIterator.DONE) {
        String sentence = StringUtils.lowerCase(s.substring(start, end));
        start = end;
        String[] words = sentence.split("[^a-z]+");
        String prevWord = null;
        for (int i = 0; i < words.length; i++) {
          String bigram = null;
          if (i == 0) {
            // begin sentence
            bigram = StringUtils.join(
              new String[] {SENTENCE_START, words[i]}, 
              WORD_SEPARATOR);
          } else if (i == words.length - 1) {
            // end sentence
            bigram = StringUtils.join(
              new String[] {words[i], SENTENCE_END}, 
              WORD_SEPARATOR);
          } else {
            // middle of sentence
            bigram = StringUtils.join(new String[] {
              prevWord, words[i]}, WORD_SEPARATOR);
          }
          context.write(new Text(bigram), ONE);
          prevWord = words[i];
        }
      }
    }
  }
  
  public static class ReduceClass extends 
    Reducer<Text,IntWritable,Text,IntWritable> {

    private static final String MYSQL_DB_DRIVER = "com.mysql.jdbc.Driver";

    private Connection conn;
    private PreparedStatement ps;
    private AtomicInteger counter = new AtomicInteger(0);
    
    @Override
    protected void setup(Context context) 
        throws IOException, InterruptedException {
      try {
        Class.forName(MYSQL_DB_DRIVER);
        Configuration conf = context.getConfiguration();
        conn = DriverManager.getConnection(
          "jdbc:mysql://localhost:3306/" + conf.get(PROP_DBNAME), 
          conf.get(PROP_DBUSER), conf.get(PROP_DBPASS));
        conn.setAutoCommit(false);
        ps = conn.prepareStatement(
          "insert into bigram_counts(bigram,cnt) values (?,?)");
      } catch (Exception e) {
        throw new IOException(e);
      }
    }
    
    @Override
    protected void reduce(Text key, Iterable<IntWritable> values, 
        Context context) 
        throws IOException, InterruptedException {
      int sum = 0;
      for (IntWritable value : values) {
        sum += value.get();
      }
      insertToDb(key.toString(), sum);
    }
    
    private void insertToDb(String bigram, int sum) 
        throws IOException {
      try {
        ps.setString(1, bigram);
        ps.setInt(2, sum);
        ps.execute();
        int current = counter.incrementAndGet();
        if (current % 1000 == 0) {
          conn.commit();
        }
      } catch (SQLException e) {
        System.out.println("Failed to insert bigram: " + bigram);
        e.printStackTrace();
      }
    }

    @Override
    protected void cleanup(Context context)
        throws IOException, InterruptedException {
      if (ps != null) {
        try { ps.close(); } catch (SQLException e) {}
      }
      if (conn != null) {
        try {
          conn.commit();
          conn.close();
        } catch (SQLException e) {
          throw new IOException(e);
        }
      }
    }
  }
  
  @Override
  public int run(String[] args) throws Exception {
    Path input = new Path(args[0]);
    Path output = new Path(NULL_PATH);
    
    Configuration conf = getConf();
    conf.set(PROP_DBNAME, args[1]);
    conf.set(PROP_DBUSER, args[2]);
    conf.set(PROP_DBPASS, args[3]);
    
    Job job = new Job(conf, "Bigram-Counter");
    
    FileInputFormat.setInputPaths(job, input);
    FileOutputFormat.setOutputPath(job, output);
    
    job.setJarByClass(BigramCounter.class);
    job.setMapperClass(MapClass.class);
    job.setReducerClass(ReduceClass.class);
    job.setInputFormatClass(SequenceFileInputFormat.class);
    job.setMapOutputKeyClass(Text.class);
    job.setMapOutputValueClass(IntWritable.class);
    job.setNumReduceTasks(5);
    
    boolean succ = job.waitForCompletion(true);
    if (! succ) {
      System.out.println("Job failed, exiting");
      return -1;
    }
    return 0;
  }

  public static void main(String[] args) throws Exception {
    if (args.length != 4) {
      System.out.println(
        "Usage: BigramCounter path_to_seqfiles output_db db_user db_pass");
      System.exit(-1);
    }
    int res = ToolRunner.run(new Configuration(), 
      new BigramCounter(), args);
    System.exit(res);
  }
}

Hadoop Configuration Changes

Hadoop is built to run on clusters of many medium size machines. What I had instead was one large 16-CPU machine, so I wanted to make sure that its processing power was utilized to the maximum possible. So I made the following changes to mapred-site.xml based on the advice in this StackOverflow page.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
<!-- Source: $HADOOP_HOME/conf/mapred-site.xml -->
<configuration>
...
<property>
  <name>mapred.tasktracker.map.tasks.maximum</name>
  <value>10</value>
  <description/>
</property>

<property>
  <name>mapred.tasktracker.reduce.tasks.maximum</name>
  <value>10</value>
  <description/>
</property>
</configuration>

In core-site.xml, I changed the location of the hadoop.tmp.dir to a large, relatively unused partition on the box instead of its default location. This was actually in response to a job failure where it ran out of HDFS space. Since at that point I had to rerun the job again anyway, I shut down Hadoop, deleted the old hadoop.tmp.dir and then restarted Hadoop and reformatted the namenode.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
<!-- Source; $HADOOP_HOME/conf/core-site.xml -->
<configuration>

<property>
  <name>hadoop.tmp.dir</name>
  <value>/prod/hadoop/tmp</value>
  <description>A base for other temporary directories.</description>
</property>
...
</configuration>

Since I have only a single data node, I set the dfs.replication in hdfs-site.xml to 1.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
<!-- $HADOOP_HOME/conf/hdfs-site.xml -->
<configuration>

<property>
  <name>dfs.replication</name>
  <value>1</value>
  <description/>
</property>

</configuration>

MySQL Configuration Changes

The default location of the MySQL data directory was in /var/lib/mysql, which was in the "/" partition, too small for my purposes. I actually ran out of disk space in this partition while writing bigrams to MySQL (the job just hangs at a fixed map-reduce completion status). I had to kill the job, shut down MySQL, reconfigure the data directory and the socket location, move the contents over to the new location, and restart MySQL. Here are the configuration changes:

1
2
3
4
5
6
7
# Source: /etc/my.cnf
[mysqld]
#datadir=/var/lib/mysql
#socket=/var/lib/mysql/mysql.sock
datadir=/prod/mysql_db
socket=/prod/mysql_db/mysql.sock
...

Deployment

Before this, I used to write shell scripts that set the JARS required by Hadoop and my application in the classpath, and then called Java. When I was doing this, I discovered that you can use the $HADOOP_HOME/bin/hadoop to call your custom Map-Reduce tasks as well, so I decided to use that.

However, I needed to set a few custom JAR files that Hadoop did not have (or need) in its classpath. I was using commons-codec which provided me implementations of Soundex and Metaphone, and I was writing to a MySQL database for which I needed the JDBC driver JAR, plus a few others for functionality I was too lazy to implement on my own.

There are two ways to supply these extra JAR files to the bin/hadoop script. One is by specifying their paths in the -libjars parameter. I thought this was nice, but it didn't work for me - for some reason it could not see the parameters I was passing to my Map-Reduce job via the command line. The second way is to package your custom JARs in the lib subdirectory of your application's JAR file, a so-called fat jar. The fat JAR approach was the one I took, creating it using the simple Ant target shown below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
  <target name="fatjar" depends="compile" description="Build JAR to run in Hadoop">
    <mkdir dir="${maven.build.output}/lib"/>
    <copy todir="${maven.build.output}/lib">
      <fileset dir="${custom.jars.dir}">
        <include name="commons-lang3-3.0.1.jar"/>
        <include name="commons-codec-1.3.jar"/>
        <include name="mysql-connector-java-5.0.5-bin.jar"/>
        ...
      </fileset>
    </copy>
    <jar jarfile="${maven.build.directory}/${maven.build.final.name}-fatjar.jar"
        basedir="${maven.build.output}" excludes="**/package.html"/>
  </target>

Once this is done, the script to run either job is quite simple. I show the cscript to run the BigramCounter below, simply replace with UnigramCounter for the other one.

1
2
3
4
5
6
#!/bin/bash
# Source: bin/bigram_counter.sh
HADOOP_HOME=/opt/hadoop-1.0.1
$HADOOP_HOME/bin/hadoop fs -rmr /prod/hadoop/dummy
$HADOOP_HOME/bin/hadoop jar /path/to/my-fatjar.jar \
  com.mycompany.spell3.train.BigramCounter $*

To run this script from the command line:

1
2
hduser@bigmac:spell3$ nohup ./bigram_counter.sh /prod/hadoop/spell \
  spelldb spelluser spelluser &

Job Killing

I needed to kill the job midway multiple times, either because I discovered I had goofed on some programming issue (incorrect database column names, etc) and the job would start throwing all kinds of exceptions down the line, or because (as mentioned previously), MySQL ran out of disk space. To do this, you need to use bin/hadoops job -kill command.

1
2
3
4
hduser@bigmac:hadoop-1.0.1$ # list out running jobs
hduser@bigmac:hadoop-1.0.1$ bin/hadoop job -list
hduser@bigmac:hadoop-1.0.1$ # kill specific job
hduser@bigmac:hadoop-1.0.1$ bin/hadoop job -kill ${job-id}

Even I had enough sense to not do a kill -9 on the hadoop daemon itself, but there was one time when I did a stop-all.sh and ended up having to throw away all my data because Hadoop got all choked up.

Another little tip is to avoid throwing exceptions from your Mapper or Reducer. A better option is to log it. This is true for any batch job, of course, but I once had one of the jobs fail after about 2 days of processing because of too many exceptions thrown by the Reducer. In the code above, I just used a System.ot.println() to log SQLExceptions if they occur, but its better to use a real logger.

So anyway, after about a week and a half of processing (including all sorts of silly but expensive mistakes), I ended up with approximately 400 million unigrams and 600 million bigrams in the database. Now to figure out how to actually use this information :-).

Update - 2012-04-09: I had a bug in my bigram generation code, which caused bad results, so I reran it. This time the job failed two times in a row, caused by (I suspect) extremely high loads on the MySQL database server. The first time I discovered that the mysql.sock file disappeared, so I terminated the job manually. The second time I found that the mysql.sock file would disappear and then reappear after a while once the load came back down (this is the only place I have found another mention of this) - however, ultimately this job failed as well. I ended up writing the bigrams and counts to text files in HDFS and the job completed in a fraction of the time it took before. So another lesson learned - avoid writing out to external datastores from within Hadoop,