package cn.com.duiba.nezha.alg.common.model.match;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class PMMLModel {

    public static final String POSITIVE_CATEGORY = "1";
    private ModelEvaluator evaluator;

    public double predict(Map<String, Double> features) throws Exception {

        List<InputField> inputFields = evaluator.getInputFields();
        List<TargetField> targetFields = evaluator.getTargetFields();


        // 过模型的原始特征，从画像中获取数据，作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Double featureValue = features.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(featureValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);

        FieldName targetName = targetFields.get(0).getFieldName();
        ProbabilityDistribution value = (ProbabilityDistribution) results.get(targetName);
        double probability = value.getProbability(POSITIVE_CATEGORY);

        return probability;

    }

    public void reloadModel(InputStream is) throws SAXException, JAXBException {

        ModelEvaluator<?> evaluator = new LoadingModelEvaluatorBuilder().load(is).build();
        evaluator.verify();

        this.evaluator = evaluator;
    }


    //批量接口
    public  <T> Map<T, Double> predicts(Map<T, Map<String, Double>> advertMap) throws Exception{

        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isAnyEmpty(advertMap)) {
            return ret;
        }
        for (Map.Entry<T, Map<String, Double>> entry : advertMap.entrySet()) {
            T key = entry.getKey();

            Map<String, Double> advertFeature = advertMap.get(key);

            Double pred = predict(advertFeature);
            ret.put(entry.getKey(), pred);
        }
        return ret;
    }


    public static void main(String[] args) throws Exception {
//        String pathxml = "/Users/houyawei/Desktop/xgboost.pmml";
        String pathxml = "/Users/houyawei/Desktop/xgboost.pmml";
        Map<String, Double> map = new HashMap<String, Double>();
        map.put("bx", 0.0);
        map.put("hjk", 0.0);
        map.put("1", 0.0);
        map.put("2", 0.0);
        map.put("3", 0.0);
        map.put("4", 0.4);
        map.put("5", 0.5);
        map.put("6", 0.6);
        map.put("7", 0.5);
        map.put("8", 0.0);
        map.put("9", 0.0);
        map.put("10", 0.0);
        map.put("11", 0.0);
        map.put("12", 0.0);
        map.put("13", 0.0);
        map.put("14", 0.6);
        map.put("15", 0.0);
        map.put("16", 0.0);
        map.put("17", 0.0);
        map.put("18", 0.0);
        map.put("19", 0.3);
        map.put("20", 0.0);
        map.put("21", 0.0);
        map.put("22", 0.6);
        map.put("23", 0.8);
        map.put("24", 0.8);
        map.put("25", 0.9);
        map.put("26", 0.9);
        map.put("27", 0.1);
        map.put("28", 0.0);
        map.put("29", 0.4);
        map.put("30", 0.6);
        map.put("31", 0.0);
        map.put("32", 0.6);
        map.put("33", 0.1);
        map.put("34", 0.0);
        map.put("35", 0.0);
        map.put("36", 0.0);
        map.put("s1", 0.0);
        map.put("s2", 0.0);
        map.put("s3", 0.4);
        map.put("p1", 0.3);

        // 模型导入
        File file = new File(pathxml);
        InputStream is = new FileInputStream(file);
        PMMLModel test = new PMMLModel();
        test.reloadModel(is);
        double pval = test.predict(map);
        System.out.println(pval);
    }

}
