package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import cn.com.duiba.nezha.alg.model.CODER;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import com.alibaba.fastjson.JSON;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

/* loaded from: input_file:cn/com/duiba/nezha/alg/model/tf/LocalTFModel.class */
public class LocalTFModel {
    private int DF_MAP_SIZE = 128;
    private SavedModelBundle bundle = null;
    private String inputName = null;
    private Map<String, String> outputNameMap = null;
    private static final Logger logger = LoggerFactory.getLogger(CODER.class);
    private static String SIGNATURE_DEF = "serving_default";
    private static String INPUT = "feat_vals";
    public static String DF_OUTPUT = "prob";
    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getOutputName(String str) throws Exception {
        if (this.outputNameMap == null) {
            this.outputNameMap = new HashMap();
        }
        if (!this.outputNameMap.containsKey(str)) {
            this.outputNameMap.put(str, TensorflowUtils.getOutputName(str, this.bundle, SIGNATURE_DEF));
        }
        return this.outputNameMap.get(str);
    }

    public void loadModel(String str, String str2) throws Exception {
        System.out.println("loadModel with version=" + str2);
        setModel(TensorflowUtils.loadModel(str, str2));
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            System.out.println("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(this.bundle));
            this.bundle.close();
            System.out.println("bundle.size after close :" + ObjectSizeCalculator.getObjectSize(this.bundle));
            this.bundle = null;
            System.out.println("bundle.size set null :" + ObjectSizeCalculator.getObjectSize(this.bundle));
        }
    }

    public void loadModel(String str) throws Exception {
        loadModel(str, getLastVersion(str) + "");
    }

    public Long getLastVersion(String str) throws Exception {
        Long l = null;
        File file = new File(str);
        if (file.isDirectory()) {
            for (File file2 : file.listFiles()) {
                String name = file2.getName();
                if (NumberUtil.isNumber(name)) {
                    Long valueOf = Long.valueOf(name);
                    if (l == null || l.longValue() < valueOf.longValue()) {
                        l = valueOf;
                    }
                }
            }
        }
        return l;
    }

    public void setModel(SavedModelBundle savedModelBundle) throws Exception {
        this.bundle = savedModelBundle;
        this.inputName = TensorflowUtils.getInputName(INPUT, savedModelBundle, SIGNATURE_DEF);
    }

    public <T> float[] predict(Tensor<T> tensor, int i) throws Exception {
        return predict(tensor, i, DF_OUTPUT);
    }

