package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.util.NumberUtil;
import cn.com.duiba.nezha.alg.api.model.util.TensorflowUtils;
import com.google.common.primitives.Floats;
import lombok.Data;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

import java.io.File;
import java.util.*;

import static java.nio.charset.StandardCharsets.UTF_8;

//import cn.com.duiba.nezha.alg.model.tf.TensorflowUtils;

// 本地调pb用的模型类
@Data
public class LocalTFModel {
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(LocalTFModel.class);
    //默认导出pb时候的签名，基本不变
    private static String SIGNATURE_DEF = "serving_default";
    //java跟tf交互的入口，取输入输出建session全靠它（在load模型的时候获取到的）
    private SavedModelBundle bundle;
    //模型的路径
    private String path;
    //模型创建的时间戳
    private String version;
    //输入的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> inputSpecMap;
    //输出的名字跟对应的TensorInfo，TensorInfo里面存了形状与类型
    private Map<String, TensorInfo> outputSpecMap;

    //把仨核心变量配置上
    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = getInputSpecMap(bundle, SIGNATURE_DEF);
        this.outputSpecMap = getOutputSpecMap(bundle, SIGNATURE_DEF);
    }

    //从路径加载模型
    public void loadModel(String path) throws Exception {
        Long lastVersion = getLastVersion(path);
        loadModel(path, lastVersion + "");
    }

    //获取路径下最新的模型版本
    public Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {
            File[] listFiles = file.listFiles();
            for (int i = 0; i < Objects.requireNonNull(listFiles).length; i++) {
                File f = listFiles[i];
                String fName = f.getName();
                if (NumberUtil.isNumber(fName)) {
                    long version = Long.parseLong(fName);
                    if (ret == null || ret < version) {
                        ret = version;
                    }
                }
            }
        }

        return ret;
    }

    //加载指定版本的模型
    public void loadModel(String path, String version) throws Exception {
        // System.out.println("loadModel with version=" + version);
        logger.info("loadModel ,path=" + path + "with version=" + version);
        SavedModelBundle bundle = TensorflowUtils.loadModel(path, version);

        setModel(bundle);
        this.path = path;
        this.version = version;
    }

    //关闭模型
    public void close() throws Exception {
        if (bundle != null) {
            logger.info("close model, path=" + path + "with version=" + version);
            bundle.close();
            bundle = null;
        }
    }

    //获取输入的名字与形状与类型
    public Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    //获取输出的名字与形状与类型
    public Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getOutputsMap();
    }

    //把特征装成tensor，默认为字符串类型
    public Map<String, Tensor<?>> preprocessInput(Map<String, String> featureMap) {
        Map<String, Tensor<?>> inputTensorMap = new HashMap<>();
        for (String inputKey : this.inputSpecMap.keySet()) {
            TensorInfo inputInfo = this.inputSpecMap.get(inputKey);
            String featureKey = inputInfo.getName();
            if (inputInfo.getDtype() == DataType.DT_FLOAT) {
                float rawInput = 0.0f;
                try {
                    rawInput = Float.parseFloat(featureMap.getOrDefault(inputKey, "0.0"));
                } catch (NumberFormatException e) {
                    logger.warn("LocalTFModel.preprocessInput exception", e);
                }
                float[][] arrayInput = {{rawInput}};
                Tensor<Float> tensorInput = Tensor.create(arrayInput, Float.class);
                inputTensorMap.put(featureKey, tensorInput);
            } else {
                String rawInput = featureMap.getOrDefault(inputKey, "");
                if (rawInput == null) {
                    rawInput = "";
                }
                byte[][][] arrayInput = {{rawInput.getBytes(UTF_8)}};
                Tensor<String> tensorInput = Tensor.create(arrayInput, String.class);
                inputTensorMap.put(featureKey, tensorInput);
            }

        }
        return inputTensorMap;
    }

    //用来把输出的tensor换成list
    //第一维都取了0，只适用于一个一个预测
    public List<?> tensorToList(Tensor tensorResult) {
        long[] tensorShape = tensorResult.shape();
        if (org.tensorflow.DataType.STRING == tensorResult.dataType()) {
            byte[][][] byteListResult = new byte[(int) tensorShape[0]][(int) tensorShape[1]][];
            tensorResult.copyTo(byteListResult);
            String[][] listResult = new String[(int) tensorShape[0]][(int) tensorShape[1]];
            for (int p = 0; p < byteListResult.length; ++p) {
                for (int q = 0; q < byteListResult[p].length; ++q) {
                    listResult[p][q] = new String(byteListResult[p][q], UTF_8);
                }
            }
            return Arrays.asList(listResult[0]);
        } else if (org.tensorflow.DataType.FLOAT == tensorResult.dataType()) {
            float[][] listResult = new float[(int) tensorShape[0]][(int) tensorShape[1]];
            tensorResult.copyTo(listResult);
            return Floats.asList(listResult[0]);
        }

        return null;
    }

    //预测接口，也是一个不应该在这儿的函数……
    public Map<String, List<?>> predict(Map<String, String> featureMap) {
        Map<String, Tensor<?>> inputTensorMap = preprocessInput(featureMap);
        List<String> outputNameList = new ArrayList<>();
        for (String outputKey : this.outputSpecMap.keySet()) {
            outputNameList.add(this.outputSpecMap.get(outputKey).getName());
        }
        Session.Runner runner = this.bundle.session().runner();
        for (String inputKey : inputTensorMap.keySet()) {
            runner.feed(inputKey, inputTensorMap.get(inputKey));
        }
        for (String outputName : outputNameList) {
            runner.fetch(outputName);
        }
        List<Tensor<?>> result = runner.run();
        Map<String, List<?>> resultMap = new HashMap<>();
        for (int i = 0; i < outputNameList.size(); i++) {
            Tensor<?> tensorResult = result.get(i);
            resultMap.put(outputNameList.get(i), tensorToList(tensorResult));
        }
        //不close要内存爆炸，tensorflow真特么笨
        for (Tensor<?> tensor : inputTensorMap.values()) {
            tensor.close();
        }
        for (Tensor<?> tensor : result) {
            tensor.close();
        }
        return resultMap;
    }


    public static void main(String[] args) throws Exception {

        byte[][][] arrayInput = {{"asdfasdfd".getBytes("UTF-8")}};

        String path = "/Users/butcher/Desktop/dssmV01";
        LocalTFModel model = new LocalTFModel();
        model.loadModel(path);

        Map<String, String> featureMap = new HashMap<>();
        String colStr = "f451001\tf451002\tf506001\tf507001\tf507003\tf509001\tf509002\tf505001\tf509004\tf501001\t" +
                "f8330021\tf8300011\tf8300021\tf8300041\tf8310011\tf8310021\tf8310041\tf205002\tf201001\tf108001\t" +
                "f503001\tf502001\tf502003\tf502002\tafee\tad_pk";
        String valStr = "0119010101\t01020301\t3\t1\tAndroid6.0.1\tvivoy55a\tvivo\t1100-1699\t2016\tAndroid\t" +
                "78465,78752,73920,77698,77250,78407,77419,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,78364,76831,78174\t" +
                "78465,78752,77698,77250,78407,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,76831,78174\t" +
                "41283,41540,40232,40363,39724,37839,37297,41236,40855,40951,39767,41400,41465,41435\t16,21,15\t78393\t" +
                "37297\t16\t268\t82239\t401503\t4403\t0\t16\t2\t7.13\t78553_191877\t";
        String[] colList = colStr.split("\t");
        String[] valList = valStr.split("\t");
        for (int i = 0; i < colList.length - 1; i++) {
            featureMap.put(colList[i], valList[i]);
        }

//        for (int i = 0; i < 10000; i++) {
//            Map<String, List> result = model.predict(featureMap);
//            if (i % 1000 == 0) {
//                System.out.println(i);
//            }
//        }

        Map<String, List<?>> result = model.predict(featureMap);
        System.out.println(result);

    }

}
