package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.util.E2ETFUtils;
import cn.com.duiba.nezha.alg.api.model.util.TensorflowUtils;
import com.alibaba.fastjson.JSON;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.TensorInfo;

/* loaded from: input_file:cn/com/duiba/nezha/alg/api/model/AdxLocalTFModel.class */
public class AdxLocalTFModel {
    private static final Logger logger = LoggerFactory.getLogger(E2ELocalTFModel.class);
    private static final String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private Long version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void loadLatestModel(String str) throws Exception {
        Long lastVersion = E2ETFUtils.getLastVersion(str);
        SavedModelBundle loadModel = E2ETFUtils.loadModel(str, lastVersion + "");
        logger.info("loadModel ,path=" + str + "with version=" + lastVersion);
        setModel(loadModel);
        this.path = str;
        this.version = lastVersion;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path + "with version=" + this.version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize(this.bundle));
            this.bundle.close();
            this.bundle = null;
        }
    }

    private void setModel(SavedModelBundle savedModelBundle) throws Exception {
        this.bundle = savedModelBundle;
        this.inputSpecMap = getInputSpecMap(savedModelBundle);
        this.outputSpecMap = getOutputSpecMap(savedModelBundle);
    }

    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle savedModelBundle) throws Exception {
        return MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF).getInputsMap();
    }

    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle savedModelBundle) throws Exception {
        return MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF).getOutputsMap();
    }

    private <T1> Map<String, Tensor<?>> preprocessInput(List<T1> list, Map<T1, Map<String, String>> map) {
        HashMap hashMap = new HashMap();
        for (String str : this.inputSpecMap.keySet()) {
            TensorInfo tensorInfo = this.inputSpecMap.get(str);
            String name = tensorInfo.getName();
            ArrayList arrayList = new ArrayList();
            Iterator<T1> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(map.get(it.next()).getOrDefault(str, null));
            }
            hashMap.put(name, listToTensor(arrayList, tensorInfo));
        }
        return hashMap;
    }

    private <T> Map<String, T[][]> executeTfGraph(Map<String, Tensor<?>> map) {
        List list = null;
        try {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (String str : this.outputSpecMap.keySet()) {
                arrayList.add(str);
                arrayList2.add(this.outputSpecMap.get(str).getName());
            }
            Session.Runner runner = this.bundle.session().runner();
            for (String str2 : map.keySet()) {
                runner.feed(str2, map.get(str2));
            }
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                runner.fetch((String) it.next());
            }
            list = runner.run();
            HashMap hashMap = new HashMap();
            for (int i = 0; i < arrayList.size(); i++) {
                hashMap.put(arrayList.get(i), tensorToArray((Tensor) list.get(i)));
            }
            TensorflowUtils.closeQuitely(map.values());
            TensorflowUtils.closeQuitely(list);
            return hashMap;
        } catch (Throwable th) {
            TensorflowUtils.closeQuitely(map.values());
            TensorflowUtils.closeQuitely(list);
            throw th;
        }
    }

    private <T1, T2> Map<String, Map<T1, T2[]>> formatOutputs(List<T1> list, Map<String, T2[][]> map) {
        HashMap hashMap = new HashMap();
        for (String str : new ArrayList(this.outputSpecMap.keySet())) {
            HashMap hashMap2 = new HashMap();
            for (int i = 0; i < list.size(); i++) {
                hashMap2.put(list.get(i), map.get(str)[i]);
            }
            hashMap.put(str, hashMap2);
        }
        return hashMap;
    }

    public <T1, T2> Map<String, Map<T1, T2[]>> predict(Map<T1, Map<String, String>> map) {
        ArrayList arrayList = new ArrayList(map.keySet());
        return formatOutputs(arrayList, executeTfGraph(preprocessInput(arrayList, map)));
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [byte[], byte[][]] */
    public static Tensor<?> listToTensor(List<String> list, TensorInfo tensorInfo) {
        DataType dtype = tensorInfo.getDtype();
        int size = list.size();
        if (dtype == DataType.DT_FLOAT) {
            float[] fArr = new float[size];
            for (int i = 0; i < list.size(); i++) {
                String str = list.get(i);
                float f = 0.0f;
                if (str != null) {
                    try {
                        f = Float.parseFloat(str);
                    } catch (NumberFormatException e) {
                        logger.warn("listToTensor exception", e);
                    }
                }
                fArr[i] = f;
            }
            return Tensors.create(fArr);
        }
        if (dtype != DataType.DT_STRING) {
            return null;
        }
        ?? r0 = new byte[size];
        for (int i2 = 0; i2 < list.size(); i2++) {
            String str2 = list.get(i2);
            String str3 = "";
            if (str2 != null) {
                str3 = str2;
            }
            r0[i2] = str3.getBytes(StandardCharsets.UTF_8);
        }
        return Tensors.create((byte[][]) r0);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <T> T[][] tensorToArray(Tensor<?> tensor) {
        long[] shape = tensor.shape();
        if (org.tensorflow.DataType.STRING == tensor.dataType()) {
            byte[][][] bArr = new byte[(int) shape[0]][(int) shape[1]];
            tensor.copyTo(bArr);
            String[][] strArr = new String[(int) shape[0]][(int) shape[1]];
            for (int i = 0; i < bArr.length; i++) {
                for (int i2 = 0; i2 < bArr[i].length; i2++) {
                    strArr[i][i2] = new String(bArr[i][i2], StandardCharsets.UTF_8);
                }
            }
            return (T[][]) strArr;
        }
        if (org.tensorflow.DataType.FLOAT != tensor.dataType()) {
            return (T[][]) ((Object[][]) null);
        }
        float[][] fArr = new float[(int) shape[0]][(int) shape[1]];
        tensor.copyTo(fArr);
        Float[][] fArr2 = new Float[(int) shape[0]][(int) shape[1]];
        for (int i3 = 0; i3 < fArr2.length; i3++) {
            for (int i4 = 0; i4 < fArr2[i3].length; i4++) {
                fArr2[i3][i4] = Float.valueOf(fArr[i3][i4]);
            }
        }
        return (T[][]) fArr2;
    }

    public static void main(String[] strArr) throws Exception {
        AdxLocalTFModel adxLocalTFModel = new AdxLocalTFModel();
        adxLocalTFModel.loadLatestModel("/Users/panhangyi/mid-adx-esmm-ctclk-phy-v2");
        HashMap hashMap = (HashMap) JSON.parseObject("{\"f309040\":\"0.0\",\"f414004\":\"0\",\"f3010010\":\"32545\",\"f5020101\":\"0\",\"f5020103\":\"0\",\"f309050\":\"0.0\",\"f5020102\":\"0\",\"f5020105\":\"1000\",\"f5020104\":\"1000\",\"f5020106\":\"1000\",\"f413004\":\"0\",\"f2010010\":\"2636\",\"f5020010\":\"0\",\"f1010020\":\"44\",\"f7040050\":\"0\",\"f3016024\":\"221290\",\"f3016025\":\"120527\",\"f3016022\":\"1261178\",\"f4010010\":\"Mozilla/5.0 (iPhone; CPU iPhone OS 15_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148 MicroMessenger/8.0.28(0x18001c2c) NetType/4G Language/zh_CN\",\"f3016023\":\"457932\",\"f7040060\":\"0\",\"f409004\":\"0\",\"f3016020\":\"113899\",\"f3010161\":\"113899\",\"f3016021\":\"62498\",\"f3010160\":\"234457\",\"f5020030\":\"0\",\"f4010121\":\"15\",\"f5010110\":\"1\",\"f2010060\":\"411143\",\"f410004\":\"0\",\"f5020020\":\"0\",\"f4010122\":\"7\",\"f4010030\":\"0\",\"f7040080\":\"0\",\"f3010144\":\"9\",\"f3010143\":\"10\",\"f3010142\":\"7\",\"f3010141\":\"8\",\"f5020050\":\"0\",\"f3010149\":\"78286\",\"f3010148\":\"220104\",\"f0001\":\"15\",\"f3010147\":\"7173\",\"f3010146\":\"69150\",\"f3010145\":\"7\",\"f0003\":\"1\",\"f0002\":\"7\",\"f5010135\":\"\",\"f3016013\":\"7778\",\"f5010134\":\"\",\"f3016014\":\"220104\",\"f5010133\":\"\",\"f3016011\":\"25292\",\"f3016012\":\"12710\",\"f3016017\":\"23726\",\"f3016018\":\"666218\",\"f7040090\":\"0\",\"f3016015\":\"78286\",\"f5010136\":\"\",\"f3016016\":\"39700\",\"f3010155\":\"457932\",\"f3010154\":\"1261178\",\"f3010153\":\"62498\",\"f3010152\":\"666218\",\"f3010151\":\"23726\",\"f5020040\":\"0\",\"f3010150\":\"39700\",\"f3016010\":\"75969\",\"f3010159\":\"11655\",\"f3010158\":\"23170\",\"f3010157\":\"120527\",\"f6060011\":\"0\",\"f3010156\":\"221290\",\"f6060012\":\"0\",\"f6060013\":\"0\",\"f6060014\":\"1\",\"f6060015\":\"1\",\"f3016019\":\"234457\",\"f5010010\":\"010105\",\"f0309\":\"0\",\"f0308\":\"0\",\"f0307\":\"0\",\"f411004\":\"0\",\"f5020070\":\"0\",\"f3030011\":\"\",\"f0302\":\"0\",\"f0301\":\"0\",\"f0306\":\"0\",\"f0305\":\"0\",\"f0304\":\"0\",\"f0303\":\"0\",\"f5010036\":\"0\",\"f5010035\":\"0\",\"f5010034\":\"0\",\"f5010033\":\"0\",\"f5020060\":\"0\",\"f6010102\":\"[1]\",\"f4010070\":\"3\",\"f5020090\":\"0\",\"f0320\":\"0\",\"f307040\":\"0.0\",\"f0203\":\"0\",\"f0324\":\"0\",\"f0202\":\"0\",\"f0323\":\"0\",\"f0201\":\"0\",\"f0322\":\"0\",\"f0321\":\"0\",\"f0327\":\"0\",\"f0326\":\"0\",\"f0204\":\"0\",\"f0325\":\"0\",\"f0319\":\"0\",\"f0318\":\"0\",\"f5020080\":\"0\",\"f3060017\":\"1\",\"f3060016\":\"1\",\"f307050\":\"0.0\",\"f3060013\":\"1\",\"f5010050\":\"8e3H3Ny3GBDTLn3T\",\"f0313\":\"0\",\"f3060012\":\"1\",\"f0312\":\"0\",\"f3060015\":\"1\",\"f0311\":\"0\",\"f3060014\":\"1\",\"f0310\":\"0\",\"f0317\":\"0\",\"f0316\":\"0\",\"f3060011\":\"1\",\"f0315\":\"0\",\"f0314\":\"0\",\"f4010091\":\"0\",\"f1010010\":\"190\",\"f5010060\":\"0\",\"f3020000\":\"0\",\"f5010079\":\"0\",\"f5010078\":\"0\",\"f5010077\":\"0\",\"f5030101\":\"190&1000\",\"f5030102\":\"2636&1000\",\"f5030103\":\"32545&1000\",\"f4010080\":\"0\",\"f5010070\":\"0\",\"f5010076\":\"0\",\"f5010075\":\"0\",\"f5010074\":\"0\",\"f6010090\":\"411143\",\"f5010082\":\"0\",\"f5010081\":\"0\",\"f5010080\":\"0\",\"f412004\":\"0\",\"f5010090\":\"0\"}", HashMap.class);
        HashMap hashMap2 = new HashMap();
        hashMap2.put(123L, hashMap);
        System.out.println(adxLocalTFModel.predict(hashMap2));
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getPath() {
        return this.path;
    }

    public Long getVersion() {
        return this.version;
    }

    public Map<String, TensorInfo> getInputSpecMap() {
        return this.inputSpecMap;
    }

    public Map<String, TensorInfo> getOutputSpecMap() {
        return this.outputSpecMap;
    }

    public void setBundle(SavedModelBundle savedModelBundle) {
        this.bundle = savedModelBundle;
    }

    public void setPath(String str) {
        this.path = str;
    }

    public void setVersion(Long l) {
        this.version = l;
    }

    public void setInputSpecMap(Map<String, TensorInfo> map) {
        this.inputSpecMap = map;
    }

    public void setOutputSpecMap(Map<String, TensorInfo> map) {
        this.outputSpecMap = map;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdxLocalTFModel)) {
            return false;
        }
        AdxLocalTFModel adxLocalTFModel = (AdxLocalTFModel) obj;
        if (!adxLocalTFModel.canEqual(this)) {
            return false;
        }
        SavedModelBundle bundle = getBundle();
        SavedModelBundle bundle2 = adxLocalTFModel.getBundle();
        if (bundle == null) {
            if (bundle2 != null) {
                return false;
            }
        } else if (!bundle.equals(bundle2)) {
            return false;
        }
        String path = getPath();
        String path2 = adxLocalTFModel.getPath();
        if (path == null) {
            if (path2 != null) {
                return false;
            }
        } else if (!path.equals(path2)) {
            return false;
        }
        Long version = getVersion();
        Long version2 = adxLocalTFModel.getVersion();
        if (version == null) {
            if (version2 != null) {
                return false;
            }
        } else if (!version.equals(version2)) {
            return false;
        }
        Map<String, TensorInfo> inputSpecMap = getInputSpecMap();
        Map<String, TensorInfo> inputSpecMap2 = adxLocalTFModel.getInputSpecMap();
        if (inputSpecMap == null) {
            if (inputSpecMap2 != null) {
                return false;
            }
        } else if (!inputSpecMap.equals(inputSpecMap2)) {
            return false;
        }
        Map<String, TensorInfo> outputSpecMap = getOutputSpecMap();
        Map<String, TensorInfo> outputSpecMap2 = adxLocalTFModel.getOutputSpecMap();
        return outputSpecMap == null ? outputSpecMap2 == null : outputSpecMap.equals(outputSpecMap2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdxLocalTFModel;
    }

    public int hashCode() {
        SavedModelBundle bundle = getBundle();
        int hashCode = (1 * 59) + (bundle == null ? 43 : bundle.hashCode());
        String path = getPath();
        int hashCode2 = (hashCode * 59) + (path == null ? 43 : path.hashCode());
        Long version = getVersion();
        int hashCode3 = (hashCode2 * 59) + (version == null ? 43 : version.hashCode());
        Map<String, TensorInfo> inputSpecMap = getInputSpecMap();
        int hashCode4 = (hashCode3 * 59) + (inputSpecMap == null ? 43 : inputSpecMap.hashCode());
        Map<String, TensorInfo> outputSpecMap = getOutputSpecMap();
        return (hashCode4 * 59) + (outputSpecMap == null ? 43 : outputSpecMap.hashCode());
    }

    public String toString() {
        return "AdxLocalTFModel(bundle=" + getBundle() + ", path=" + getPath() + ", version=" + getVersion() + ", inputSpecMap=" + getInputSpecMap() + ", outputSpecMap=" + getOutputSpecMap() + ")";
    }
}
