package test;


import com.alibaba.fastjson.JSON;
import my.util.TensorflowUtils2;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

//import cn.com.duiba.nezha.alg.tf.TensorflowUtils3;
//import com.google.protobuf.ByteString;
//import org.tensorflow.framework.MetaGraphDef;
//import org.tensorflow.framework.SignatureDef;

public class LocalTFModelTest {


    private int DF_MAP_SIZE = 128;

    private static String SIGNATURE_DEF = "serving_default";
    private static String INPUT = "feat_vals";

    public static String DF_OUTPUT = "prob";
    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";


    private SavedModelBundle bundle = null;

    private String inputName = null;

    private Map<String, String> outputNameMap = null;


    public SavedModelBundle getBundle() {
        return this.bundle;
    }


//    public String getOutputName(String output) throws Exception {
//        if (outputNameMap == null) {
//            outputNameMap = new HashMap<>();
//        }
//        if (!outputNameMap.containsKey(output)) {
//
//            SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
//
//            String outputName = sig.getOutputsMap().get(output).getName();
//            outputNameMap.put(output, outputName);
//        }
//
//        return outputNameMap.get(output);
//    }
//
//
//    /**
//     * 加载指定版本模型
//     *
//     * @param path
//     * @param version
//     * @throws Exception
//     */
//    public void loadModel(String path, String version) throws Exception {
//
//        SavedModelBundle bundle = TensorflowUtils.loadModel(path, version);
//        setModel(bundle);
//    }

    public String getOutputName(String output) throws Exception {
        if (outputNameMap == null) {
            outputNameMap = new HashMap<>();
        }
        if (!outputNameMap.containsKey(output)) {

            String outputName = TensorflowUtils2.getOutputName(output, bundle, SIGNATURE_DEF);
            outputNameMap.put(output, outputName);
        }

        return outputNameMap.get(output);
    }


    /**
     * 加载指定版本模型
     *
     * @param path
     * @param version
     * @throws Exception
     */
    public void loadModel(String path, String version) throws Exception {

        SavedModelBundle bundle = TensorflowUtils2.loadModel(path, version);
        setModel(bundle);
    }


    //    public void setModel(SavedModelBundle bundle) throws Exception {
//        this.bundle = bundle;
//
//        this.inputName = TensorflowUtils3.getInputName(INPUT, bundle, SIGNATURE_DEF);
//
////        System.out.println(sig.getInputsMap().get(INPUT));
////        System.out.println(sig.getOutputsMap().get(DF_OUTPUT));
//    }
    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;

        this.inputName = TensorflowUtils2.getInputName(INPUT, bundle, SIGNATURE_DEF);

//        System.out.println(sig.getInputsMap().get(INPUT));
//        System.out.println(sig.getOutputsMap().get(DF_OUTPUT));
    }

    public <T> float[] predict(Tensor<T> x, int size) throws Exception {
        return predict(x, size, DF_OUTPUT);
    }

    public <T> float[] predict(Tensor<T> x, int size, String output) throws Exception {

        String outputName = getOutputName(output);
        List<Tensor<?>> y = bundle.session().runner().feed(inputName, x).fetch(outputName).run();
        float[] result = new float[size];
        y.get(0).copyTo(result);
        y.get(0).copyTo(result);
        return result;

    }


    public <T> Map<T, Double> predict(Map<T, List<Float>> dataMap) throws Exception {
        Map<T, Double> ret = new HashMap<>(DF_MAP_SIZE);

        if (dataMap != null) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());
            Tensor<Float> x = getParamsTensor(keyList, dataMap);
            ret = getPredictReturn(keyList, x, DF_OUTPUT);

        }
        return ret;
    }

    public <T> Map<T, Double> predictStr(Map<T, String> dataMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (dataMap != null) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());
            Tensor<String> x = getParamsTensorStr(keyList, dataMap);
            ret = getPredictReturn(keyList, x, DF_OUTPUT);

        }
        return ret;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<K> x, String output) throws Exception {
        Map<T, Double> ret = new HashMap<>(DF_MAP_SIZE);

        int sampleSize = keyList.size();

        float[] pred = predict(x, sampleSize, output);

        for (int i = 0; i < sampleSize; i++) {
            T key = keyList.get(i);
            ret.put(key, pred[i] + 0.0);
        }

        return ret;
    }


    public static <T> Tensor getParamsTensor(List<T> keyList, Map<T, List<Float>> dataMap) {
        Tensor ret = null;

        if (keyList != null && dataMap != null) {

            int sampleSize = keyList.size();
            int featureSize = dataMap.get(keyList.get(0)).size();
            float[][] matrix = new float[sampleSize][featureSize];
            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                List<Float> feature = dataMap.get(key);

                for (int j = 0; j < featureSize; j++) {
                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create(matrix, Float.class);
        }


        return ret;
    }

    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, String> dataMap) throws Exception {
        Tensor ret = null;

        if (keyList != null && dataMap != null) {

            int sampleSize = keyList.size();
            int featureSize = 1;


            byte[][] matrix = new byte[sampleSize][featureSize];

            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                String feature = dataMap.get(key);
                matrix[i] = feature.getBytes("UTF-8");
            }
            ret = Tensor.create(matrix, String.class);
        }

        return ret;
    }


    public static void main(String[] args) throws Exception {
        LocalTFModelTest model = new LocalTFModelTest();


        model.loadModel("/Users/lwj/Desktop/model/", "20");

        /*构建预测张量*/
        List<Float> tmplist1 = new ArrayList<>();
        List<Float> tmplist2 = new ArrayList<>();
        for (int i = 0; i < 240; i++) {
            tmplist1.add(0f);
            tmplist2.add(0f);
        }

        tmplist1.set(10, 1f);
        tmplist2.set(50, 1f);

        Map<String, List<Float>> x = new HashMap<>();
        x.put("a", tmplist1);
        x.put("b", tmplist2);

//        System.out.println(JSON.toJSONString(model.predict(x)));

    }

}