package cn.com.duiba.nezha.alg.common.model.match;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class PMMLModel {

    public static final String POSITIVE_CATEGORY = "1";
    private ModelEvaluator evaluator;

    public double predict(Map<String, Double> features) throws Exception {

        List<InputField> inputFields = evaluator.getInputFields();
        List<TargetField> targetFields = evaluator.getTargetFields();


        // 过模型的原始特征，从画像中获取数据，作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Double featureValue = features.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(featureValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);

        FieldName targetName = targetFields.get(0).getFieldName();
        ProbabilityDistribution value = (ProbabilityDistribution) results.get(targetName);
        double probability = value.getProbability(POSITIVE_CATEGORY);

        return probability;

    }

    public void reloadModel(InputStream is) throws SAXException, JAXBException {

        ModelEvaluator<?> evaluator = new LoadingModelEvaluatorBuilder().load(is).build();
        evaluator.verify();

        this.evaluator = evaluator;
    }


    //批量接口
    public  <T> Map<T, Double> predicts(Map<T, Map<String, Double>> advertMap) throws Exception{

        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isAnyEmpty(advertMap)) {
            return ret;
        }
        for (Map.Entry<T, Map<String, Double>> entry : advertMap.entrySet()) {
            T key = entry.getKey();

            Map<String, Double> advertFeature = advertMap.get(key);

            Double pred = predict(advertFeature);
            ret.put(entry.getKey(), pred);
        }
        return ret;
    }


    public static void main(String[] args) throws Exception {
//        String pathxml = "/Users/houyawei/Desktop/xgboost.pmml";
        String pathxml = "/Users/houyawei/Desktop/xgboost.pmml.0";
        Map<String, Double> map = new HashMap<String, Double>();
        map.put("w1", 0.1);
        map.put("w2", 0.0);
        map.put("w3", 0.0);
        map.put("w4", 0.0);
        map.put("w5", 0.0);
        map.put("w6", 0.0);
        map.put("w7", 0.2);
        map.put("w8", 0.1);
        map.put("w9", 0.0);
        map.put("w10", 0.0);
        map.put("w11", 0.0);
        map.put("w12", 0.0);
        map.put("w13", 0.0);
        map.put("w14", 0.0);
        map.put("w15", 0.0);
        map.put("w16", 0.0);
        map.put("w17", 0.0);
        map.put("w18", 0.8);
        map.put("w19", 0.0);
        map.put("w20", 0.0);
        map.put("w21", 0.0);
        map.put("w22", 0.0);
        map.put("w23", 0.0);
        map.put("w24", 0.1);
        map.put("w25", 0.0);
        map.put("w26", 0.4);
        map.put("w27", 0.0);
        map.put("w28", 0.5);
        map.put("w29", 0.0);
        map.put("w30", 0.2);
        map.put("w31", 0.0);
        map.put("w32", 0.3);
        map.put("w33", 0.0);
        map.put("w34", 0.0);
        map.put("w35", 0.0);
        map.put("w36", 0.0);
        map.put("s1", 0.1);
        map.put("s2", 0.2);
        map.put("s3", 0.0);
        map.put("p1", 0.2);

        // 模型导入
        File file = new File(pathxml);
        InputStream is = new FileInputStream(file);
        PMMLModel test = new PMMLModel();
        test.reloadModel(is);
        double pval = test.predict(map);
        System.out.println(pval);
    }

}
