Monday, August 2, 2010

Machine Learning: Google Prediction API

Loosely speaking, machine learning is using a computer to recognize patterns in data, and then make predictions about new data based on what it has learned. It is like a marriage between computer science and statistics. Besides its most obvious application (an army of sentient robots which wages war against mankind a.k.a. Skynet), there are a lot of uses for machine learning including:
  • Predicting housing prices
  • Recommendation engine for a retail website based on past customer purchases
  • Spam filter or any document classifier
  • Autonomous vehicles such as cars or helicopters
  • Forecasting electricity demand based on historical data
  • Computers playing chess
The best way to understand how machine learning works is to see an example. Google has released a machine learning software-as-a-service product called the Prediction API. I wrote a simple spam filter using the Prediction API, and it works surprisingly well. The basic usage is 2 steps: train using historical data, predict using new data.

Training sets in machine learning are generally tables or matrixes where each row in the table represents a single training example. The first column in the table is the classification for that example. The classification values of my spam filter are either "spam" for a spam email or "good" for a non-spam email. The additional columns in the table are features of that training example. In the case of predicting house prices, the features would be attributes about the house such as the square footage, number of bed rooms, whether it has a pool, etc. In the case of the spam filter, the text of the email's subject and body are the features. Other features of an email could be header information, but I kept it simple. The Prediction API requires the training set to be a comma-separated file. Here's what the training set looked like for the spam filter:

# classification, email text
good, "Some email text..."
spam, "Do you want to buy a degree?"
good, "Hi mom, ...."
spam, "Click on this shady link..."

To get a decent quantity of real data to test with, I wrote a little IMAP email exporter that downloaded a month's worth of my good email and my spam email from Gmail. My training set was about 3000 emails. I also put 1500 additional emails in another CSV file to use for testing later. It is important that the testing data set be independent of the training set. Then I uploaded the training set to Google Storage which is another cool software-as-a-service product that Google is developing. Google Storage is similar to Amazon S3. It is easy to use, and I like the interface, but that's another topic.

The Prediction API is REST-based. Once the training set is uploaded, you call a simple request to start training:

www.googleapis.com/prediction/v1/train/{my bucket}

There is another call to check the status of the training. For my data set, training only took a couple of minutes. After the training has completed, you make another REST call to get a prediction. The payload of the request is JSON object that has the same features as your training set. My spam filter only had one feature: text. If you used 10 features, the JSON object would have 10 fields. In practice, you can have hundreds or thousands of features.

www.googleapis.com/prediction/v1/train/{my bucket}/predict

{"data":{
"input": {
"text" : ["Want to buy a degree?"]}}}}

The Prediction API returns the most probable classification for that request. The response looks like:

{"data":{
"output" :{
["output_label":"spam"]}}}

Conclusion

I wrote a little Groovy script that ran through 1500 test emails and checked to see if the Prediction API would pick the correct classification. So how did it do? It correctly identified spam 91% of the time. I thought that was really good considering I only used the text of the email as an input feature. I was able to create a simple spam filter that is 91% accurate and it only took me a couple of hours. Keep in mind the Prediction API documentation states that machine learning is the ultimate case of "garbage in, garbage out". The quality of your training examples has a huge effect on the accuracy of your predictions.

Machine learning algorithms are not one-size-fits-all by any means. The art of ML is using the right algorithm for your data and tuning that algorithm appropriately. Google has taken an interesting approach by making the Prediction API appear to be a "universal learner" that works on any data set. During the learning process, it must try several different algorithms and pick the best fit. However Google is keeping mum on the internals of the Prediction API.

I ran into a few snags when I was running my tests, but the Prediction API message boards are responsive and I got help quickly. The Prediction API is currently experimental, so it should interesting to see where Google takes it, and if it sees the light of day as a commercial offering.

Machine learning is an interesting and valuable field that has a lot of uses in software development. In my next machine learning post, we'll look at the open source ML toolkit Weka.

References and Further Reading

4 comments:

  1. Nice write up, I agree, very very interesting API with an enormous set of potential uses.

    ReplyDelete
  2. How many emails in the set were spam (i.e. what was the prior probability)? and how many false positives and false negatives were there? To really understand how good 91% is we need to know a few more things.

    For example, let's assume that the prior probability is 10%, i.e. 150 out of the 1500 emails are spam (this is probably very low). I'm going to assume that "91% accurate" means 9% false positives (spam slipping through the filter), but no number for false negatives is given (legitimate emails identified as spam), so let's make it easy and say that it too is 9%.

    Using these numbers we find that 136 out of 150 (0.91 * 150 rounded down) spam messages are identified, but also that, 121 ((1500 - 150) * 0.09 rounded down) were incorrectly identified as spam. In total 257 messages are identified as spam, but only 136 of these are actually spam (and not all spam has been identified). This means that the probability of a message being spam, given that it has been identified as spam, is just above 50% (the probability of an email being legitimate, given that it has been identified as such is 99% though).

    Now, that is probably completely wrong, because I just assumed the prior probability and the false negative rate (and I might have misunderstood the false positive rate too). If we re-run the calculation with the assumption that half of the emails are spam and the false negative rate is only 1% we get that the probability of something being correctly labeled as spam is 99%, but that something being correctly identified as legitimate will only be 92%.

    I hope this illustrates why I think it would be great if you posted the prior probability and both false positive and false negative rates.

    ReplyDelete
  3. "good" email is also called "ham".

    ReplyDelete
  4. @Lawrence: Thanks, "ham" it is!

    @Theo:

    2000 were "ham" and 1000 were "spam" in the training set. The testing set of 1500 rows maintained the same ham:spam ratio of 2:1.

    The 9% error rate including both false positives and false negatives. Out of 3000 emails, about 270 where incorrectly classified. My little script didn't separate false positives from false negatives, but I'll try to update it and re-run the test to get those numbers.

    Of course in a real-world spam filter, you want to include a cost calculation that reduces false positives (classifying ham as spam) even though it will increase false negatives (classifying spam as ham) because most email users never check their spam folder so false positives are unacceptable. The Prediction API doesn't include a way to assign a cost to a feature, and in fact I'm going to request that as an enhancement. The Google folks who support the Prediction API are very responsive and open to suggestions.

    Also my simplistic spam filter only used the text of the emails and didn't include important attributes such as header info or how many other users marked this email as spam, and I used a fairly small training set. I thought 91% was pretty good for just basic text classification all things considered (one of those considerations being I'm new to ML in general). If you want more info on how accurate the Prediction API is on some common data sets used for ML, so other users posted their accuracy findings on the Prediction API message board (https://groups.google.com/group/prediction-api-discuss?pli=1)

    ReplyDelete