package cn.com.duiba.nezha.alg.model.tf;


import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import com.alibaba.fastjson.JSON;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

import java.io.File;
import java.nio.file.Paths;

public class TensorflowUtils {


    public static SavedModelBundle getModelBundle(String path, String version) throws Exception {

        System.out.println("local model path =" + path + ",version=" + version);
        path = Paths.get(path,version).toString();
        SavedModelBundle bundle = SavedModelBundle.load(path, "serve");
        return bundle;
    }

    public static SavedModelBundle loadModelV1(String modelpath, String version) {

        String path;
        if (modelpath.endsWith("/")) {
            path = modelpath + version;
        } else {
            path = modelpath + "/" + version;
        }
        System.out.println("local model path ="+path);
        SavedModelBundle bundle = SavedModelBundle.load(path, "serve");
        return bundle;
    }

    public static String getLastVersion(String path) throws Exception {
        String ret = null;
        Long tmpVersion = null;
        File file = new File(path);
        if (file.isDirectory()) {

            File[] listFiles = file.listFiles();

            System.out.println(listFiles);

            for (int i = 0; i < listFiles.length; i++) {
                File f = listFiles[i];
                String fName = f.getName();
                if (NumberUtil.isNumber(fName)) {

                    Long version = Long.valueOf(fName);

                    if (tmpVersion == null || tmpVersion < version) {

                        tmpVersion = version;
                    }
                }

            }
        }
        ret = tmpVersion + "";
        return ret;

    }

    public static String getOutputName(String output, SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {

        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);

        String outputName = sig.getOutputsMap().get(output).getName();


        return outputName;
    }

    public static String getInputName(String input, SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        String inputName = null;
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        if (sig.getInputsMap().get(input) != null) {
            inputName = sig.getInputsMap().get(input).getName();
        }
        return inputName;
    }
}
