package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.util.E2ETFUtils;
import cn.com.duiba.nezha.alg.api.model.util.TensorflowUtils;
import com.alibaba.fastjson.JSON;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import lombok.Data;
import org.apache.commons.lang.StringUtils;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.nio.charset.StandardCharsets.UTF_8;

/**
 * Adx本地PB模型类
 *
 * @author caiyida@duiba.com.cn
 * @date 2022/3/24
 */
@Data
public class AdxLocalTFModel {


    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(E2ELocalTFModel.class);
    //默认导出pb时候的签名，基本不变
    private static final String SIGNATURE_DEF = "serving_default";
    //java跟tf交互的入口，取输入输出建session全靠它（在load模型的时候获取到的）
    private SavedModelBundle bundle;
    //模型的路径
    private String path;
    //模型创建的时间戳
    private Long version;
    //输入的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> inputSpecMap;
    //输出的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> outputSpecMap;

    //从路径加载最新模型
    public void loadLatestModel(String path) throws Exception {
        Long lastVersion = E2ETFUtils.getLastVersion(path);
        SavedModelBundle bundle = E2ETFUtils.loadModel(path, lastVersion + "");
        logger.info("loadModel ,path=" + path + "with version=" + lastVersion);
        setModel(bundle);
        this.path = path;
        this.version = lastVersion;
    }

    //关闭模型的方法，不然会一直占用内存
    public void close() throws Exception {
        if (bundle != null) {
            logger.info("close model  ,path=" + path + "with version=" + version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(bundle));
            bundle.close();
            bundle = null;
        }
    }

