package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import com.alibaba.fastjson.JSON;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

/* loaded from: input_file:cn/com/duiba/nezha/alg/model/tf/LocalTFModelV2.class */
public class LocalTFModelV2 extends TensorflowUtils {
    private int DF_MAP_SIZE = 128;
    private SavedModelBundle bundle = null;
    private String version = null;
    private String path = null;
    private String inputName = null;
    private Map<String, String> outputNameMap = null;
    private static final Logger logger = LoggerFactory.getLogger(LocalTFModelV2.class);
    private static String SIGNATURE_DEF = "serving_default";
    private static String DF_INPUT = "feat_vals";
    public static String DF_OUTPUT = "prob";

    public String getOutputName(String str) throws Exception {
        if (this.outputNameMap == null) {
            this.outputNameMap = new HashMap();
        }
        if (!this.outputNameMap.containsKey(str)) {
            try {
                String outputName = getOutputName(str, this.bundle, SIGNATURE_DEF);
                if (outputName != null) {
                    this.outputNameMap.put(str, outputName);
                }
            } catch (Exception e) {
                logger.info("getOutputName , empty", e);
                throw e;
            }
        }
        return this.outputNameMap.get(str);
    }

    public void loadModel(String str) throws Exception {
        loadModel(str, null);
    }

    public void loadModel(String str, String str2) throws Exception {
        if (str2 == null) {
            str2 = getLastVersion(str);
        }
        this.bundle = getModelBundle(str, str2);
        this.inputName = TensorflowUtils.getInputName(DF_INPUT, this.bundle, SIGNATURE_DEF);
        this.path = str;
        this.version = str2;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path);
            this.bundle.close();
            this.bundle = null;
        }
    }

    public <T> float[] predict(Tensor<T> tensor, int i, String str) throws Exception {
        String outputName = getOutputName(str);
        if (outputName == null || this.inputName == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,outputName= " + outputName + ",inputName=" + this.inputName);
        }
        if (tensor == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,x=null");
        }
        List run = this.bundle.session().runner().feed(this.inputName, tensor).fetch(outputName).run();
        float[] fArr = new float[i];
        ((Tensor) run.get(0)).copyTo(fArr);
        ((Tensor) run.get(0)).close();
        return fArr;
    }

    public <T> Map<T, Double> predictStr(Map<T, String> map) throws Exception {
        return predictStr(map, DF_OUTPUT);
    }

    public <T> Map<T, Double> predictInt(Map<T, List<Integer>> map) throws Exception {
        return predictInt(map, DF_OUTPUT);
    }

    public <T> Map<T, Double> predictFloat(Map<T, List<Float>> map) throws Exception {
        return predictFloat(map, DF_OUTPUT);
    }

    public <T> Map<T, Double> predictStr(Map<T, String> map, String str) throws Exception {
        Map<T, Double> hashMap = new HashMap();
        if (AssertUtil.isNotEmpty(map)) {
            ArrayList arrayList = new ArrayList(map.keySet());
            try {
                Tensor paramsTensorStr = getParamsTensorStr(arrayList, map, arrayList.size(), 1);
                Throwable th = null;
                try {
                    hashMap = getPredictReturn(arrayList, paramsTensorStr, str);
                    if (paramsTensorStr != null) {
                        if (0 != 0) {
                            try {
                                paramsTensorStr.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            paramsTensorStr.close();
                        }
                    }
                } finally {
                }
            } catch (Exception e) {
                logger.info("predictV(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return hashMap;
    }

    public <T> Map<T, Double> predictFloat(Map<T, List<Float>> map, String str) throws Exception {
        Map<T, Double> hashMap = new HashMap();
        if (AssertUtil.isNotEmpty(map)) {
            ArrayList arrayList = new ArrayList(map.keySet());
            try {
                Tensor paramsTensorFloat = getParamsTensorFloat(arrayList, map, arrayList.size(), map.get(arrayList.get(0)).size());
                Throwable th = null;
                try {
                    try {
                        hashMap = getPredictReturn(arrayList, paramsTensorFloat, str);
                        if (paramsTensorFloat != null) {
                            if (0 != 0) {
                                try {
                                    paramsTensorFloat.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                paramsTensorFloat.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                logger.info("predictV(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return hashMap;
    }

    public <T> Map<T, Double> predictInt(Map<T, List<Integer>> map, String str) throws Exception {
        Map<T, Double> hashMap = new HashMap();
        if (AssertUtil.isNotEmpty(map)) {
            ArrayList arrayList = new ArrayList(map.keySet());
            try {
                Tensor paramsTensorInt = getParamsTensorInt(arrayList, map, arrayList.size(), map.get(arrayList.get(0)).size());
                Throwable th = null;
                try {
                    try {
                        hashMap = getPredictReturn(arrayList, paramsTensorInt, str);
                        if (paramsTensorInt != null) {
                            if (0 != 0) {
                                try {
                                    paramsTensorInt.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                paramsTensorInt.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                logger.info("predictV(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }
        return hashMap;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> list, Tensor<K> tensor, String str) throws Exception {
        HashMap hashMap = new HashMap(this.DF_MAP_SIZE);
        int size = list.size();
        if (predict(tensor, size, str) != null) {
            for (int i = 0; i < size; i++) {
                hashMap.put(list.get(i), Double.valueOf(MathUtil.formatDouble(r0[i] + 1.0E-6d, 6)));
            }
        }
        return hashMap;
    }

    public static <T> Tensor getParamsTensorFloat(List<T> list, Map<T, List<Float>> map, int i, int i2) {
        float[][] fArr = new float[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            List<Float> list2 = map.get(list.get(i3));
            for (int i4 = 0; i4 < i2; i4++) {
                fArr[i3][i4] = list2.get(i4).floatValue();
            }
        }
        return Tensor.create(fArr, Integer.class);
    }

    public static <T> Tensor getParamsTensorInt(List<T> list, Map<T, List<Integer>> map, int i, int i2) {
        int[][] iArr = new int[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            List<Integer> list2 = map.get(list.get(i3));
            for (int i4 = 0; i4 < i2; i4++) {
                iArr[i3][i4] = list2.get(i4).intValue();
            }
        }
        return Tensor.create(iArr, Integer.class);
    }

    public static <T> Tensor getParamsTensorStr(List<T> list, Map<T, String> map, int i, int i2) throws Exception {
        byte[][] bArr = new byte[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            bArr[i3] = map.get(list.get(i3)).getBytes("UTF-8");
        }
        return Tensor.create(bArr, String.class);
    }

    public static void main(String[] strArr) throws Exception {
        LocalTFModelV2 localTFModelV2 = new LocalTFModelV2();
        localTFModelV2.loadModel("/Users/lwj/Documents/pythonProjectFiles/tf/data/serving-model/deepfm-v008/", null);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            arrayList.add(1);
        }
        HashMap hashMap = new HashMap();
        hashMap.put("s1", arrayList);
        hashMap.put("s2", arrayList);
        System.out.println("ret=" + JSON.toJSONString(localTFModelV2.predictInt(hashMap)));
    }

    public int getDF_MAP_SIZE() {
        return this.DF_MAP_SIZE;
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getVersion() {
        return this.version;
    }

    public String getPath() {
        return this.path;
    }

    public String getInputName() {
        return this.inputName;
    }

    public Map<String, String> getOutputNameMap() {
        return this.outputNameMap;
    }

    public void setDF_MAP_SIZE(int i) {
        this.DF_MAP_SIZE = i;
    }

    public void setBundle(SavedModelBundle savedModelBundle) {
        this.bundle = savedModelBundle;
    }

    public void setVersion(String str) {
        this.version = str;
    }

    public void setPath(String str) {
        this.path = str;
    }

    public void setInputName(String str) {
        this.inputName = str;
    }

    public void setOutputNameMap(Map<String, String> map) {
        this.outputNameMap = map;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LocalTFModelV2)) {
            return false;
        }
        LocalTFModelV2 localTFModelV2 = (LocalTFModelV2) obj;
        if (!localTFModelV2.canEqual(this) || getDF_MAP_SIZE() != localTFModelV2.getDF_MAP_SIZE()) {
            return false;
        }
        SavedModelBundle bundle = getBundle();
        SavedModelBundle bundle2 = localTFModelV2.getBundle();
        if (bundle == null) {
            if (bundle2 != null) {
                return false;
            }
        } else if (!bundle.equals(bundle2)) {
            return false;
        }
        String version = getVersion();
        String version2 = localTFModelV2.getVersion();
        if (version == null) {
            if (version2 != null) {
                return false;
            }
        } else if (!version.equals(version2)) {
            return false;
        }
        String path = getPath();
        String path2 = localTFModelV2.getPath();
        if (path == null) {
            if (path2 != null) {
                return false;
            }
        } else if (!path.equals(path2)) {
            return false;
        }
        String inputName = getInputName();
        String inputName2 = localTFModelV2.getInputName();
        if (inputName == null) {
            if (inputName2 != null) {
                return false;
            }
        } else if (!inputName.equals(inputName2)) {
            return false;
        }
        Map<String, String> outputNameMap = getOutputNameMap();
        Map<String, String> outputNameMap2 = localTFModelV2.getOutputNameMap();
        return outputNameMap == null ? outputNameMap2 == null : outputNameMap.equals(outputNameMap2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LocalTFModelV2;
    }

    public int hashCode() {
        int df_map_size = (1 * 59) + getDF_MAP_SIZE();
        SavedModelBundle bundle = getBundle();
        int hashCode = (df_map_size * 59) + (bundle == null ? 43 : bundle.hashCode());
        String version = getVersion();
        int hashCode2 = (hashCode * 59) + (version == null ? 43 : version.hashCode());
        String path = getPath();
        int hashCode3 = (hashCode2 * 59) + (path == null ? 43 : path.hashCode());
        String inputName = getInputName();
        int hashCode4 = (hashCode3 * 59) + (inputName == null ? 43 : inputName.hashCode());
        Map<String, String> outputNameMap = getOutputNameMap();
        return (hashCode4 * 59) + (outputNameMap == null ? 43 : outputNameMap.hashCode());
    }

    public String toString() {
        return "LocalTFModelV2(DF_MAP_SIZE=" + getDF_MAP_SIZE() + ", bundle=" + getBundle() + ", version=" + getVersion() + ", path=" + getPath() + ", inputName=" + getInputName() + ", outputNameMap=" + getOutputNameMap() + ")";
    }
}
