package cn.com.duiba.nezha.alg.model.tf;


import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

public class TensorflowUtils {

    public static SavedModelBundle loadModel(String modelpath) {
        SavedModelBundle bundle = SavedModelBundle.load(modelpath, "serve");
        return bundle;
    }

    public static SavedModelBundle loadModel(String modelpath, String version) {

        String path = modelpath;
        if (modelpath.endsWith("/")) {
            path = modelpath + version;
        } else {
            path = modelpath + "/" + version;
        }
        System.out.println("local model path ="+path);
        SavedModelBundle bundle = SavedModelBundle.load(path, "serve");
        return bundle;
    }

    public static String getOutputName(String output, SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);

        String outputName = sig.getOutputsMap().get(output).getName();

        return outputName;
    }

    public static String getInputName(String input, SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom(bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);

        String inputName = sig.getInputsMap().get(input).getName();

        return inputName;
    }
}
