package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.model.util.OneTFUtils;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

import java.util.*;


/**
 * 这是一个通用性上面不考虑太多，以求一个比较清晰的模型类
 * 每个特征的输入shape均为(None,)
 * 每个输出的shape均为(None,1)
 * 输入只支持2个类型，Float与String，输出只支持1个类型 Float
 */
public class OneLocalTfModel {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(OneLocalTfModel.class);
    //默认导出pb时候的签名，基本不变
    private static final String SIGNATURE_DEF = "serving_default";
    //java跟tf交互的入口，取输入输出建session全靠它（在load模型的时候获取到的）
    private SavedModelBundle bundle;
    //模型的路径
    private String path;
    //模型创建的时间戳
    private Long version;
    //输入的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> inputSpecMap;
    //输出的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> outputSpecMap;

    //从路径加载最新模型
    public void loadLatestModel(String path) throws Exception {
        Long lastVersion = OneTFUtils.getLastVersion(path);
        SavedModelBundle bundle = OneTFUtils.loadModel(path, lastVersion + "");
        logger.info("loadModel ,path=" + path + "with version=" + lastVersion);
        setModel(bundle);
        this.path = path;
        this.version = lastVersion;
    }

    //关闭模型的方法，不然会一直占用内存
    public void close() throws Exception {
        if (bundle != null) {
            logger.info("close model  ,path=" + path + "with version=" + version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(bundle));
            bundle.close();
            bundle = null;
        }
    }

