/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.util.WuWuPredResult;
import cn.com.duiba.nezha.alg.api.model.util.WuWuTFUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class WuWuLocalTFModel {
    private static final Logger logger = LoggerFactory.getLogger(WuWuLocalTFModel.class);
    private static String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private Long version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void loadLatestModel(String path) throws Exception {
        Long lastVersion = WuWuTFUtils.getLastVersion(path);
        SavedModelBundle bundle = WuWuTFUtils.loadModel(path, lastVersion + "");
        logger.info("loadModel ,path=" + path + "with version=" + lastVersion);
        this.setModel(bundle);
        this.path = path;
        this.version = lastVersion;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path + "with version=" + this.version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
            this.bundle.close();
            this.bundle = null;
        }
    }

    private void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = this.getInputSpecMap(bundle, SIGNATURE_DEF);
        this.outputSpecMap = this.getOutputSpecMap(bundle, SIGNATURE_DEF);
    }

    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getOutputsMap();
    }

    private <T> Map<String, Tensor<?>> preprocessInput(List<T> sampleKeyList, Map<T, Map<String, String>> sampleWithFeatureMap) {
        HashMap formattedInput = new HashMap();
        for (String featureKey : this.inputSpecMap.keySet()) {
            TensorInfo inputInfo = this.inputSpecMap.get(featureKey);
            String tfInputKey = inputInfo.getName();
            ArrayList<String> oneFeatureList = new ArrayList<String>();
            for (T sampleKey : sampleKeyList) {
                Map<String, String> featureMap = sampleWithFeatureMap.get(sampleKey);
                oneFeatureList.add(featureMap.getOrDefault(featureKey, null));
            }
            Tensor<?> oneFeatureTensor = WuWuTFUtils.listToTensor(oneFeatureList, inputInfo);
            formattedInput.put(tfInputKey, oneFeatureTensor);
        }
        return formattedInput;
    }

    private Map<String, Object[][]> executeTfGraph(Map<String, Tensor<?>> formattedInput) {
        ArrayList<String> outputKeyList = new ArrayList<String>();
        ArrayList<String> outputTfKeyList = new ArrayList<String>();
        for (String string : this.outputSpecMap.keySet()) {
            outputKeyList.add(string);
            outputTfKeyList.add(this.outputSpecMap.get(string).getName());
        }
        Session.Runner runner = this.bundle.session().runner();
        for (String inputTfKey : formattedInput.keySet()) {
            runner.feed(inputTfKey, formattedInput.get(inputTfKey));
        }
        for (String outputTfKey : outputTfKeyList) {
            runner.fetch(outputTfKey);
        }
        List list = runner.run();
        HashMap<String, Object[][]> resultMap = new HashMap<String, Object[][]>();
        for (int i = 0; i < outputKeyList.size(); ++i) {
            Tensor tensor = (Tensor)list.get(i);
            resultMap.put((String)outputKeyList.get(i), WuWuTFUtils.tensorToArray(tensor));
        }
        for (Tensor<?> tensor : formattedInput.values()) {
            tensor.close();
        }
        for (Tensor tensor : list) {
            tensor.close();
        }
        return resultMap;
    }

    private <T> Map<T, Map<String, WuWuPredResult>> formatOutputs(List<T> sampleKeyList, Map<String, Object[][]> rawOutputs) {
        HashMap resultMap = new HashMap();
        ArrayList<String> outputKeyList = new ArrayList<String>(this.outputSpecMap.keySet());
        for (int i = 0; i < sampleKeyList.size(); ++i) {
            HashMap<String, WuWuPredResult> outputMap = new HashMap<String, WuWuPredResult>();
            for (String outputKey : outputKeyList) {
                WuWuPredResult wuWuPredResult = new WuWuPredResult(outputKey);
                Object[] oneOutputs = rawOutputs.get(outputKey)[i];
                DataType dtype = this.outputSpecMap.get(outputKey).getDtype();
                if (oneOutputs.length == 1) {
                    wuWuPredResult.setSingleResult(oneOutputs[0], dtype);
                } else {
                    wuWuPredResult.setArrayResult(oneOutputs, dtype);
                }
                outputMap.put(outputKey, wuWuPredResult);
            }
            resultMap.put(sampleKeyList.get(i), outputMap);
        }
        return resultMap;
    }

    public <T> Map<T, Map<String, WuWuPredResult>> predict(Map<T, Map<String, String>> sampleWithFeatureMap) {
        ArrayList<T> sampleKeyList = new ArrayList<T>(sampleWithFeatureMap.keySet());
        Map<String, Tensor<?>> formattedInput = this.preprocessInput(sampleKeyList, sampleWithFeatureMap);
        Map<String, Object[][]> rawOutputs = this.executeTfGraph(formattedInput);
        Map<T, Map<String, WuWuPredResult>> predResults = this.formatOutputs(sampleKeyList, rawOutputs);
        return predResults;
    }

    public static void main(String[] args) throws Exception {
        HashMap sampleWithFeatureMap = new HashMap();
        HashMap<String, String> featureMap = new HashMap<String, String>();
        featureMap.put("f1", "a");
        featureMap.put("f2", "a,c,444");
        sampleWithFeatureMap.put("sample1", featureMap);
        WuWuLocalTFModel model = new WuWuLocalTFModel();
        model.loadLatestModel("C:\\Users\\aaa\\Desktop\\nezha\\model\\full");
        Map result = model.predict(sampleWithFeatureMap);
        System.out.println(result);
    }
}

