package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.api.model.TFModel;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import cn.com.duiba.nezha.alg.model.vo.CodeSizeDo;
import com.alibaba.fastjson.JSON;
import lombok.Data;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.net.Inet4Address;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

//import cn.com.duiba.wolf.perf.timeprofile.DBTimeProfile;


@Data
public class LocalTFModelV2 extends TensorflowUtils implements TFModel  {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(LocalTFModelV2.class);

    private int DF_MAP_SIZE = 128;
    private static String SIGNATURE_DEF = "serving_default";


    public static String DF_OUTPUT = "prob";

    private static String ALL_INPUT_1 = "input_s";//单值特征
    private static String ALL_INPUT_2 = "input_m";//多值特征
    private static String ALL_INPUT_3 = "input_f";//浮点特征

    private SavedModelBundle bundle = null;

    private String version = null;
    private String path = null;


    private String inputName1 = null;
    private String inputName2 = null;
    private String inputName3 = null;


    private Map<String, String> outputNameMap = null;

    /**
     * 加载最新版本模型
     *
     * @param path
     * @throws Exception
     */
    public void loadModel(String path) throws Exception {
        loadModel(path, null);
    }

    /**
     * 加载指定版本模型
     *
     * @param path
     * @param version
     * @throws Exception
     */
    public void loadModel(String path, String version) throws Exception {

        if (version == null) {
            version = getLastVersion(path);
        }
        this.bundle = getModelBundle(path, version);

        this.inputName1 = TensorflowUtils.getInputName(ALL_INPUT_1, bundle, SIGNATURE_DEF);
        this.inputName2 = TensorflowUtils.getInputName(ALL_INPUT_2, bundle, SIGNATURE_DEF);
        this.inputName3 = TensorflowUtils.getInputName(ALL_INPUT_3, bundle, SIGNATURE_DEF);

        this.path = path;
        this.version = version;

    }

    public String getOutputName(String output) throws Exception {
        if (outputNameMap == null) {
            outputNameMap = new HashMap<>();
        }
        if (!outputNameMap.containsKey(output)) {

            try {
                String outputName = TensorflowUtils.getOutputName(output, bundle, SIGNATURE_DEF);

                if (outputName != null) {
                    outputNameMap.put(output, outputName);
                }

            } catch (Exception e) {
//                e.printStackTrace();
                System.out.println("getOutputName , empty");
                logger.info("getOutputName empty, output: {}, model path: {}{}", output, path, version, e);
                throw e;

            }

        }
        if(!outputNameMap.containsKey(output)){

        }


        return outputNameMap.get(output);
    }

    /**
     * 关闭模型
     *
     * @throws Exception
     */
    public void close() throws Exception {
        if (bundle != null) {
            logger.info("close model  ,path=" + path);
            bundle.close();
            bundle = null;
        }
    }


