package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.util.E2EPredResult;
import cn.com.duiba.nezha.alg.api.model.util.E2ETFUtils;
import cn.com.duiba.nezha.alg.api.model.util.WuWuTFUtils;
import com.alibaba.fastjson.JSON;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.TensorInfo;

/* loaded from: input_file:cn/com/duiba/nezha/alg/api/model/E2ELocalTFModel.class */
public class E2ELocalTFModel {
    private static final Logger logger = LoggerFactory.getLogger(E2ELocalTFModel.class);
    private static final String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private Long version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void loadLatestModel(String str) throws Exception {
        Long lastVersion = E2ETFUtils.getLastVersion(str);
        SavedModelBundle loadModel = E2ETFUtils.loadModel(str, lastVersion + "");
        logger.info("loadModel ,path=" + str + "with version=" + lastVersion);
        setModel(loadModel);
        this.path = str;
        this.version = lastVersion;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path + "with version=" + this.version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(this.bundle));
            this.bundle.close();
            this.bundle = null;
        }
    }

    private void setModel(SavedModelBundle savedModelBundle) throws Exception {
        this.bundle = savedModelBundle;
        this.inputSpecMap = getInputSpecMap(savedModelBundle);
        this.outputSpecMap = getOutputSpecMap(savedModelBundle);
    }

    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle savedModelBundle) throws Exception {
        return MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF).getInputsMap();
    }

    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle savedModelBundle) throws Exception {
        return MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF).getOutputsMap();
    }

    private <T> Map<String, Tensor<?>> preprocessInput(List<T> list, Map<T, Map<String, String>> map) {
        HashMap hashMap = new HashMap();
        System.out.println("this.inputSpecMap.keySet(): " + this.inputSpecMap.keySet());
        for (String str : this.inputSpecMap.keySet()) {
            TensorInfo tensorInfo = this.inputSpecMap.get(str);
            String name = tensorInfo.getName();
            ArrayList arrayList = new ArrayList();
            Iterator<T> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(map.get(it.next()).getOrDefault(str, null));
            }
            hashMap.put(name, E2ETFUtils.listToTensor(arrayList, tensorInfo));
        }
        return hashMap;
    }

    private Map<String, Object[][]> executeTfGraph(Map<String, Tensor<?>> map) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (String str : this.outputSpecMap.keySet()) {
            arrayList.add(str);
            arrayList2.add(this.outputSpecMap.get(str).getName());
        }
        Session.Runner runner = this.bundle.session().runner();
        for (String str2 : map.keySet()) {
            runner.feed(str2, map.get(str2));
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            runner.fetch((String) it.next());
        }
        List run = runner.run();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < arrayList.size(); i++) {
            hashMap.put(arrayList.get(i), WuWuTFUtils.tensorToArray((Tensor) run.get(i)));
        }
        Iterator<Tensor<?>> it2 = map.values().iterator();
        while (it2.hasNext()) {
            it2.next().close();
        }
        Iterator it3 = run.iterator();
        while (it3.hasNext()) {
            ((Tensor) it3.next()).close();
        }
        return hashMap;
    }

    private <T> Map<T, Map<String, E2EPredResult>> formatOutputs(List<T> list, Map<String, Object[][]> map) {
        HashMap hashMap = new HashMap();
        ArrayList<String> arrayList = new ArrayList(this.outputSpecMap.keySet());
        for (int i = 0; i < list.size(); i++) {
            HashMap hashMap2 = new HashMap();
            for (String str : arrayList) {
                E2EPredResult e2EPredResult = new E2EPredResult(str);
                Object[] objArr = map.get(str)[i];
                DataType dtype = this.outputSpecMap.get(str).getDtype();
                if (objArr.length == 1) {
                    e2EPredResult.setSingleResult(objArr[0], dtype);
                } else {
                    e2EPredResult.setArrayResult(objArr, dtype);
                }
                hashMap2.put(str, e2EPredResult);
            }
            hashMap.put(list.get(i), hashMap2);
        }
        return hashMap;
    }

    public <T> Map<T, Double> predict(Map<T, Map<String, String>> map) {
        ArrayList arrayList = new ArrayList(map.keySet());
        System.out.println("sampleKeyList: " + arrayList);
        Map<String, Tensor<?>> preprocessInput = preprocessInput(arrayList, map);
        System.out.println("formattedInput: " + JSON.toJSONString(preprocessInput.keySet()));
        Map<String, Object[][]> executeTfGraph = executeTfGraph(preprocessInput);
        System.out.println("rawOutputs: " + JSON.toJSONString(executeTfGraph));
        Map<T, Map<String, E2EPredResult>> formatOutputs = formatOutputs(arrayList, executeTfGraph);
        HashMap hashMap = new HashMap();
        for (T t : formatOutputs.keySet()) {
            hashMap.put(t, Double.valueOf(formatDouble(formatOutputs.get(t).get("prob").getFloatResult().floatValue(), 6)));
        }
        return hashMap;
    }

    public static double formatDouble(float f, int i) {
        return new BigDecimal(f).setScale(i, RoundingMode.UP).doubleValue();
    }

    public static void main(String[] strArr) throws Exception {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("f00001", "a");
        hashMap2.put("f00002", "a");
        hashMap2.put("f00003", "a");
        hashMap2.put("f00004", "a");
        hashMap2.put("f00005", "a");
        hashMap2.put("f00006", "a");
        hashMap2.put("f00007", "a");
        hashMap2.put("f00008", "a");
        hashMap2.put("f00009", "a");
        hashMap2.put("f00010", "a");
        hashMap2.put("f00011", "a");
        hashMap2.put("f00012", "a");
        hashMap2.put("f00013", "a");
        hashMap2.put("f00014", "a");
        hashMap2.put("f00015", "a");
        hashMap2.put("f00016", "a");
        hashMap2.put("f00017", "a");
        hashMap2.put("f00018", "a");
        hashMap2.put("f00019", "a");
        hashMap2.put("f00020", "a");
        hashMap2.put("f00021", "a");
        hashMap2.put("f00022", "a");
        hashMap2.put("f00023", "a");
        hashMap.put(1L, hashMap2);
        hashMap.put(2L, hashMap2);
        E2ELocalTFModel e2ELocalTFModel = new E2ELocalTFModel();
        e2ELocalTFModel.loadLatestModel("/Users/pengzhixiong/Downloads/pengzong/full");
        System.out.println(e2ELocalTFModel.predict(hashMap));
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getPath() {
        return this.path;
    }

    public Long getVersion() {
        return this.version;
    }

    public Map<String, TensorInfo> getInputSpecMap() {
        return this.inputSpecMap;
    }

    public Map<String, TensorInfo> getOutputSpecMap() {
        return this.outputSpecMap;
    }

    public void setBundle(SavedModelBundle savedModelBundle) {
        this.bundle = savedModelBundle;
    }

    public void setPath(String str) {
        this.path = str;
    }

    public void setVersion(Long l) {
        this.version = l;
    }

    public void setInputSpecMap(Map<String, TensorInfo> map) {
        this.inputSpecMap = map;
    }

    public void setOutputSpecMap(Map<String, TensorInfo> map) {
        this.outputSpecMap = map;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof E2ELocalTFModel)) {
            return false;
        }
        E2ELocalTFModel e2ELocalTFModel = (E2ELocalTFModel) obj;
        if (!e2ELocalTFModel.canEqual(this)) {
            return false;
        }
        SavedModelBundle bundle = getBundle();
        SavedModelBundle bundle2 = e2ELocalTFModel.getBundle();
        if (bundle == null) {
            if (bundle2 != null) {
                return false;
            }
        } else if (!bundle.equals(bundle2)) {
            return false;
        }
        String path = getPath();
        String path2 = e2ELocalTFModel.getPath();
        if (path == null) {
            if (path2 != null) {
                return false;
            }
        } else if (!path.equals(path2)) {
            return false;
        }
        Long version = getVersion();
        Long version2 = e2ELocalTFModel.getVersion();
        if (version == null) {
            if (version2 != null) {
                return false;
            }
        } else if (!version.equals(version2)) {
            return false;
        }
        Map<String, TensorInfo> inputSpecMap = getInputSpecMap();
        Map<String, TensorInfo> inputSpecMap2 = e2ELocalTFModel.getInputSpecMap();
        if (inputSpecMap == null) {
            if (inputSpecMap2 != null) {
                return false;
            }
        } else if (!inputSpecMap.equals(inputSpecMap2)) {
            return false;
        }
        Map<String, TensorInfo> outputSpecMap = getOutputSpecMap();
        Map<String, TensorInfo> outputSpecMap2 = e2ELocalTFModel.getOutputSpecMap();
        return outputSpecMap == null ? outputSpecMap2 == null : outputSpecMap.equals(outputSpecMap2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof E2ELocalTFModel;
    }

    public int hashCode() {
        SavedModelBundle bundle = getBundle();
        int hashCode = (1 * 59) + (bundle == null ? 43 : bundle.hashCode());
        String path = getPath();
        int hashCode2 = (hashCode * 59) + (path == null ? 43 : path.hashCode());
        Long version = getVersion();
        int hashCode3 = (hashCode2 * 59) + (version == null ? 43 : version.hashCode());
        Map<String, TensorInfo> inputSpecMap = getInputSpecMap();
        int hashCode4 = (hashCode3 * 59) + (inputSpecMap == null ? 43 : inputSpecMap.hashCode());
        Map<String, TensorInfo> outputSpecMap = getOutputSpecMap();
        return (hashCode4 * 59) + (outputSpecMap == null ? 43 : outputSpecMap.hashCode());
    }

    public String toString() {
        return "E2ELocalTFModel(bundle=" + getBundle() + ", path=" + getPath() + ", version=" + getVersion() + ", inputSpecMap=" + getInputSpecMap() + ", outputSpecMap=" + getOutputSpecMap() + ")";
    }
}
