/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.stat.utils.tf;

import cn.com.duiba.nezha.alg.stat.utils.AssertUtil;
import cn.com.duiba.nezha.alg.stat.utils.tf.PredictResultType;
import cn.com.duiba.nezha.alg.stat.utils.tf.TensorflowUtils;
import com.alibaba.fastjson.JSON;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class LocalTFModel {
    private int DF_MAP_SIZE = 128;
    private static String SIGNATURE_DEF = "serving_default";
    private static String INPUT = "feat_vals";
    public static String DF_OUTPUT = "prob";
    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";
    private SavedModelBundle bundle = null;
    private String inputName = null;
    private Map<String, String> outputNameMap = null;
    private String defaultOutputName = null;

    public String getOutputName(String output) throws Exception {
        if (this.outputNameMap == null) {
            this.outputNameMap = new HashMap<String, String>();
        }
        if (!this.outputNameMap.containsKey(output)) {
            SignatureDef sig = MetaGraphDef.parseFrom((byte[])this.bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
            String outputName = ((TensorInfo)sig.getOutputsMap().get(output)).getName();
            this.outputNameMap.put(output, outputName);
        }
        return this.outputNameMap.get(output);
    }

    public void loadModel(String path, String version) throws Exception {
        SavedModelBundle bundle = TensorflowUtils.loadModel(path, version);
        this.setModel(bundle);
    }

    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        this.inputName = ((TensorInfo)sig.getInputsMap().get(INPUT)).getName();
    }

    public <T> float[] predict(Tensor<T> x, int size) throws Exception {
        return this.predict(x, size, DF_OUTPUT);
    }

    public <T> float[] predict(Tensor<T> x, int size, String output) throws Exception {
        String outputName = this.getOutputName(output);
        List y = this.bundle.session().runner().feed(this.inputName, x).fetch(outputName).run();
        float[] result = new float[size];
        ((Tensor)y.get(0)).copyTo((Object)result);
        ((Tensor)y.get(0)).copyTo((Object)result);
        return result;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictESMM(Map<T, List<Float>> dataMap) throws Exception {
        HashMap<PredictResultType, Map<T, Double>> ret = new HashMap<PredictResultType, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensor(keyList, dataMap);
            ret.put(PredictResultType.CTR, this.getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
            ret.put(PredictResultType.CVR, this.getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictStrESMM(Map<T, String> dataMap) throws Exception {
        HashMap<PredictResultType, Map<T, Double>> ret = new HashMap<PredictResultType, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensorStr(keyList, dataMap);
            ret.put(PredictResultType.CTR, this.getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
            ret.put(PredictResultType.CVR, this.getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));
        }
        return ret;
    }

    public <T> Map<T, Double> predict(Map<T, List<Float>> dataMap) throws Exception {
        Map<Object, Object> ret = new HashMap(this.DF_MAP_SIZE);
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensor(keyList, dataMap);
            ret = this.getPredictReturn(keyList, x, DF_OUTPUT);
        }
        return ret;
    }

    public <T> Map<T, Double> predictStr(Map<T, String> dataMap) throws Exception {
        Map<Object, Object> ret = new HashMap();
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensorStr(keyList, dataMap);
            ret = this.getPredictReturn(keyList, x, DF_OUTPUT);
        }
        return ret;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<K> x, String output) throws Exception {
        HashMap<T, Double> ret = new HashMap<T, Double>(this.DF_MAP_SIZE);
        int sampleSize = keyList.size();
        float[] pred = this.predict(x, sampleSize, output);
        for (int i = 0; i < sampleSize; ++i) {
            T key = keyList.get(i);
            ret.put(key, (double)pred[i] + 0.0);
        }
        return ret;
    }

    public static <T> Tensor getParamsTensor(List<T> keyList, Map<T, List<Float>> dataMap) {
        Tensor ret = null;
        if (AssertUtil.isAllNotEmpty(keyList, dataMap)) {
            int sampleSize = keyList.size();
            int featureSize = dataMap.get(keyList.get(0)).size();
            float[][] matrix = new float[sampleSize][featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                List<Float> feature = dataMap.get(key);
                for (int j = 0; j < featureSize; ++j) {
                    matrix[i][j] = feature.get(j).floatValue();
                }
            }
            ret = Tensor.create((Object)matrix, Float.class);
        }
        return ret;
    }

    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, String> dataMap) {
        Tensor ret = null;
        if (AssertUtil.isAllNotEmpty(keyList, dataMap)) {
            int sampleSize = keyList.size();
            boolean featureSize = true;
            String[] matrix = new String[sampleSize];
            for (int i = 0; i < sampleSize; ++i) {
                String feature;
                T key = keyList.get(i);
                matrix[i] = feature = dataMap.get(key);
            }
            ret = Tensor.create((Object)matrix, String.class);
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        LocalTFModel model = new LocalTFModel();
        model.loadModel("/Users/lwj/Desktop/model/", "1");
        ArrayList<Float> tmplist1 = new ArrayList<Float>();
        ArrayList<Float> tmplist2 = new ArrayList<Float>();
        for (int i = 0; i < 240; ++i) {
            tmplist1.add(Float.valueOf(0.0f));
            tmplist2.add(Float.valueOf(0.0f));
        }
        tmplist1.set(10, Float.valueOf(1.0f));
        tmplist2.set(50, Float.valueOf(1.0f));
        HashMap<String, ArrayList<Float>> x = new HashMap<String, ArrayList<Float>>();
        x.put("a", tmplist1);
        x.put("b", tmplist2);
        System.out.println(JSON.toJSONString(model.predict(x)));
    }
}

