Loosely speaking, machine learning is using a computer to recognize patterns in data, and then make predictions about new data based on what it has learned. It is like a marriage between computer science and statistics. Besides its most obvious application (an army of sentient robots which wages war against mankind a.k.a. Skynet), there are a lot of uses for machine learning including:
- Predicting housing prices
- Recommendation engine for a retail website based on past customer purchases
- Spam filter or any document classifier
- Autonomous vehicles such as cars or helicopters
- Forecasting electricity demand based on historical data
- Computers playing chess
The best way to understand how machine learning works is to see an example. Google has released a machine learning software-as-a-service product called the Prediction API. I wrote a simple spam filter using the Prediction API, and it works surprisingly well. The basic usage is 2 steps: train using historical data, predict using new data.
Training sets in machine learning are generally tables or matrixes where each row in the table represents a single training example. The first column in the table is the classification for that example. The classification values of my spam filter are either “spam” for a spam email or “good” for a non-spam email. The additional columns in the table are features of that training example. In the case of predicting house prices, the features would be attributes about the house such as the square footage, number of bed rooms, whether it has a pool, etc. In the case of the spam filter, the text of the email’s subject and body are the features. Other features of an email could be header information, but I kept it simple. The Prediction API requires the training set to be a comma-separated file. Here’s what the training set looked like for the spam filter:
# classification, email text good, "Some email text..." spam, "Do you want to buy a degree?" good, "Hi mom, ...." spam, "Click on this shady link..."
To get a decent quantity of real data to test with, I wrote a little IMAP email exporter that downloaded a month’s worth of my good email and my spam email from Gmail. My training set was about 3000 emails. I also put 1500 additional emails in another CSV file to use for testing later. It is important that the testing data set be independent of the training set. Then I uploaded the training set to Google Storage which is another cool software-as-a-service product that Google is developing. Google Storage is similar to Amazon S3. It is easy to use, and I like the interface, but that’s another topic.
The Prediction API is REST-based. Once the training set is uploaded, you call a simple request to start training:
www.googleapis.com/prediction/v1/train/{my bucket}
There is another call to check the status of the training. For my data set, training only took a couple of minutes. After the training has completed, you make another REST call to get a prediction. The payload of the request is JSON object that has the same features as your training set. My spam filter only had one feature: text. If you used 10 features, the JSON object would have 10 fields. In practice, you can have hundreds or thousands of features.
www.googleapis.com/prediction/v1/train/{my bucket}/predict {"data":{ "input": { "text" : ["Want to buy a degree?"]}}}}
The Prediction API returns the most probable classification for that request. The response looks like:
{"data":{ "output" :{ ["output_label":"spam"]}}}
Conclusion
I wrote a little Groovy script that ran through 1500 test emails and checked to see if the Prediction API would pick the correct classification. So how did it do? It correctly identified spam 91% of the time. I thought that was really good considering I only used the text of the email as an input feature. I was able to create a simple spam filter that is 91% accurate and it only took me a couple of hours. Keep in mind the Prediction API documentation states that machine learning is the ultimate case of “garbage in, garbage out”. The quality of your training examples has a huge effect on the accuracy of your predictions.
Machine learning algorithms are not one-size-fits-all by any means. The art of ML is using the right algorithm for your data and tuning that algorithm appropriately. Google has taken an interesting approach by making the Prediction API appear to be a “universal learner” that works on any data set. During the learning process, it must try several different algorithms and pick the best fit. However Google is keeping mum on the internals of the Prediction API.
I ran into a few snags when I was running my tests, but the Prediction API message boards are responsive and I got help quickly. The Prediction API is currently experimental, so it should interesting to see where Google takes it, and if it sees the light of day as a commercial offering.
Machine learning is an interesting and valuable field that has a lot of uses in software development. In my next machine learning post, we’ll look at the open source ML toolkit Weka.
References and Further Reading