package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import cn.com.duiba.nezha.alg.model.CODER;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
//import cn.com.duiba.wolf.perf.timeprofile.DBTimeProfile;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import com.alibaba.fastjson.JSON;
import io.grpc.Status;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.Code;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


public class LocalTFModel {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(CODER.class);

    private int DF_MAP_SIZE = 128;

    private static String SIGNATURE_DEF = "serving_default";

    private static String DF_INPUT = "feat_vals";


    public static String DF_OUTPUT = "prob";
    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";
    public static String MMOE_CVR_OUTPUT = "prob0";
    public static String MMOE_BCVR_OUTPUT = "prob1";


    private SavedModelBundle bundle = null;

    private String version = null;

    public String getVersion() {
        return version;
    }

    private String path = null;

    private String inputName = null;

    private Map<String, String> outputNameMap = null;


    public SavedModelBundle getBundle() {
        return this.bundle;
    }


    public String getOutputName(String output) throws Exception {
        if (outputNameMap == null) {
            outputNameMap = new HashMap<>();
        }
        if (!outputNameMap.containsKey(output)) {

            try {
                String outputName = TensorflowUtils.getOutputName(output, bundle, SIGNATURE_DEF);

                if (outputName != null) {
                    outputNameMap.put(output, outputName);
                }

            } catch (Exception e) {
//                e.printStackTrace();
                System.out.println("getOutputName , empty");
                logger.info("getOutputName , empty", e);
                throw e;

            }

        }

        return outputNameMap.get(output);
    }


    /**
     * 加载最新版本模型
     *
     * @param path
     * @throws Exception
     */
    public void loadModel(String path) throws Exception {
        loadModel(path, null);
    }

    /**
     * 加载指定版本模型
     *
     * @param path
     * @param version
     * @throws Exception
     */
    public void loadModel(String path, String version) throws Exception {

        if (version == null) {
            version = TensorflowUtils.getLastVersion(path);
        }
        SavedModelBundle bundle = TensorflowUtils.loadModelV1(path, version);

        setModel(bundle);
        this.path = path;
        this.version = version;
    }


    /**
     * 关闭模型
     *
     * @throws Exception
     */
    public void close() throws Exception {

        if (bundle != null) {

            logger.info("close model  ,paht=" + path + "with version=" + version);
            System.out.println("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(bundle));
            bundle.close();
            bundle = null;
        }
    }


    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;

        this.inputName = TensorflowUtils.getInputName(DF_INPUT, bundle, SIGNATURE_DEF);

    }

    public <T> float[] predict(Tensor<T> x, int size) throws Exception {
        return predict(x, size, DF_OUTPUT);
    }

    public <T> float[] predict(Tensor<T> x, int size, String output) throws Exception {

        float[] ret = null;
        String outputName = getOutputName(output);


        if (outputName == null || inputName == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,outputName= " + outputName + ",inputName=" + inputName);
        }

        if (x == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,x=null");
        }

        if (bundle == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,bundle=null");
        }


        try {

            List<Tensor<?>> y = bundle.session().runner().feed(inputName, x).fetch(outputName).run();
            float[] result = new float[size];
            y.get(0).copyTo(result);

            y.get(0).close();


            ret = result;
        } catch (Exception e) {
            System.out.println(e);
            logger.info("predict(Tensor<T> x, int size, String output)  run error", e);
        } finally {
//                DBTimeProfile.release();
        }

        return ret;

    }


    public <T> Map<PredictResultType, Map<T, Double>> predictStrMMOE(Map<T, String> dataMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        if (AssertUtil.isAllNotEmpty(dataMap)) {

            List<T> keyList = new ArrayList<>(dataMap.keySet());
            Tensor<String> x = getParamsTensorStr(keyList, dataMap);
            ret.put(PredictResultType.CVR, getPredictReturn(keyList, x, MMOE_CVR_OUTPUT));
            ret.put(PredictResultType.BACKEND_CVR, getPredictReturn(keyList, x, MMOE_BCVR_OUTPUT));
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictESMM(Map<T, List<Float>> dataMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        if (AssertUtil.isAllNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());
            try (Tensor<Float> x = getParamsTensor(keyList, dataMap)) {
                ret.put(PredictResultType.CTR, getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
                ret.put(PredictResultType.CVR, getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));

            } catch (Exception e) {
                throw e;
            } finally {


            }

        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictStrESMM(Map<T, String> dataMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        if (AssertUtil.isAllNotEmpty(dataMap)) {

            List<T> keyList = new ArrayList<>(dataMap.keySet());
            Tensor<String> x = getParamsTensorStr(keyList, dataMap);
            ret.put(PredictResultType.CTR, getPredictReturn(keyList, x, ESMM_CTR_OUTPUT));
            ret.put(PredictResultType.CVR, getPredictReturn(keyList, x, ESMM_CVR_OUTPUT));
        }
        return ret;
    }


    public <T> Map<T, Double> predict(Map<T, List<Float>> dataMap) throws Exception {
        Map<T, Double> ret = new HashMap<>(DF_MAP_SIZE);

        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());
            Tensor<Float> x = getParamsTensor(keyList, dataMap);
            ret = getPredictReturn(keyList, x, DF_OUTPUT);

        }
        return ret;
    }

    public <T> Map<T, Double> predictInt(Map<T, List<Integer>> dataMap) throws Exception {
        return predictInt(dataMap, DF_OUTPUT);
    }

    public <T> Map<T, Double> predictInt(Map<T, List<Integer>> dataMap, String output) throws Exception {

        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());

            try (Tensor<Integer> x = getParamsTensorInt(keyList, dataMap)) {
                if (x == null) {
                    logger.info(" predictStr input is invalid, x=null");
                } else {
                    try {
//                        DBTimeProfile.enter("LocalTFModel.predictStr.getPredictReturn");
                        ret = getPredictReturn(keyList, x, output);
                    } finally {
//                        DBTimeProfile.release();
                    }
                }
            } catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return ret;
    }