    //把仨核心变量配置上
    private void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = getInputSpecMap(bundle);
        this.outputSpecMap = getOutputSpecMap(bundle);
    }

    //获取输入的名字与形状与类型
    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(AdxLocalTFModel.SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    //获取输出的名字与形状与类型
    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(AdxLocalTFModel.SIGNATURE_DEF);
        return sig.getOutputsMap();
    }

    /**
     * @param sampleKeyList
     * @param sampleWithFeatureMap
     * @param <T1>
     * @return
     */
    private <T1> Map<String, Tensor<?>> preprocessInput(List<T1> sampleKeyList, Map<T1, Map<String, String>> sampleWithFeatureMap) {
        Map<String, Tensor<?>> formattedInput = new HashMap<>();
        for (String featureKey : this.inputSpecMap.keySet()) {
            TensorInfo inputInfo = this.inputSpecMap.get(featureKey);
            String tfInputKey = inputInfo.getName();
            List<String> oneFeatureList = new ArrayList<>();
            for (T1 sampleKey : sampleKeyList) {
                Map<String, String> featureMap = sampleWithFeatureMap.get(sampleKey);
                oneFeatureList.add(featureMap.getOrDefault(featureKey, null));
            }
            Tensor<?> oneFeatureTensor = AdxLocalTFModel.listToTensor(oneFeatureList, inputInfo);
            formattedInput.put(tfInputKey, oneFeatureTensor);
        }
        return formattedInput;
    }

    /***
     *
     * @param formattedInput 已经格式化好，按照列式存储的输入，顺序与sampleKeyList相同,示例：
     * {"f1":Tensor<String>("wowowo","ninini"),
     * "f2":Tensor<String>("tututu","hahaha")}
     * @return 返回的Map，key为导出pb时指定的字符串（不同于tensorflow operation的名称）,例如：
     * @tf.function
     * def serving_function(data):
     *     result={'prob':model_1(data)}
     *     return result
     *
     * output_func=serving_function.get_concrete_function(input_spec)
     * 如果把output_func作为saved model导出的signatures的话
     * 那么返回Map就只有1个key，名字为prob
     * 一个示例为：
     * {“ctr”:[[0.1],[0.3]]
     * "cvr":[[0.03],[0.09]]}
     * 此处输出使用二维数组，为了兼容多分类或者召回的情况
     */
    private <T> Map<String, T[][]> executeTfGraph(Map<String, Tensor<?>> formattedInput) {

        Map<String, T[][]> resultMap;
        List<Tensor<?>> result = null;
        try {
            List<String> outputKeyList = new ArrayList<>();
            List<String> outputTfKeyList = new ArrayList<>();
            for (String outputKey : this.outputSpecMap.keySet()) {
                outputKeyList.add(outputKey);
                outputTfKeyList.add(this.outputSpecMap.get(outputKey).getName());
            }
            Session.Runner runner = this.bundle.session().runner();
            for (String inputTfKey : formattedInput.keySet()) {
                runner.feed(inputTfKey, formattedInput.get(inputTfKey));
            }
            for (String outputTfKey : outputTfKeyList) {
                runner.fetch(outputTfKey);
            }
            result = runner.run();
            resultMap = new HashMap<>();
            for (int i = 0; i < outputKeyList.size(); i++) {
                Tensor<?> tensorResult = result.get(i);
                resultMap.put(outputKeyList.get(i), tensorToArray(tensorResult));
            }
        } finally {
            TensorflowUtils.closeQuitely(formattedInput.values());
            TensorflowUtils.closeQuitely(result);
        }

        return resultMap;
    }

    /***
     *
     * @param sampleKeyList 样本key的列表，决定了预测结果的顺序
     * @param rawOutputs 列式存储的预测结果, 顺序与sampleKeyList相同，示例：
     * {“ctr”:[[0.1],[0.3]]
     * "cvr":[[0.03],[0.09]]}
     * @return 行式存储的预估结果，示例：
     * {“样本1”:{"ctr":0.1,"cvr":0.3}
     * “样本2”:{"ctr":0.3,"cvr":0.09}}
     */
    private <T1, T2> Map<String, Map<T1, T2[]>> formatOutputs(List<T1> sampleKeyList, Map<String, T2[][]> rawOutputs) {
        Map<String, Map<T1, T2[]>> resultMap = new HashMap<>();
        List<String> outputKeyList = new ArrayList<>(this.outputSpecMap.keySet());
        for (String outputKey : outputKeyList) {
            Map<T1, T2[]> outputMap = new HashMap<>();
            for (int i = 0; i < sampleKeyList.size(); i++) {
                T2[] oneOutputs = rawOutputs.get(outputKey)[i];
                outputMap.put(sampleKeyList.get(i), oneOutputs);
            }
            resultMap.put(outputKey, outputMap);
        }
        return resultMap;
    }

    /**
     * @param sampleWithFeatureMap 带featureMap的样本
     * @param <T1>                 样本key
     * @param <T2>                 模型输出的类型
     * @return <预估值类型，<样本id,预估值>>
     */
    public <T1, T2> Map<String, Map<T1, T2[]>> predict(Map<T1, Map<String, String>> sampleWithFeatureMap) {
        //第一步，把样本的key都取出来
        List<T1> sampleKeyList = new ArrayList<>(sampleWithFeatureMap.keySet());

        //第二步，把特征从一行一行的变成一列一列的，key是特征名字，value是按照sampleKeyList排序的特征取值tensor
        Map<String, Tensor<?>> formattedInput = preprocessInput(sampleKeyList, sampleWithFeatureMap);

        //第三步，把上述特征输入tensorflow中，得到按sampleKeyList顺序的预估结果，并关闭所有tensor, key是tf.function中输出的名字
        Map<String, T2[][]> rawOutputs = executeTfGraph(formattedInput);

        //第四步，把原始输出组装为PredResult对象，里面的key是tf.function中输出的名字
        Map<String, Map<T1, T2[]>> predResults = formatOutputs(sampleKeyList, rawOutputs);

        return predResults;
    }

    /***
     * 把字符串的java list换成Tensor
     * 只支持Float跟String两种类型
     * @param input 输入的字符串List
     * @param inputInfo 包括了需要的输入类型，字符串类型默认值为“”，Float类型默认值为0.0
     * @return 转好的Tensor
     */
    public static Tensor<?> listToTensor(List<String> input, TensorInfo inputInfo) {
        DataType dtype = inputInfo.getDtype();
        int inputSize = input.size();
        if (dtype == DataType.DT_FLOAT) {
            float[] arrayInput = new float[inputSize];
            for (int i = 0; i < input.size(); i++) {
                String s = input.get(i);
                float inputWithType = 0.0f;
                if (s != null) {
                    try {
                        inputWithType = Float.parseFloat(s);
                    } catch (NumberFormatException e) {
                        logger.warn("listToTensor exception", e);
                    }
                }
                arrayInput[i] = inputWithType;
            }
            return Tensors.create(arrayInput);
        } else if (dtype == DataType.DT_STRING) {
            byte[][] arrayInput = new byte[inputSize][];
            for (int i = 0; i < input.size(); i++) {
                String s = input.get(i);
                String inputWithType = "";
                if (s != null) {
                    inputWithType = s;
                }
                arrayInput[i] = inputWithType.getBytes(UTF_8);
            }
            return Tensors.create(arrayInput);
        } else {
            return null;
        }
    }

    /**
     * @param tensorResult
     * @param <T>
     * @return
     */
    private <T> T[][] tensorToArray(Tensor<?> tensorResult) {
        long[] tensorShape = tensorResult.shape();
        if (org.tensorflow.DataType.STRING == tensorResult.dataType()) {
            byte[][][] byteListResult = new byte[(int) tensorShape[0]][(int) tensorShape[1]][];
            tensorResult.copyTo(byteListResult);
            String[][] listResult = new String[(int) tensorShape[0]][(int) tensorShape[1]];
            for (int p = 0; p < byteListResult.length; ++p) {
                for (int q = 0; q < byteListResult[p].length; ++q) {
                    listResult[p][q] = new String(byteListResult[p][q], UTF_8);
                }
            }
            return (T[][]) listResult;
        } else if (org.tensorflow.DataType.FLOAT == tensorResult.dataType()) {
            float[][] listResult = new float[(int) tensorShape[0]][(int) tensorShape[1]];
            tensorResult.copyTo(listResult);
            //由于tensor不支持copyTo一个Float的数组，只能这样了
            Float[][] listResult1 = new Float[(int) tensorShape[0]][(int) tensorShape[1]];
            for (int i = 0; i < listResult1.length; i++) {
                for (int j = 0; j < listResult1[i].length; j++) {
                    listResult1[i][j] = listResult[i][j];
                }
            }
            return (T[][]) listResult1;
        }
        return null;
    }

    public static void main(String[] args) throws Exception {
        AdxLocalTFModel adxLocalTFModel = new AdxLocalTFModel();
        adxLocalTFModel.loadLatestModel("/Users/panhangyi/mid-adx-esmm-ctclk-phy-v2");

//        HashMap<String, String> featureMap = new HashMap<>();
        String str = "{\"f309040\":\"0.0\",\"f414004\":\"0\",\"f3010010\":\"32545\",\"f5020101\":\"0\",\"f5020103\":\"0\",\"f309050\":\"0.0\",\"f5020102\":\"0\",\"f5020105\":\"1000\",\"f5020104\":\"1000\",\"f5020106\":\"1000\",\"f413004\":\"0\",\"f2010010\":\"2636\",\"f5020010\":\"0\",\"f1010020\":\"44\",\"f7040050\":\"0\",\"f3016024\":\"221290\",\"f3016025\":\"120527\",\"f3016022\":\"1261178\",\"f4010010\":\"Mozilla/5.0 (iPhone; CPU iPhone OS 15_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148 MicroMessenger/8.0.28(0x18001c2c) NetType/4G Language/zh_CN\",\"f3016023\":\"457932\",\"f7040060\":\"0\",\"f409004\":\"0\",\"f3016020\":\"113899\",\"f3010161\":\"113899\",\"f3016021\":\"62498\",\"f3010160\":\"234457\",\"f5020030\":\"0\",\"f4010121\":\"15\",\"f5010110\":\"1\",\"f2010060\":\"411143\",\"f410004\":\"0\",\"f5020020\":\"0\",\"f4010122\":\"7\",\"f4010030\":\"0\",\"f7040080\":\"0\",\"f3010144\":\"9\",\"f3010143\":\"10\",\"f3010142\":\"7\",\"f3010141\":\"8\",\"f5020050\":\"0\",\"f3010149\":\"78286\",\"f3010148\":\"220104\",\"f0001\":\"15\",\"f3010147\":\"7173\",\"f3010146\":\"69150\",\"f3010145\":\"7\",\"f0003\":\"1\",\"f0002\":\"7\",\"f5010135\":\"\",\"f3016013\":\"7778\",\"f5010134\":\"\",\"f3016014\":\"220104\",\"f5010133\":\"\",\"f3016011\":\"25292\",\"f3016012\":\"12710\",\"f3016017\":\"23726\",\"f3016018\":\"666218\",\"f7040090\":\"0\",\"f3016015\":\"78286\",\"f5010136\":\"\",\"f3016016\":\"39700\",\"f3010155\":\"457932\",\"f3010154\":\"1261178\",\"f3010153\":\"62498\",\"f3010152\":\"666218\",\"f3010151\":\"23726\",\"f5020040\":\"0\",\"f3010150\":\"39700\",\"f3016010\":\"75969\",\"f3010159\":\"11655\",\"f3010158\":\"23170\",\"f3010157\":\"120527\",\"f6060011\":\"0\",\"f3010156\":\"221290\",\"f6060012\":\"0\",\"f6060013\":\"0\",\"f6060014\":\"1\",\"f6060015\":\"1\",\"f3016019\":\"234457\",\"f5010010\":\"010105\",\"f0309\":\"0\",\"f0308\":\"0\",\"f0307\":\"0\",\"f411004\":\"0\",\"f5020070\":\"0\",\"f3030011\":\"\",\"f0302\":\"0\",\"f0301\":\"0\",\"f0306\":\"0\",\"f0305\":\"0\",\"f0304\":\"0\",\"f0303\":\"0\",\"f5010036\":\"0\",\"f5010035\":\"0\",\"f5010034\":\"0\",\"f5010033\":\"0\",\"f5020060\":\"0\",\"f6010102\":\"[1]\",\"f4010070\":\"3\",\"f5020090\":\"0\",\"f0320\":\"0\",\"f307040\":\"0.0\",\"f0203\":\"0\",\"f0324\":\"0\",\"f0202\":\"0\",\"f0323\":\"0\",\"f0201\":\"0\",\"f0322\":\"0\",\"f0321\":\"0\",\"f0327\":\"0\",\"f0326\":\"0\",\"f0204\":\"0\",\"f0325\":\"0\",\"f0319\":\"0\",\"f0318\":\"0\",\"f5020080\":\"0\",\"f3060017\":\"1\",\"f3060016\":\"1\",\"f307050\":\"0.0\",\"f3060013\":\"1\",\"f5010050\":\"8e3H3Ny3GBDTLn3T\",\"f0313\":\"0\",\"f3060012\":\"1\",\"f0312\":\"0\",\"f3060015\":\"1\",\"f0311\":\"0\",\"f3060014\":\"1\",\"f0310\":\"0\",\"f0317\":\"0\",\"f0316\":\"0\",\"f3060011\":\"1\",\"f0315\":\"0\",\"f0314\":\"0\",\"f4010091\":\"0\",\"f1010010\":\"190\",\"f5010060\":\"0\",\"f3020000\":\"0\",\"f5010079\":\"0\",\"f5010078\":\"0\",\"f5010077\":\"0\",\"f5030101\":\"190&1000\",\"f5030102\":\"2636&1000\",\"f5030103\":\"32545&1000\",\"f4010080\":\"0\",\"f5010070\":\"0\",\"f5010076\":\"0\",\"f5010075\":\"0\",\"f5010074\":\"0\",\"f6010090\":\"411143\",\"f5010082\":\"0\",\"f5010081\":\"0\",\"f5010080\":\"0\",\"f412004\":\"0\",\"f5010090\":\"0\"}";
        HashMap<String, String> featureMap = JSON.parseObject(str, HashMap.class);
//        featureMap.put("ft1","aaa");

        Map<Long, Map<String, String>> sample = new HashMap<>();
        sample.put(123L,featureMap);

        Map<String, Map<Long, Float[]>> predict = adxLocalTFModel.predict(sample);

        System.out.println(predict);

    }
}
