/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.common.model.match;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.xml.sax.SAXException;

public class PMMLModel {
    public static final String POSITIVE_CATEGORY = "1";
    private ModelEvaluator evaluator;

    public double predict(Map<String, Double> features) throws Exception {
        List inputFields = this.evaluator.getInputFields();
        List targetFields = this.evaluator.getTargetFields();
        LinkedHashMap<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Double featureValue = features.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare((Object)featureValue);
            arguments.put(inputFieldName, inputFieldValue);
        }
        Map results = this.evaluator.evaluate(arguments);
        FieldName targetName = ((TargetField)targetFields.get(0)).getFieldName();
        ProbabilityDistribution value = (ProbabilityDistribution)results.get(targetName);
        double probability = value.getProbability(POSITIVE_CATEGORY);
        return probability;
    }

    public void reloadModel(InputStream is) throws SAXException, JAXBException {
        ModelEvaluator evaluator = new LoadingModelEvaluatorBuilder().load(is).build();
        evaluator.verify();
        this.evaluator = evaluator;
    }

    public <T> Map<T, Double> predicts(Map<T, Map<String, Double>> advertMap) throws Exception {
        HashMap<T, Double> ret = new HashMap<T, Double>();
        if (AssertUtil.isAnyEmpty(advertMap)) {
            return ret;
        }
        for (Map.Entry<T, Map<String, Double>> entry : advertMap.entrySet()) {
            T key = entry.getKey();
            Map<String, Double> advertFeature = advertMap.get(key);
            Double pred = this.predict(advertFeature);
            ret.put(entry.getKey(), pred);
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        String pathxml = "/Users/houyawei/Desktop/xgboost.pmml.0";
        HashMap<String, Double> map = new HashMap<String, Double>();
        map.put("w1", 0.1);
        map.put("w2", 0.0);
        map.put("w3", 0.0);
        map.put("w4", 0.0);
        map.put("w5", 0.0);
        map.put("w6", 0.0);
        map.put("w7", 0.2);
        map.put("w8", 0.1);
        map.put("w9", 0.0);
        map.put("w10", 0.0);
        map.put("w11", 0.0);
        map.put("w12", 0.0);
        map.put("w13", 0.0);
        map.put("w14", 0.0);
        map.put("w15", 0.0);
        map.put("w16", 0.0);
        map.put("w17", 0.0);
        map.put("w18", 0.8);
        map.put("w19", 0.0);
        map.put("w20", 0.0);
        map.put("w21", 0.0);
        map.put("w22", 0.0);
        map.put("w23", 0.0);
        map.put("w24", 0.1);
        map.put("w25", 0.0);
        map.put("w26", 0.4);
        map.put("w27", 0.0);
        map.put("w28", 0.5);
        map.put("w29", 0.0);
        map.put("w30", 0.2);
        map.put("w31", 0.0);
        map.put("w32", 0.3);
        map.put("w33", 0.0);
        map.put("w34", 0.0);
        map.put("w35", 0.0);
        map.put("w36", 0.0);
        map.put("s1", 0.1);
        map.put("s2", 0.2);
        map.put("s3", 0.0);
        map.put("p1", 0.2);
        File file = new File(pathxml);
        FileInputStream is = new FileInputStream(file);
        PMMLModel test2 = new PMMLModel();
        test2.reloadModel(is);
        double pval = test2.predict(map);
        System.out.println(pval);
    }
}