    //    public <T,K> Map<T, Double> predictV2(Map<T, List<K>> dataMap) throws Exception {
//        return predictV2(dataMap, DF_OUTPUT);
//    }
//
    public <T, K> Map<T, Double> predictV2(Map<T, List<Integer>> dataMap, String output) throws Exception {

        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());

            try (Tensor<K> x = getParamsTensorInt(keyList, dataMap)) {
                if (x == null) {
                    logger.info(" predictStr input is invalid, x=null");
                } else {
                    try {
//                        DBTimeProfile.enter("LocalTFModel.predictStr.getPredictReturn");
                        ret = getPredictReturn(keyList, x, output);
                    } finally {
//                        DBTimeProfile.release();
                    }
                }
            } catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return ret;
    }


    public <T> Map<T, Double> predictStr(Map<T, String> dataMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());

            try (Tensor<String> x = getParamsTensorStr(keyList, dataMap)) {
                if (x == null) {
                    logger.info(" predictStr input is invalid, x=null");
                } else {
                    try {
//                        DBTimeProfile.enter("LocalTFModel.predictStr.getPredictReturn");
                        ret = getPredictReturn(keyList, x, DF_OUTPUT);
                    } finally {
//                        DBTimeProfile.release();
                    }
                }
            } catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return ret;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<K> x, String output) throws Exception {
        Map<T, Double> ret = new HashMap<>(DF_MAP_SIZE);

        int sampleSize = keyList.size();


        float[] pred = predict(x, sampleSize, output);

        if (pred != null) {
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                ret.put(key, MathUtil.formatDouble(pred[i] + 0.000001, 6));
            }
        }


        return ret;
    }


    public static <T> Tensor getParamsTensor(List<T> keyList, Map<T, List<Float>> dataMap) {
        Tensor ret = null;

        if (AssertUtil.isAllNotEmpty(keyList, dataMap)) {

            int sampleSize = keyList.size();
            int featureSize = dataMap.get(keyList.get(0)).size();

            float[][] matrix = new float[sampleSize][featureSize];

            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                List<Float> feature = dataMap.get(key);

                for (int j = 0; j < featureSize; j++) {
                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create(matrix, Float.class);
        }


        return ret;
    }

    public static <T> Tensor getParamsTensorInt(List<T> keyList, Map<T, List<Integer>> dataMap) {
        Tensor ret = null;

        if (AssertUtil.isAllNotEmpty(keyList, dataMap)) {

            int sampleSize = keyList.size();
            int featureSize = dataMap.get(keyList.get(0)).size();

            int[][] matrix = new int[sampleSize][featureSize];

            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                List<Integer> feature = dataMap.get(key);

                for (int j = 0; j < featureSize; j++) {
                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create(matrix, Integer.class);
        }


        return ret;
    }

    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, String> dataMap) throws Exception {
        Tensor ret = null;

        if (AssertUtil.isAllNotEmpty(keyList, dataMap)) {


            int sampleSize = keyList.size();
            int featureSize = 1;

            byte[][] matrix = new byte[sampleSize][featureSize];

            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                String feature = dataMap.get(key);
                matrix[i] = feature.getBytes("UTF-8");
            }
            ret = Tensor.create(matrix, String.class);

        }

        return ret;
    }


    public Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {

            File[] listFiles = file.listFiles();
            for (int i = 0; i < listFiles.length; i++) {
                File f = listFiles[i];
                String fName = f.getName();
                if (NumberUtil.isNumber(fName)) {

                    Long version = Long.valueOf(fName);

                    if (ret == null || ret < version) {

                        ret = version;
                    }
                }

            }
        }

        return ret;

    }


//
//    public static <T,K> Tensor getParamsTensorV2(List<T> keyList, Map<T, List<K>> dataMap,String type) throws Exception{
//        Tensor ret = null;
//        if(type.equals("str")){
//            return getParamsTensorStr(keyList,dataMap);
//        }
//
//        if(type.equals("int")){
//            return getParamsTensorInt(keyList, dataMap);
//        }
//
//        if(type.equals("float")){
//            return getParamsTensorInt(keyList, dataMap);
//        }
//
//        return ret;
//    }
//


    public static void main(String[] args) throws Exception {
        LocalTFModel model = new LocalTFModel();

        model.loadModel("/Users/lwj/Documents/pythonProjectFiles/tf/data/serving-model/deepfm-v021/");

//        model.loadModel("/Users/lwj/Desktop/model/", "1");

        /*构建预测张量*/
        List<Integer> tmplist1 = new ArrayList<>();
        for (int i = 0; i < 35; i++) {
            tmplist1.add(1);
        }


        Map<String, List<Integer>> dataMap = new HashMap<>();
        dataMap.put("s1", tmplist1);
        dataMap.put("s2", tmplist1);
        int featureSize = 100;
        Map<String, Double> ret = model.predictInt(dataMap);

        System.out.println("ret=" + JSON.toJSONString(ret));
        //System.out.println(JSON.toJSONString(model.predict(x)));

    }


}