package cn.com.duiba.nezha.alg.alg.advertexploitbak.model;

import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import cn.com.duiba.nezha.alg.model.tf.TensorflowUtils;
import com.google.common.primitives.Floats;
//import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

import java.io.File;
import java.io.UnsupportedEncodingException;
import java.util.*;

import static java.nio.charset.StandardCharsets.UTF_8;

public class DSSMLocalTFModel {
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DSSMLocalTFModel.class);

    private static String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private String version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap=getInputSpecMap(bundle,SIGNATURE_DEF);
        this.outputSpecMap=getOutputSpecMap(bundle,SIGNATURE_DEF);
    }

    public void loadModel(String path) throws Exception {

        Long lastVersion = getLastVersion(path);

        loadModel(path, lastVersion + "");
    }

    public Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {

            File[] listFiles = file.listFiles();
            for (int i = 0; i < listFiles.length; i++) {
                File f = listFiles[i];
                String fName = f.getName();
                if (NumberUtil.isNumber(fName)) {

                    Long version = Long.valueOf(fName);

                    if (ret == null || ret < version) {

                        ret = version;
                    }
                }

            }
        }

        return ret;
    }

    public void loadModel(String path, String version) throws Exception {

        System.out.println("loadModel with version=" + version);
        logger.info("loadModel ,paht=" + path + "with version=" + version);
        SavedModelBundle bundle = TensorflowUtils.getModelBundle(path, version);

        setModel(bundle);
        this.path = path;
        this.version = version;
    }

    public void close() throws Exception {

        if (bundle != null) {

            logger.info("close model  ,paht=" + path + "with version=" + version);
//            System.out.println("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(bundle));
            bundle.close();
            bundle = null;
        }
    }

    public Map getInputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        Map<String, TensorInfo> inputSpecMap = sig.getInputsMap();
        return inputSpecMap;
    }

    public Map getOutputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        Map<String, TensorInfo> outputSpecMap= sig.getOutputsMap();
        return outputSpecMap;
    }
//把特征装成tensor，默认为字符串类型
    public Map<String, Tensor> preprocessInput(Map<String,String> featureMap) throws UnsupportedEncodingException {
        Map<String, Tensor> inputTensorMap=new HashMap<String, Tensor>();
        for (String inputKey : this.inputSpecMap.keySet()) {
            TensorInfo inputInfo = this.inputSpecMap.get(inputKey);
            String featureKey=inputInfo.getName();
            if (inputInfo.getDtype()==DataType.DT_FLOAT) {
                float rawInput=0.0f;
                try {
                    rawInput=Float.parseFloat(featureMap.getOrDefault(inputKey,"0.0"));
                } catch (NumberFormatException e) {
                    e.printStackTrace();
                }
                float [][] arrayInput={{rawInput}};
                Tensor<Float> tensorInput=Tensor.create(arrayInput,Float.class);
                inputTensorMap.put(featureKey,tensorInput);
            }
            else {
                String rawInput=featureMap.getOrDefault(inputKey,"");
                byte[][][] arrayInput={{rawInput.getBytes("UTF-8")}};
                Tensor<String> tensorInput=Tensor.create(arrayInput,String.class);
                inputTensorMap.put(featureKey,tensorInput);
            }

        }
        return inputTensorMap;
    }

    //把tensor转成list，第一维都取了0，只适用于一个一个预测
    public List tensorToList(Tensor tensorResult) {
        long[] tensorShape=tensorResult.shape();
        if(org.tensorflow.DataType.STRING == tensorResult.dataType()) {
            byte[][][] byteListResult=new byte[(int) tensorShape[0]][(int) tensorShape[1]][];
            tensorResult.copyTo(byteListResult);
            String[][] listResult=new String[(int) tensorShape[0]][(int) tensorShape[1]];
            for (int p = 0; p < byteListResult.length; ++p) {
                for (int q = 0; q < byteListResult[p].length; ++q) {
                    listResult[p][q] = new String(byteListResult[p][q], UTF_8);
                }
            }
            return Arrays.asList(listResult[0]);
        }
        else if(org.tensorflow.DataType.FLOAT==tensorResult.dataType()) {
            float[][] listResult=new float[(int) tensorShape[0]][(int) tensorShape[1]];
            tensorResult.copyTo(listResult);
            return Floats.asList(listResult[0]);
        }

        return null;
    }

    public Map<String,List> predict(Map<String,String> featureMap) throws UnsupportedEncodingException {
        Map<String,Tensor> inputTensorMap=preprocessInput(featureMap);
        List<String> outputNameList=new ArrayList<>();
        for (String outputKey : this.outputSpecMap.keySet()) {
            outputNameList.add(this.outputSpecMap.get(outputKey).getName());
        }
        Session.Runner runner = this.bundle.session().runner();
        for (String inputKey : inputTensorMap.keySet()) {
            runner.feed(inputKey,inputTensorMap.get(inputKey));
        }
        for (String outputName : outputNameList) {
            runner.fetch(outputName);
        }
        List<Tensor<?>> result=runner.run();
        Map<String,List> resultMap=new HashMap<>();
        for (int i = 0; i < outputNameList.size(); i++) {
            Tensor tensorResult=result.get(i);
            resultMap.put(outputNameList.get(i), tensorToList(tensorResult));
        }
        //不close要内存爆炸，tensorflow真特么笨
        for (Tensor tensor : inputTensorMap.values()) {
            tensor.close();
        }
        for (Tensor<?> tensor : result) {
            tensor.close();
        }
        return resultMap;
    }


    public static void main(String[] args) throws Exception{
        String path="C:\\Users\\wangyaole\\Desktop\\nezha\\model\\dssm01";
        DSSMLocalTFModel model=new DSSMLocalTFModel();
        model.loadModel(path);

        Map<String,String> featureMap=new HashMap<>();
        String colStr="f501001,f504001,f508005,f108001,f201001,f507001,f503001,f406001,f502001,f502002,afee,ad_pk";
        String valStr="Android,HLK-AL00,荣耀,325913,70125,3,3609,1,16,2,7.2,72093_177812";
        String[] colList=colStr.split(",");
        String[] valList=valStr.split(",");
        for (int i = 0; i < colList.length-1; i++) {
            featureMap.put(colList[i],valList[i]);
        }
        for (int i = 0; i < 10000; i++) {
            Map result=model.predict(featureMap);
            if (i%1000==0){
                System.out.println(i);
            }
        }
        Map result=model.predict(featureMap);
        System.out.println(result);

    }

}
