/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.TensorflowUtils;
import com.alibaba.fastjson.JSON;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

public class LocalTFModel {
    private int DF_MAP_SIZE = 128;
    private static String SIGNATURE_DEF = "serving_default";
    private static String INPUT = "feat_vals";
    public static String DF_OUTPUT = "prob";
    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";
    private SavedModelBundle bundle = null;
    private String inputName = null;
    private Map<String, String> outputNameMap = null;

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getOutputName(String output) throws Exception {
        if (this.outputNameMap == null) {
            this.outputNameMap = new HashMap<String, String>();
        }
        if (!this.outputNameMap.containsKey(output)) {
            String outputName = TensorflowUtils.getOutputName(output, this.bundle, SIGNATURE_DEF);
            this.outputNameMap.put(output, outputName);
        }
        return this.outputNameMap.get(output);
    }

    public void loadModel(String path, String version) throws Exception {
        System.out.println("loadModel with version=" + version);
        SavedModelBundle bundle = TensorflowUtils.loadModel(path, version);
        this.setModel(bundle);
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            System.out.println("bundle.size before close :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
            this.bundle.close();
            System.out.println("bundle.size after close :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
            this.bundle = null;
            System.out.println("bundle.size set null :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
        }
    }

    public void loadModel(String path) throws Exception {
        Long lastVersion = this.getLastVersion(path);
        this.loadModel(path, lastVersion + "");
    }

    public Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {
            File[] listFiles = file.listFiles();
            for (int i = 0; i < listFiles.length; ++i) {
                File f = listFiles[i];
                String fName = f.getName();
                if (!NumberUtil.isNumber((String)fName)) continue;
                Long version = Long.valueOf(fName);
                if (ret != null && ret >= version) continue;
                ret = version;
            }
        }
        return ret;
    }

    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputName = TensorflowUtils.getInputName(INPUT, bundle, SIGNATURE_DEF);
    }

    public <T> float[] predict(Tensor<T> x, int size) throws Exception {
        return this.predict(x, size, DF_OUTPUT);
    }

    public <T> float[] predict(Tensor<T> x, int size, String output) throws Exception {
        String outputName = this.getOutputName(output);
        List y = this.bundle.session().runner().feed(this.inputName, x).fetch(outputName).run();
        float[] result = new float[size];
        ((Tensor)y.get(0)).copyTo((Object)result);
        ((Tensor)y.get(0)).copyTo((Object)result);
        return result;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictESMM(Map<T, List<Float>> dataMap) throws Exception {
        HashMap<PredictResultType, Map<T, Double>> ret = new HashMap<PredictResultType, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{dataMap})) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensor(keyList, dataMap);
            ret.put(PredictResultType.CTR, this.getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
            ret.put(PredictResultType.CVR, this.getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictStrESMM(Map<T, String> dataMap) throws Exception {
        HashMap<PredictResultType, Map<T, Double>> ret = new HashMap<PredictResultType, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{dataMap})) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensorStr(keyList, dataMap);
            ret.put(PredictResultType.CTR, this.getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
            ret.put(PredictResultType.CVR, this.getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));
        }
        return ret;
    }

    public <T> Map<T, Double> predict(Map<T, List<Float>> dataMap) throws Exception {
        Map<Object, Object> ret = new HashMap(this.DF_MAP_SIZE);
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensor(keyList, dataMap);
            ret = this.getPredictReturn(keyList, x, DF_OUTPUT);
        }
        return ret;
    }

    public <T> Map<T, Double> predictStr(Map<T, String> dataMap) throws Exception {
        Map<Object, Object> ret = new HashMap();
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            Tensor x = LocalTFModel.getParamsTensorStr(keyList, dataMap);
            ret = this.getPredictReturn(keyList, x, DF_OUTPUT);
        }
        return ret;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<K> x, String output) throws Exception {
        HashMap<T, Double> ret = new HashMap<T, Double>(this.DF_MAP_SIZE);
        int sampleSize = keyList.size();
        float[] pred = this.predict(x, sampleSize, output);
        for (int i = 0; i < sampleSize; ++i) {
            T key = keyList.get(i);
            ret.put(key, MathUtil.formatDouble((double)((double)pred[i] + 0.0), (int)6));
        }
        return ret;
    }

    public static <T> Tensor getParamsTensor(List<T> keyList, Map<T, List<Float>> dataMap) {
        Tensor ret = null;
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{keyList, dataMap})) {
            int sampleSize = keyList.size();
            int featureSize = dataMap.get(keyList.get(0)).size();
            float[][] matrix = new float[sampleSize][featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                List<Float> feature = dataMap.get(key);
                for (int j = 0; j < featureSize; ++j) {
                    matrix[i][j] = feature.get(j).floatValue();
                }
            }
            ret = Tensor.create((Object)matrix, Float.class);
        }
        return ret;
    }

    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, String> dataMap) throws Exception {
        Tensor ret = null;
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{keyList, dataMap})) {
            int sampleSize = keyList.size();
            int featureSize = 1;
            byte[][] matrix = new byte[sampleSize][featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                String feature = dataMap.get(key);
                matrix[i] = feature.getBytes("UTF-8");
            }
            ret = Tensor.create((Object)matrix, String.class);
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        LocalTFModel model = new LocalTFModel();
        model.loadModel("/Users/lwj/Desktop/model/", "1");
        ArrayList<Float> tmplist1 = new ArrayList<Float>();
        ArrayList<Float> tmplist2 = new ArrayList<Float>();
        for (int i = 0; i < 240; ++i) {
            tmplist1.add(Float.valueOf(0.0f));
            tmplist2.add(Float.valueOf(0.0f));
        }
        tmplist1.set(10, Float.valueOf(1.0f));
        tmplist2.set(50, Float.valueOf(1.0f));
        HashMap<String, ArrayList<Float>> x = new HashMap<String, ArrayList<Float>>();
        x.put("a", tmplist1);
        x.put("b", tmplist2);
        System.out.println(JSON.toJSONString(model.predict(x)));
    }
}