    public <T> float[] predict(Tensor<T> tensor, int i, String str) throws Exception {
        String outputName = getOutputName(str);
        if (outputName == null) {
            logger.warn("predict(Tensor<T> x, int size, String output)  invalid,outputName= " + outputName + ",inputName=" + this.inputName);
        }
        List run = this.bundle.session().runner().feed(this.inputName, tensor).fetch(outputName).run();
        float[] fArr = new float[i];
        ((Tensor) run.get(0)).copyTo(fArr);
        ((Tensor) run.get(0)).copyTo(fArr);
        return fArr;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictESMM(Map<T, List<Float>> map) throws Exception {
        HashMap hashMap = new HashMap();
        if (AssertUtil.isAllNotEmpty(new Object[]{map})) {
            ArrayList arrayList = new ArrayList(map.keySet());
            try {
                Tensor paramsTensor = getParamsTensor(arrayList, map);
                Throwable th = null;
                try {
                    try {
                        hashMap.put(PredictResultType.CTR, getPredictReturn(arrayList, paramsTensor, ESMM_CTR_OUTPUT));
                        hashMap.put(PredictResultType.CVR, getPredictReturn(arrayList, paramsTensor, ESMM_CVR_OUTPUT));
                        if (paramsTensor != null) {
                            if (0 != 0) {
                                try {
                                    paramsTensor.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                paramsTensor.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (paramsTensor != null) {
                        if (th != null) {
                            try {
                                paramsTensor.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            paramsTensor.close();
                        }
                    }
                    throw th3;
                }
            } catch (Exception e) {
                throw e;
            }
        }
        return hashMap;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictStrESMM(Map<T, String> map) throws Exception {
        HashMap hashMap = new HashMap();
        if (AssertUtil.isAllNotEmpty(new Object[]{map})) {
            ArrayList arrayList = new ArrayList(map.keySet());
            Tensor paramsTensorStr = getParamsTensorStr(arrayList, map);
            hashMap.put(PredictResultType.CTR, getPredictReturn(arrayList, paramsTensorStr, ESMM_CTR_OUTPUT));
            hashMap.put(PredictResultType.CVR, getPredictReturn(arrayList, paramsTensorStr, ESMM_CVR_OUTPUT));
        }
        return hashMap;
    }

    public <T> Map<T, Double> predict(Map<T, List<Float>> map) throws Exception {
        Map<T, Double> hashMap = new HashMap(this.DF_MAP_SIZE);
        if (AssertUtil.isNotEmpty(map)) {
            ArrayList arrayList = new ArrayList(map.keySet());
            hashMap = getPredictReturn(arrayList, getParamsTensor(arrayList, map), DF_OUTPUT);
        }
        return hashMap;
    }

    public <T> Map<T, Double> predictStr(Map<T, String> map) throws Exception {
        Map<T, Double> hashMap = new HashMap();
        if (AssertUtil.isNotEmpty(map)) {
            ArrayList arrayList = new ArrayList(map.keySet());
            try {
                try {
                    Tensor paramsTensorStr = getParamsTensorStr(arrayList, map);
                    Throwable th = null;
                    if (paramsTensorStr == null) {
                        logger.warn(" predictStr input is invalid, x=null");
                    } else {
                        hashMap = getPredictReturn(arrayList, paramsTensorStr, DF_OUTPUT);
                    }
                    if (paramsTensorStr != null) {
                        if (0 != 0) {
                            try {
                                paramsTensorStr.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            paramsTensorStr.close();
                        }
                    }
                } finally {
                }
            } catch (Exception e) {
                throw e;
            }
        }
        return hashMap;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> list, Tensor<K> tensor, String str) throws Exception {
        HashMap hashMap = new HashMap(this.DF_MAP_SIZE);
        int size = list.size();
        float[] predict = predict(tensor, size, str);
        for (int i = 0; i < size; i++) {
            hashMap.put(list.get(i), Double.valueOf(MathUtil.formatDouble(predict[i] + 0.0d, 6)));
        }
        return hashMap;
    }

    public static <T> Tensor getParamsTensor(List<T> list, Map<T, List<Float>> map) {
        Tensor tensor = null;
        if (AssertUtil.isAllNotEmpty(new Object[]{list, map})) {
            int size = list.size();
            int size2 = map.get(list.get(0)).size();
            float[][] fArr = new float[size][size2];
            for (int i = 0; i < size; i++) {
                List<Float> list2 = map.get(list.get(i));
                for (int i2 = 0; i2 < size2; i2++) {
                    fArr[i][i2] = list2.get(i2).floatValue();
                }
            }
            tensor = Tensor.create(fArr, Float.class);
        }
        return tensor;
    }

    public static <T> Tensor getParamsTensorStr(List<T> list, Map<T, String> map) throws Exception {
        Tensor tensor = null;
        if (AssertUtil.isAllNotEmpty(new Object[]{list, map})) {
            int size = list.size();
            byte[][] bArr = new byte[size][1];
            for (int i = 0; i < size; i++) {
                bArr[i] = map.get(list.get(i)).getBytes("UTF-8");
            }
            tensor = Tensor.create(bArr, String.class);
        }
        return tensor;
    }

    public static void main(String[] strArr) throws Exception {
        LocalTFModel localTFModel = new LocalTFModel();
        localTFModel.loadModel("/Users/lwj/Desktop/model/", "1");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 240; i++) {
            arrayList.add(Float.valueOf(0.0f));
            arrayList2.add(Float.valueOf(0.0f));
        }
        arrayList.set(10, Float.valueOf(1.0f));
        arrayList2.set(50, Float.valueOf(1.0f));
        HashMap hashMap = new HashMap();
        hashMap.put("a", arrayList);
        hashMap.put("b", arrayList2);
        System.out.println(JSON.toJSONString(localTFModel.predict(hashMap)));
    }
}
