/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.model.util.OneTFUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class OneLocalTfModel {
    private static final Logger logger = LoggerFactory.getLogger(OneLocalTfModel.class);
    private static final String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private Long version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void loadLatestModel(String path) throws Exception {
        Long lastVersion = OneTFUtils.getLastVersion(path);
        SavedModelBundle bundle = OneTFUtils.loadModel(path, lastVersion + "");
        logger.info("loadModel ,path=" + path + "with version=" + lastVersion);
        this.setModel(bundle);
        this.path = path;
        this.version = lastVersion;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path + "with version=" + this.version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
            this.bundle.close();
            this.bundle = null;
        }
    }

    private void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = this.getInputSpecMap(bundle);
        this.outputSpecMap = this.getOutputSpecMap(bundle);
        if (this.outputSpecMap.size() > 1) {
            logger.error("model output num greater than 1, should not use OneLocalTfModel, close...");
            this.close();
            throw new Exception();
        }
    }

    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getOutputsMap();
    }

    private <T> Map<String, Tensor<?>> preprocessInput(List<T> sampleKeyList, Map<T, Map<String, String>> sampleWithFeatureMap) {
        HashMap formattedInput = new HashMap();
        for (Map.Entry<String, TensorInfo> stringTensorInfoEntry : this.inputSpecMap.entrySet()) {
            TensorInfo inputInfo = stringTensorInfoEntry.getValue();
            String tfInputKey = inputInfo.getName();
            ArrayList<String> oneFeatureList = new ArrayList<String>();
            for (T sampleKey : sampleKeyList) {
                Map<String, String> featureMap = sampleWithFeatureMap.get(sampleKey);
                oneFeatureList.add(featureMap.getOrDefault(stringTensorInfoEntry.getKey(), null));
            }
            Tensor<?> oneFeatureTensor = OneTFUtils.listToTensor(oneFeatureList, inputInfo);
            formattedInput.put(tfInputKey, oneFeatureTensor);
        }
        return formattedInput;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private float[][] executeTfGraph(Map<String, Tensor<?>> formattedInput) {
        float[][] resultArray;
        List result = null;
        try {
            ArrayList<String> outputKeyList = new ArrayList<String>();
            ArrayList<String> outputTfKeyList = new ArrayList<String>();
            for (String outputKey : this.outputSpecMap.keySet()) {
                outputKeyList.add(outputKey);
                outputTfKeyList.add(this.outputSpecMap.get(outputKey).getName());
            }
            Session.Runner runner = this.bundle.session().runner();
            for (String inputTfKey : formattedInput.keySet()) {
                runner.feed(inputTfKey, formattedInput.get(inputTfKey));
            }
            for (String outputTfKey : outputTfKeyList) {
                runner.fetch(outputTfKey);
            }
            result = runner.run();
            resultArray = OneTFUtils.tensorToArray((Tensor)result.get(0));
        }
        finally {
            OneTFUtils.closeQuitely(formattedInput.values());
            OneTFUtils.closeQuitely(result);
        }
        return resultArray;
    }

    private <T> Map<T, Double> formatOutputs(List<T> sampleKeyList, float[][] rawOutputs) {
        HashMap<T, Double> resultMap = new HashMap<T, Double>();
        for (int i = 0; i < sampleKeyList.size(); ++i) {
            resultMap.put(sampleKeyList.get(i), Double.valueOf(rawOutputs[i][0]));
        }
        return resultMap;
    }

    public <T> Map<T, Double> predict(Map<T, Map<String, String>> sampleWithFeatureMap) {
        ArrayList<T> sampleKeyList = new ArrayList<T>(sampleWithFeatureMap.keySet());
        Map<String, Tensor<?>> formattedInput = this.preprocessInput(sampleKeyList, sampleWithFeatureMap);
        float[][] rawOutputs = this.executeTfGraph(formattedInput);
        Map<T, Double> predResults = this.formatOutputs(sampleKeyList, rawOutputs);
        return predResults;
    }

    public static void main(String[] args) throws Exception {
        HashMap sampleWithFeatureMap = new HashMap();
        HashMap<String, String> featureMap = new HashMap<String, String>();
        featureMap.put("f1", "a");
        featureMap.put("f2", "a,c,444");
        sampleWithFeatureMap.put("sample1", featureMap);
        OneLocalTfModel oneLocalTfModel = new OneLocalTfModel();
        oneLocalTfModel.loadLatestModel("C:\\Users\\aaa\\Desktop\\nezha\\model\\full");
        Map result = oneLocalTfModel.predict(sampleWithFeatureMap);
        System.out.println(result);
    }
}

