Export python scikit learn models into pmml

Selva picture Selva · Oct 19, 2015 · Viewed 22.6k times · Source

I want to export python scikit-learn models into PMML.

What python package is best suited?

I read about Augustus, but I was not able to find any example using scikit-learn models.


C8H10N4O2 picture C8H10N4O2 · Sep 26, 2016

SkLearn2PMML is

a thin wrapper around the JPMML-SkLearn command-line application. For a list of supported Scikit-Learn Estimator and Transformer types, please refer to the documentation of the JPMML-SkLearn project.

As @user1808924 notes, it supports Python 2.7 or 3.4+. It also requires Java 1.7+

Installed via: (requires git)

pip install git+https://github.com/jpmml/sklearn2pmml.git

Example of how export a classifier tree to PMML. First grow the tree:

# example tree & viz from http://scikit-learn.org/stable/modules/tree.html
from sklearn import datasets, tree
iris = datasets.load_iris()
clf = tree.DecisionTreeClassifier() 
clf = clf.fit(iris.data, iris.target)

There are two parts to an SkLearn2PMML conversion, an estimator (our clf) and a mapper (for preprocessing steps such as discretization or PCA). Our mapper is pretty basic, since we are not doing any transformations.

from sklearn_pandas import DataFrameMapper
default_mapper = DataFrameMapper([(i, None) for i in iris.feature_names + ['Species']])

from sklearn2pmml import sklearn2pmml

It is possible (though not documented) to pass mapper=None, but you will see that the predictor names get lost (returning x1 not sepal length etc.).

Let's look at the .pmml file:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
        <Application name="JPMML-SkLearn" version="1.1.1"/>
        <DataField name="sepal length (cm)" optype="continuous" dataType="float"/>
        <DataField name="sepal width (cm)" optype="continuous" dataType="float"/>
        <DataField name="petal length (cm)" optype="continuous" dataType="float"/>
        <DataField name="petal width (cm)" optype="continuous" dataType="float"/>
        <DataField name="Species" optype="categorical" dataType="string">
            <Value value="setosa"/>
            <Value value="versicolor"/>
            <Value value="virginica"/>
    <TreeModel functionName="classification" splitCharacteristic="binarySplit">
            <MiningField name="Species" usageType="target"/>
            <MiningField name="sepal length (cm)"/>
            <MiningField name="sepal width (cm)"/>
            <MiningField name="petal length (cm)"/>
            <MiningField name="petal width (cm)"/>
            <OutputField name="probability_setosa" dataType="double" feature="probability" value="setosa"/>
            <OutputField name="probability_versicolor" dataType="double" feature="probability" value="versicolor"/>
            <OutputField name="probability_virginica" dataType="double" feature="probability" value="virginica"/>
        <Node id="1">
            <Node id="2" score="setosa" recordCount="50.0">
                <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="0.8"/>
                <ScoreDistribution value="setosa" recordCount="50.0"/>
                <ScoreDistribution value="versicolor" recordCount="0.0"/>
                <ScoreDistribution value="virginica" recordCount="0.0"/>
            <Node id="3">
                <SimplePredicate field="petal width (cm)" operator="greaterThan" value="0.8"/>
                <Node id="4">
                    <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.75"/>
                    <Node id="5">
                        <SimplePredicate field="petal length (cm)" operator="lessOrEqual" value="4.95"/>
                        <Node id="6" score="versicolor" recordCount="47.0">
                            <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.6500001"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="47.0"/>
                            <ScoreDistribution value="virginica" recordCount="0.0"/>
                        <Node id="7" score="virginica" recordCount="1.0">
                            <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.6500001"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="1.0"/>
                    <Node id="8">
                        <SimplePredicate field="petal length (cm)" operator="greaterThan" value="4.95"/>
                        <Node id="9" score="virginica" recordCount="3.0">
                            <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.55"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="3.0"/>
                        <Node id="10">
                            <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.55"/>
                            <Node id="11" score="versicolor" recordCount="2.0">
                                <SimplePredicate field="sepal length (cm)" operator="lessOrEqual" value="6.95"/>
                                <ScoreDistribution value="setosa" recordCount="0.0"/>
                                <ScoreDistribution value="versicolor" recordCount="2.0"/>
                                <ScoreDistribution value="virginica" recordCount="0.0"/>
                            <Node id="12" score="virginica" recordCount="1.0">
                                <SimplePredicate field="sepal length (cm)" operator="greaterThan" value="6.95"/>
                                <ScoreDistribution value="setosa" recordCount="0.0"/>
                                <ScoreDistribution value="versicolor" recordCount="0.0"/>
                                <ScoreDistribution value="virginica" recordCount="1.0"/>
                <Node id="13">
                    <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.75"/>
                    <Node id="14">
                        <SimplePredicate field="petal length (cm)" operator="lessOrEqual" value="4.8500004"/>
                        <Node id="15" score="virginica" recordCount="2.0">
                            <SimplePredicate field="sepal width (cm)" operator="lessOrEqual" value="3.1"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="2.0"/>
                        <Node id="16" score="versicolor" recordCount="1.0">
                            <SimplePredicate field="sepal width (cm)" operator="greaterThan" value="3.1"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="1.0"/>
                            <ScoreDistribution value="virginica" recordCount="0.0"/>
                    <Node id="17" score="virginica" recordCount="43.0">
                        <SimplePredicate field="petal length (cm)" operator="greaterThan" value="4.8500004"/>
                        <ScoreDistribution value="setosa" recordCount="0.0"/>
                        <ScoreDistribution value="versicolor" recordCount="0.0"/>
                        <ScoreDistribution value="virginica" recordCount="43.0"/>

The first split (Node 1) is on petal width at 0.8. Node 2 (petal width <= 0.8) captures all of the setosa, with nothing else.

You can compare the pmml output to the graphviz output:

from sklearn.externals.six import StringIO
import pydotplus # this might be pydot for python 2.7
dot_data = StringIO() 
                     filled=True, rounded=True,  
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# for in-line display, you can also do:
# from IPython.display import Image  
# Image(graph.create_png())  

enter image description here