    public <T> float[] predict(Tensor<Integer> input1, Tensor<String> input2, Tensor<Float> input3, int size, String output) throws Exception {

        float[] ret = null;
        String outputName = getOutputName(output);

        if (outputName == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,outputName= " + outputName);
        }

        Session.Runner runner = bundle.session().runner();
        if (inputName1 != null) {
            runner.feed(inputName1, input1);
        }

        if (inputName2 != null) {
            runner.feed(inputName2, input2);
        }
        if (inputName3 != null) {
            runner.feed(inputName3, input3);
        }

//        System.out.println("inputName1="+inputName1);
//        System.out.println("inputName2="+inputName2);
//        System.out.println("inputName3="+inputName3);
//        System.out.println("outputName="+outputName);

        List<Tensor<?>> y = runner.fetch(outputName).run();
        //System.out.println(y);
        float[] result = new float[size];
        y.get(0).copyTo(result);
        y.get(0).close();

        ret = result;

        //System.out.println(JSON.toJSONString(result));
        return ret;

    }

    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, CodeSizeDo codeSizeDo) throws Exception {
        return predictAll(dataMap, DF_OUTPUT, codeSizeDo);
    }


    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, String output, CodeSizeDo codeSizeDo) throws Exception {

        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());

            int sampleSize = keyList.size();
            try (Tensor<Integer> input1 = getParamsTensorInt(keyList, dataMap, sampleSize, codeSizeDo.singleSize);
                 Tensor<String> input2 = getParamsTensorStr(keyList, dataMap,  sampleSize, codeSizeDo.multiSize);
                 Tensor<Float> input3 = getParamsTensorFloat(keyList, dataMap, sampleSize, codeSizeDo.floatSize);
            ) {
                if (input1 == null) {
                    logger.info(" predictStr input1 is invalid, x=null");
                } else {
                    try {
//                        DBTimeProfile.enter("LocalTFModel.predictStr.getPredictReturn");
                        ret = getPredictReturn(keyList, input1, input2, input3, output);
                    } finally {
//                        DBTimeProfile.release();
                    }
                }
            } catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return ret;
    }


    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<Integer> input1, Tensor<String> input2, Tensor<Float> input3, String output) throws Exception {
        Map<T, Double> ret = new HashMap<>(DF_MAP_SIZE);

        int sampleSize = keyList.size();

        float[] pred = predict(input1,input2,input3, sampleSize, output);

        if (pred != null) {

            for (int i = 0; i < sampleSize; i++) {
                ret.put(keyList.get(i), MathUtil.formatDouble(pred[i] + 0.000001, 6));
            }
        }
        return ret;
    }


    public static <T> Tensor getParamsTensorFloat(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {

        Tensor ret = null;

        if (featureSize > 0) {
            float[][] matrix = new float[sampleSize][featureSize];

            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Float> feature = codeDo.getFloatCode();

                for (int j = 0; j < featureSize; j++) {

                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create(matrix, Float.class);
            //System.out.println("getParamsTensorFloat"+JSON.toJSONString(matrix));
        }

        return ret;

    }


    public static <T> Tensor getParamsTensorInt(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        Tensor ret = null;

        if (featureSize > 0) {

            int[][] matrix = new int[sampleSize][featureSize];

            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Integer> feature = codeDo.getSingleCode();

                for (int j = 0; j < featureSize; j++) {
                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create(matrix, Integer.class);

            //System.out.println("getParamsTensorInt"+JSON.toJSONString(matrix));
        }




        return ret;
    }



    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) throws Exception {
        Tensor ret = null;

        if (featureSize > 0) {

            byte[][][] matrix = new byte[sampleSize][featureSize][];

            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {

                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<String> featureList = codeDo.getMultiCode();

                for (int j = 0; j < featureSize; j++) {
                    String feature=featureList.get(j);
                    matrix[i][j]= feature.getBytes("UTF-8");
                }
            }
            ret = Tensor.create(matrix, String.class);
            //System.out.println("getParamsTensorStr"+JSON.toJSONString(ret.shape()));
        }


        return ret;
    }



    public static void main(String[] args) throws Exception {
        LocalTFModelV2 model = new LocalTFModelV2();

//        model.loadModel("/Users/lwj/Documents/pythonProjectFiles/tf/data/serving-model/deepfm36/", null);
        model.loadModel("/Users/panhangyi/data/serving-model/mid-ad-cvr-v0919/", "202309221515");

//        model.loadModel("/Users/lwj/Desktop/model/", "1");
        System.out.println(JSON.toJSONString(model));

        /*构建预测张量*/
        List<Integer> tmplist1 = new ArrayList<>();
        List<String> tmplist2 = new ArrayList<>();
        List<Float> tmplist3 = new ArrayList<>();
        for (int i = 0; i <62; i++) {
            tmplist1.add(10);
        }

        for (int i = 0; i < 3; i++) {
            tmplist2.add("0,1,2");
        }
        for (int i = 0; i < 9; ++i) {
            tmplist3.add(0.3f);
        }
//        tmplist3.add(0.2f);

        CodeSizeDo codeSizeDo = new CodeSizeDo();
        codeSizeDo.setSingleSize(62);
        codeSizeDo.setMultiSize(3);
        codeSizeDo.setFloatSize(9);

        Map<String, CodeDo> dataMap = new HashMap<>();
        CodeDo codeDo=new CodeDo();
        codeDo.setSingleCode(tmplist1);
        codeDo.setMultiCode(tmplist2);
        codeDo.setFloatCode(tmplist3);
        dataMap.put("s1",codeDo);

        Map<String, Double> ret = model.predictAll(dataMap,codeSizeDo);

        System.out.println("ret=" + JSON.toJSONString(ret));


    }


}