机器学习模型之 PMML

这里讲如何用Python生成机器学习模型的PMML文件,然后在Java读取文件用于线上预测。

文章里提到的Python和Java代码在git代码库 PMML

一、PMML简介

PMML全称预言模型标记模型(Predictive Model Markup Language),以XML 为载体呈现数据挖掘模型。PMML 允许您在不同的应用程序之间轻松共享预测分析模型。因此,您可以在一个系统中定型一个模型,在 PMML 中对其进行表达,然后将其移动到另一个系统中,而不需考虑分析和预测过程中的具体实现细节。使得模型的部署摆脱了模型开发和产品整合的束缚。

PMML 标准是数据挖掘过程的一个实例化标准,它按照数据挖掘任务执行过程,有序的定义了数据挖掘不同阶段的相关信息:

file

  • 头信息(Header)
  • 数据字典(DataDictionary)
  • 挖掘模式(Mining Schema)
  • 数据转换(Transformations)
  • 模型定义 (Model Definition)
  • 评分结果 (Score Result)

二、离线和线上

机器学习模型的应用一般会经历两个主要过程:离线开发和线上部署。

  • 离线开发
    • 人员:数据科学家或者建模人员
    • 语言:Python
    • 目标:分析数据,加工变量,建立模型用于预测或者分类
  • 线上部署
    • 人员:数据工程师
    • 语言:Java
    • 目标:让模型在线上运行,保证稳定可靠效率高

很显然,Python代码开发的模型肯定是无法直接用在Java里。这里有几个解决方法

  • Java调用Python
    • 方法:把Python当成一个脚本,用Java去调用。
    • 评价:这种方法很原始,练习可以,但是不可用于线上,因为效率没有保障。
  • Java重新实现
    • 方法:参考Python代码,用Java重新实现所有逻辑:特征加工、模型选择、参数配置等。
    • 评价:这种方法比较理想,但是维护代价大。
      • 需要同时维护两套系统,而且保证它们的逻辑和结果一致,改动时两处都需要修改。
  • 分工合作:PMML
    • 方法:Python生成PMML文件,Java读取文件用于线上预测。
    • 评价:复杂的变量的加工可以在Java里进行,余下的变量加工和模型都在PMML文件里。

如果模型的开发人员同时也会Python和Java,那就很完美。但是建模过程更多是离线式迭代式的工作,用Python更加适合。当模型用于生产系统时,用Java更加合适。大部分公司里有专门的数据科学家和数据工程师,各自负责一块,协作一起完成一项任务。

三、Python里如何生成PMML?

  • 开发工具: PyCharm
  • 语言:Python 3.7
  • 三方包
    • sklearn 机器学习
    • sklearn2pmml 把机器学习模型翻译成PMML文件
  • 任务
    • 读取数据,训练模型,输出PMML文件
    • 备注:为了简化例子,这里用了全量数据来训练。真实场景里会分成train和test数据集,还会用cross validation来调节超参数。

Python代码如下

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml

if __name__ == '__main__':
    iris = load_iris() # 经典的数据
    X = iris.data  # 样本特征
    y = iris.target  # 分类目标
    pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())]) # 用决策树分类
    pipeline.fit(X, y)  # 训练
    sklearn2pmml(pipeline, "iris.pmml", with_repr=True)  # 输出PMML文件

如上代码会生成iris.pmml文件,部分截图如下

file
输入的4个变量是x1, x2, x3和x4,输出变量是y。在这个简单的决策树里,x1没有被用到。

测试案例

为了在Java里测试,我们输出前2个实例的特征和分类结果

# 打印如下的数据用于测试
print(X[0, :], y[0])
print(X[1, :], y[1])

输出结果

[5.1 3.5 1.4 0.2] 0
[4.9 3. 1.4 0.2] 0

Java逻辑

  • java: 1.8
  • gradle依赖包
    • implementation("org.jpmml:pmml-evaluator:1.4.11")
    • implementation("org.jpmml:pmml-evaluator-extension:1.4.11")

gradle的配置如下
file

代码逻辑

  • 读取PMML文件
  • 读取数据成Map<String, Object>格式
    • key: 特征名字
    • value: 特征值
  • 把原始数据转成PMML需要的Map<FieldName, FieldValue>格式
  • 运行模型得到结果,并且把结果转换成Map<String, Object>格式
    • key: 输出结果字段名字
    • value: 输出结果
  • 提取需要的结果字段

完整Java代码如下

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;

public class TestPmml {
    public static void main(String args[]) throws Exception {
        String fp = "iris.pmml";
        TestPmml obj = new TestPmml();
        Evaluator model = obj.loadPmml(fp);
        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.4, 0.2));
        for (int i = 0; i < inputs.size(); i++) {
            Map<String, Object> output = obj.predict(model, inputs.get(i));
            System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
        }
    }

    private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException {
        InputStream is = new FileInputStream(fp);
        PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance();
        return factory.newModelEvaluator(pmml);
    }

    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (int i = 0; i < targetFields.size(); i++) {
            TargetField field = targetFields.get(i);
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }

}

输出结果如下
file


相关文章:
pmml(模型标准化)
机器学习模型之PMML
SkLearn2PMML
用PMML实现机器学习模型的跨平台上线
使用IDEA及Gradle创建Java项目

为者常成,行者常至