Web application that uses scikit-learn
I'm working on a Docker image that wraps predict
and predictproba
methods and expose them as a web api: https://github.com/hexacta/docker-sklearn-predict-http-api
You need to save your model:
from sklearn.externals import joblib
joblib.dump(clf, 'iris-svc.pkl')
create a Dockerfile:
FROM hexacta/sklearn-predict-http-api:latest
COPY iris-svc.pkl /usr/src/app/model.pkl
and run the container:
$ docker build -t iris-svc .
$ docker run -d -p 4000:8080 iris-svc
then you can make requests:
$ curl -H "Content-Type: application/json" -X POST -d '{"sepal length (cm)":4.4}' http://localhost:4000/predictproba
$ curl -H "Content-Type: application/json" -X POST -d '[{"sepal length (cm)":4.4}, {"sepal length (cm)":15}]' http://localhost:4000/predict
[0, 2]
You can follow the tutorial below to deploy your scikit-learn model in Azure ML and get the web service automatically generated:
Build and Deploy a Predictive Web App Using Python and Azure ML
or the combination of yHat + Heroku may also do the trick
If this is just for a demo, train your classifier offline, pickle the model and then use a simple python web framework such as flask or bottle to unpickle the model at server startup time and call the predict function in an HTTP request handler.
django is a feature complete framework hence is longer to learn than flask or bottle but it has a great documentation and a larger community.
heroku is a service to host your application in the cloud. It's possible to host flask applications on heroku, here is a simple template project + instructions to do so.
For "production" setups I would advise you not to use pickle but to write your own persistence layer for the machine learning model so as to have full control on the parameters your store and be more robust to library upgrades that might break the unpickling of old models.
While this is not a classifier, I have implemented a simple machine learning web service using the bottle framework and scikit-learn. Given a dataset in .csv format it returns 2D visualizations with respect to principal components analysis and linear discriminant analysis techniques.
More information and example data files can be found at: http://mindwriting.org/blog/?p=153
Here is the implementation: upload.html:
action="/plot" method="post"
Select a file: <input type="file" name="upload" />
<input type="submit" value="PCA & LDA" />
pca_lda_viz.py (modify host name and port number):
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cStringIO import StringIO
from bottle import route, run, request, static_file
import csv
from matplotlib.font_manager import FontProperties
import colorsys
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.lda import LDA
html = '''
<img src="data:image/png;base64,{}" />
def root():
return static_file('upload.html', root='.')
@route('/plot', method='POST')
def plot():
# Get the data
upload = request.files.get('upload')
mydata = list(csv.reader(upload.file, delimiter=','))
x = [row[0:-1] for row in mydata[1:len(mydata)]]
classes = [row[len(row)-1] for row in mydata[1:len(mydata)]]
labels = list(set(classes))
classIndices = np.array([labels.index(myclass) for myclass in classes])
X = np.array(x).astype('float')
y = classIndices
target_names = labels
#Apply dimensionality reduction
pca = PCA(n_components=2)
X_r = pca.fit(X).transform(X)
lda = LDA(n_components=2)
X_r2 = lda.fit(X, y).transform(X)
#Create 2D visualizations
fig = plt.figure()
ax=fig.add_subplot(1, 2, 1)
bx=fig.add_subplot(1, 2, 2)
fontP = FontProperties()
colors = np.random.rand(len(labels),3)
for c,i, target_name in zip(colors,range(len(labels)), target_names):
ax.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c,
ax.legend(loc='upper center', bbox_to_anchor=(1.05, -0.05),
fancybox=True,shadow=True, ncol=len(labels),prop=fontP)
ax.tick_params(axis='both', which='major', labelsize=6)
for c,i, target_name in zip(colors,range(len(labels)), target_names):
bx.scatter(X_r2[y == i, 0], X_r2[y == i, 1], c=c,
bx.tick_params(axis='both', which='major', labelsize=6)
# Encode image to png in base64
io = StringIO()
fig.savefig(io, format='png')
data = io.getvalue().encode('base64')
return html.format(data)
run(host='mindwriting.org', port=8079, debug=True)