/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.api.model;

import cn.com.duiba.nezha.alg.api.model.E2ELocalTFModel;
import cn.com.duiba.nezha.alg.api.model.util.E2ETFUtils;
import cn.com.duiba.nezha.alg.api.model.util.TensorflowUtils;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.DataType;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class AdxLocalTFModel {
    private static final Logger logger = LoggerFactory.getLogger(E2ELocalTFModel.class);
    private static final String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private Long version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void loadLatestModel(String path) throws Exception {
        Long lastVersion = E2ETFUtils.getLastVersion(path);
        SavedModelBundle bundle = E2ETFUtils.loadModel(path, lastVersion + "");
        logger.info("loadModel ,path=" + path + "with version=" + lastVersion);
        this.setModel(bundle);
        this.path = path;
        this.version = lastVersion;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path + "with version=" + this.version);
            logger.info("bundle.size before close :" + ObjectSizeCalculator.getObjectSize((Object)this.bundle));
            this.bundle.close();
            this.bundle = null;
        }
    }

    private void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = this.getInputSpecMap(bundle);
        this.outputSpecMap = this.getOutputSpecMap(bundle);
    }

    private Map<String, TensorInfo> getInputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getInputsMap();
    }

    private Map<String, TensorInfo> getOutputSpecMap(SavedModelBundle bundle) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        return sig.getOutputsMap();
    }

    private <T1> Map<String, Tensor<?>> preprocessInput(List<T1> sampleKeyList, Map<T1, Map<String, String>> sampleWithFeatureMap) {
        HashMap formattedInput = new HashMap();
        for (String featureKey : this.inputSpecMap.keySet()) {
            TensorInfo inputInfo = this.inputSpecMap.get(featureKey);
            String tfInputKey = inputInfo.getName();
            ArrayList<String> oneFeatureList = new ArrayList<String>();
            for (T1 sampleKey : sampleKeyList) {
                Map<String, String> featureMap = sampleWithFeatureMap.get(sampleKey);
                oneFeatureList.add(featureMap.getOrDefault(featureKey, null));
            }
            Tensor<?> oneFeatureTensor = AdxLocalTFModel.listToTensor(oneFeatureList, inputInfo);
            formattedInput.put(tfInputKey, oneFeatureTensor);
        }
        return formattedInput;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * WARNING - void declaration
     */
    private <T> Map<String, T[][]> executeTfGraph(Map<String, Tensor<?>> formattedInput) {
        HashMap<String, T[][]> resultMap;
        List result = null;
        try {
            void var7_10;
            ArrayList<String> outputKeyList = new ArrayList<String>();
            ArrayList<String> outputTfKeyList = new ArrayList<String>();
            for (String string : this.outputSpecMap.keySet()) {
                outputKeyList.add(string);
                outputTfKeyList.add(this.outputSpecMap.get(string).getName());
            }
            Session.Runner runner = this.bundle.session().runner();
            for (String inputTfKey : formattedInput.keySet()) {
                runner.feed(inputTfKey, formattedInput.get(inputTfKey));
            }
            for (String outputTfKey : outputTfKeyList) {
                runner.fetch(outputTfKey);
            }
            result = runner.run();
            resultMap = new HashMap<String, T[][]>();
            boolean bl = false;
            while (var7_10 < outputKeyList.size()) {
                Tensor tensorResult = (Tensor)result.get((int)var7_10);
                resultMap.put((String)outputKeyList.get((int)var7_10), this.tensorToArray(tensorResult));
                ++var7_10;
            }
        }
        finally {
            TensorflowUtils.closeQuitely(formattedInput.values());
            TensorflowUtils.closeQuitely(result);
        }
        return resultMap;
    }

    private <T1, T2> Map<String, Map<T1, T2[]>> formatOutputs(List<T1> sampleKeyList, Map<String, T2[][]> rawOutputs) {
        HashMap<String, Map<T1, T2[]>> resultMap = new HashMap<String, Map<T1, T2[]>>();
        ArrayList<String> outputKeyList = new ArrayList<String>(this.outputSpecMap.keySet());
        for (String outputKey : outputKeyList) {
            HashMap<T1, T2[]> outputMap = new HashMap<T1, T2[]>();
            for (int i = 0; i < sampleKeyList.size(); ++i) {
                T2[] oneOutputs = rawOutputs.get(outputKey)[i];
                outputMap.put(sampleKeyList.get(i), oneOutputs);
            }
            resultMap.put(outputKey, outputMap);
        }
        return resultMap;
    }

    public <T1, T2> Map<String, Map<T1, T2[]>> predict(Map<T1, Map<String, String>> sampleWithFeatureMap) {
        ArrayList<T1> sampleKeyList = new ArrayList<T1>(sampleWithFeatureMap.keySet());
        Map<String, Tensor<?>> formattedInput = this.preprocessInput(sampleKeyList, sampleWithFeatureMap);
        Map<String, T2[][]> rawOutputs = this.executeTfGraph(formattedInput);
        Map<String, Map<T1, T2[]>> predResults = this.formatOutputs(sampleKeyList, rawOutputs);
        return predResults;
    }

    public static Tensor<?> listToTensor(List<String> input, TensorInfo inputInfo) {
        org.tensorflow.framework.DataType dtype = inputInfo.getDtype();
        int inputSize = input.size();
        if (dtype == org.tensorflow.framework.DataType.DT_FLOAT) {
            float[] arrayInput = new float[inputSize];
            for (int i = 0; i < input.size(); ++i) {
                String s = input.get(i);
                float inputWithType = 0.0f;
                if (s != null) {
                    try {
                        inputWithType = Float.parseFloat(s);
                    }
                    catch (NumberFormatException e) {
                        logger.warn("listToTensor exception", (Throwable)e);
                    }
                }
                arrayInput[i] = inputWithType;
            }
            return Tensors.create((float[])arrayInput);
        }
        if (dtype == org.tensorflow.framework.DataType.DT_STRING) {
            byte[][] arrayInput = new byte[inputSize][];
            for (int i = 0; i < input.size(); ++i) {
                String s = input.get(i);
                String inputWithType = "";
                if (s != null) {
                    inputWithType = s;
                }
                arrayInput[i] = inputWithType.getBytes(StandardCharsets.UTF_8);
            }
            return Tensors.create((byte[][])arrayInput);
        }
        return null;
    }

    private <T> T[][] tensorToArray(Tensor<?> tensorResult) {
        long[] tensorShape = tensorResult.shape();
        if (DataType.STRING == tensorResult.dataType()) {
            byte[][][] byteListResult = new byte[(int)tensorShape[0]][(int)tensorShape[1]][];
            tensorResult.copyTo((Object)byteListResult);
            String[][] listResult = new String[(int)tensorShape[0]][(int)tensorShape[1]];
            for (int p = 0; p < byteListResult.length; ++p) {
                for (int q = 0; q < byteListResult[p].length; ++q) {
                    listResult[p][q] = new String(byteListResult[p][q], StandardCharsets.UTF_8);
                }
            }
            return listResult;
        }
        if (DataType.FLOAT == tensorResult.dataType()) {
            float[][] listResult = new float[(int)tensorShape[0]][(int)tensorShape[1]];
            tensorResult.copyTo((Object)listResult);
            Float[][] listResult1 = new Float[(int)tensorShape[0]][(int)tensorShape[1]];
            for (int i = 0; i < listResult1.length; ++i) {
                for (int j = 0; j < listResult1[i].length; ++j) {
                    listResult1[i][j] = Float.valueOf(listResult[i][j]);
                }
            }
            return listResult1;
        }
        return null;
    }

    public static void main(String[] args) throws Exception {
        AdxLocalTFModel adxLocalTFModel = new AdxLocalTFModel();
        adxLocalTFModel.loadLatestModel("C:\\Users\\aaa\\Desktop\\adx\\model\\test");
        HashMap<String, String> featureMap = new HashMap<String, String>();
        featureMap.put("ft1", "aaa");
        HashMap<Long, HashMap<String, String>> sample = new HashMap<Long, HashMap<String, String>>();
        sample.put(123L, featureMap);
        Map predict = adxLocalTFModel.predict(sample);
        System.out.println(predict);
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getPath() {
        return this.path;
    }

    public Long getVersion() {
        return this.version;
    }

    public Map<String, TensorInfo> getInputSpecMap() {
        return this.inputSpecMap;
    }

    public Map<String, TensorInfo> getOutputSpecMap() {
        return this.outputSpecMap;
    }

    public void setBundle(SavedModelBundle bundle) {
        this.bundle = bundle;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public void setVersion(Long version) {
        this.version = version;
    }

    public void setInputSpecMap(Map<String, TensorInfo> inputSpecMap) {
        this.inputSpecMap = inputSpecMap;
    }

    public void setOutputSpecMap(Map<String, TensorInfo> outputSpecMap) {
        this.outputSpecMap = outputSpecMap;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AdxLocalTFModel)) {
            return false;
        }
        AdxLocalTFModel other = (AdxLocalTFModel)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SavedModelBundle this$bundle = this.getBundle();
        SavedModelBundle other$bundle = other.getBundle();
        if (this$bundle == null ? other$bundle != null : !this$bundle.equals(other$bundle)) {
            return false;
        }
        String this$path = this.getPath();
        String other$path = other.getPath();
        if (this$path == null ? other$path != null : !this$path.equals(other$path)) {
            return false;
        }
        Long this$version = this.getVersion();
        Long other$version = other.getVersion();
        if (this$version == null ? other$version != null : !((Object)this$version).equals(other$version)) {
            return false;
        }
        Map<String, TensorInfo> this$inputSpecMap = this.getInputSpecMap();
        Map<String, TensorInfo> other$inputSpecMap = other.getInputSpecMap();
        if (this$inputSpecMap == null ? other$inputSpecMap != null : !((Object)this$inputSpecMap).equals(other$inputSpecMap)) {
            return false;
        }
        Map<String, TensorInfo> this$outputSpecMap = this.getOutputSpecMap();
        Map<String, TensorInfo> other$outputSpecMap = other.getOutputSpecMap();
        return !(this$outputSpecMap == null ? other$outputSpecMap != null : !((Object)this$outputSpecMap).equals(other$outputSpecMap));
    }

    protected boolean canEqual(Object other) {
        return other instanceof AdxLocalTFModel;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SavedModelBundle $bundle = this.getBundle();
        result = result * 59 + ($bundle == null ? 43 : $bundle.hashCode());
        String $path = this.getPath();
        result = result * 59 + ($path == null ? 43 : $path.hashCode());
        Long $version = this.getVersion();
        result = result * 59 + ($version == null ? 43 : ((Object)$version).hashCode());
        Map<String, TensorInfo> $inputSpecMap = this.getInputSpecMap();
        result = result * 59 + ($inputSpecMap == null ? 43 : ((Object)$inputSpecMap).hashCode());
        Map<String, TensorInfo> $outputSpecMap = this.getOutputSpecMap();
        result = result * 59 + ($outputSpecMap == null ? 43 : ((Object)$outputSpecMap).hashCode());
        return result;
    }

    public String toString() {
        return "AdxLocalTFModel(bundle=" + this.getBundle() + ", path=" + this.getPath() + ", version=" + this.getVersion() + ", inputSpecMap=" + this.getInputSpecMap() + ", outputSpecMap=" + this.getOutputSpecMap() + ")";
    }
}