    //把仨核心变量配置上
    private void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = getInputSpecMap(bundle);
        this.outputSpecMap = getOutputSpecMap(bundle);
        if (outputSpecMap.size() > 1) {
            logger.error("model output num greater than 1, should not use OneLocalTfModel, close...");
            close();
            throw new Exception();
        }
    }

    //获取输入的名字与形状与类型
    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(OneLocalTfModel.SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    //获取输出的名字与形状与类型
    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(OneLocalTfModel.SIGNATURE_DEF);
        return sig.getOutputsMap();
    }


    /***
     *
     * @param sampleKeyList, 样本key的列表，用来决定顺序
     * @param sampleWithFeatureMap 样本key与该样本的featureMap组成的Map, 示例：
     *                             {“样本1”:{"f1":"wowowo","f2":"tututu"},
     *                             "样本2":{"f1":"ninini","f2":"hahaha"}}
     * @param <T> 样本key的类型
     * @return 返回Map的key为输入特征的tensorflow operation节点名（不同于特征名），值为该特征的Tensor，顺序与sampleKeyList参数相同
     * 用来把行格式的featureMap，转变为列格式的Tensor，方便输入tensorflow的session，示例：
     * {"f1":Tensor<String>("wowowo","ninini"),
     * "f2":Tensor<String>("tututu","hahaha")}
     */
    private <T> Map<String, Tensor<?>> preprocessInput(List<T> sampleKeyList, Map<T, Map<String, String>> sampleWithFeatureMap) {
        Map<String, Tensor<?>> formattedInput = new HashMap<>();
        for (Map.Entry<String, TensorInfo> stringTensorInfoEntry : this.inputSpecMap.entrySet()) {
            TensorInfo inputInfo = stringTensorInfoEntry.getValue();
            String tfInputKey = inputInfo.getName();
            List<String> oneFeatureList = new ArrayList<>();
            for (T sampleKey : sampleKeyList) {
                Map<String, String> featureMap = sampleWithFeatureMap.get(sampleKey);
                oneFeatureList.add(featureMap.getOrDefault(stringTensorInfoEntry.getKey(), null));
            }
            Tensor<?> oneFeatureTensor = OneTFUtils.listToTensor(oneFeatureList, inputInfo);
            formattedInput.put(tfInputKey, oneFeatureTensor);
        }
        return formattedInput;
    }


    /***
     *
     * @param formattedInput 已经格式化好，按照列式存储的输入，顺序与sampleKeyList相同,示例：
     * {"f1":Tensor<String>("wowowo","ninini"),
     * "f2":Tensor<String>("tututu","hahaha")}
     * @return 返回的Map，key为导出pb时指定的字符串（不同于tensorflow operation的名称）,例如：
     * \@tf.function
     * def serving_function(data):
     *     result={'prob':model_1(data)}
     *     return result
     *
     * output_func=serving_function.get_concrete_function(input_spec)
     * 如果把output_func作为saved model导出的signatures的话
     * 那么返回Map就只有1个key，名字为prob
     * 一个示例为：
     * [[0.1],[0.3]]
     * 此处输出使用二维数组，是由于keras的默认行为，形状为n*1，n为样本数
     * 输出必须有且只有1个，且只能是Float类型
     */
    private float[][] executeTfGraph(Map<String, Tensor<?>> formattedInput) {
        float[][] resultArray;
        List<Tensor<?>> result = null;
        try {
            List<String> outputKeyList = new ArrayList<>();
            List<String> outputTfKeyList = new ArrayList<>();
            for (String outputKey : this.outputSpecMap.keySet()) {
                outputKeyList.add(outputKey);
                outputTfKeyList.add(this.outputSpecMap.get(outputKey).getName());
            }
            Session.Runner runner = this.bundle.session().runner();
            for (String inputTfKey : formattedInput.keySet()) {
                runner.feed(inputTfKey, formattedInput.get(inputTfKey));
            }
            for (String outputTfKey : outputTfKeyList) {
                runner.fetch(outputTfKey);
            }
            result = runner.run();
            resultArray = OneTFUtils.tensorToArray(result.get(0));
        } finally {
            //不close要内存爆炸，tensorflow真特么笨
            OneTFUtils.closeQuitely(formattedInput.values());
            OneTFUtils.closeQuitely(result);
        }

        return resultArray;
    }

    /***
     *
     * @param sampleKeyList 样本key的列表，决定了预测结果的顺序
     * @param rawOutputs 列式存储的预测结果, 顺序与sampleKeyList相同，示例：
     * [[0.1],[0.3]]
     * @param <T> 样本的key类型
     * @return 行式存储的预估结果，示例：
     * {“样本1”:0.1,
     * “样本2”:0.3}
     */
    private <T> Map<T, Double> formatOutputs(List<T> sampleKeyList, float[][] rawOutputs) {
        Map<T, Double> resultMap = new HashMap<>();
        for (int i = 0; i < sampleKeyList.size(); i++) {
            resultMap.put(sampleKeyList.get(i), (double) rawOutputs[i][0]);
        }
        return resultMap;
    }

    /***
     * 预估的主接口
     * @param sampleWithFeatureMap 样本key与该样本的featureMap组成的Map, 示例：
     *                             {“样本1”:{"f1":"wowowo","f2":"tututu"},
     *                             "样本2":{"f1":"ninini","f2":"hahaha"}}
     * @param <T> 样本的key类型
     * @return 行式存储的预估结果，示例：
     * {“样本1”:0.1,
     * “样本2”:0.3}
     */
    public <T> Map<T, Double> predict(Map<T, Map<String, String>> sampleWithFeatureMap) {
        //第一步，把样本的key都取出来
        List<T> sampleKeyList = new ArrayList<>(sampleWithFeatureMap.keySet());
        //第二步，把特征从一行一行的变成一列一列的，key是特征名字，value是按照sampleKeyList排序的特征取值tensor
        Map<String, Tensor<?>> formattedInput = preprocessInput(sampleKeyList, sampleWithFeatureMap);
        //第三步，把上述特征输入tensorflow中，得到按sampleKeyList顺序的预估结果，并关闭所有tensor
        float[][] rawOutputs = executeTfGraph(formattedInput);
        //第四步，把原始输出组装为Map对象，里面的key是样本的key
        Map<T, Double> predResults = formatOutputs(sampleKeyList, rawOutputs);
        return predResults;
    }

    public static void main(String[] args) throws Exception {
        Map<String,Map<String,String>> sampleWithFeatureMap=new HashMap<>();
        Map<String,String> featureMap=new HashMap<>();
        featureMap.put("f1","a");
        featureMap.put("f2","a,c,444");
        sampleWithFeatureMap.put("sample1",featureMap);
        OneLocalTfModel oneLocalTfModel = new OneLocalTfModel();
        oneLocalTfModel.loadLatestModel("C:\\Users\\aaa\\Desktop\\nezha\\model\\full");

        Map<String, Double> result = oneLocalTfModel.predict(sampleWithFeatureMap);
        System.out.println(result);
    }

}
