/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.alg.recall.model;

import cn.com.duiba.nezha.alg.common.util.NumberUtil;
import cn.com.duiba.nezha.alg.model.tf.TensorflowUtils;
import com.google.common.primitives.Floats;
import java.io.File;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.DataType;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class DSSMLocalTFModel {
    private static final Logger logger = LoggerFactory.getLogger(DSSMLocalTFModel.class);
    private static String SIGNATURE_DEF = "serving_default";
    private SavedModelBundle bundle;
    private String path;
    private String version;
    private Map<String, TensorInfo> inputSpecMap;
    private Map<String, TensorInfo> outputSpecMap;

    public void setModel(SavedModelBundle bundle) throws Exception {
        this.bundle = bundle;
        this.inputSpecMap = this.getInputSpecMap(bundle, SIGNATURE_DEF);
        this.outputSpecMap = this.getOutputSpecMap(bundle, SIGNATURE_DEF);
    }

    public void loadModel(String path) throws Exception {
        Long lastVersion = this.getLastVersion(path);
        this.loadModel(path, lastVersion + "");
    }

    public Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {
            File[] listFiles = file.listFiles();
            for (int i = 0; i < listFiles.length; ++i) {
                File f = listFiles[i];
                String fName = f.getName();
                if (!NumberUtil.isNumber((String)fName)) continue;
                Long version = Long.valueOf(fName);
                if (ret != null && ret >= version) continue;
                ret = version;
            }
        }
        return ret;
    }

    public void loadModel(String path, String version) throws Exception {
        if (version == null) {
            version = TensorflowUtils.getLastVersion((String)path);
        }
        SavedModelBundle bundle = TensorflowUtils.getModelBundle((String)path, (String)version);
        this.setModel(bundle);
        this.path = path;
        this.version = version;
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,paht=" + this.path + "with version=" + this.version);
            this.bundle.close();
            this.bundle = null;
        }
    }

    public Map getInputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        Map inputSpecMap = sig.getInputsMap();
        return inputSpecMap;
    }

    public Map getOutputSpecMap(SavedModelBundle bundle, String SIGNATURE_DEF) throws Exception {
        SignatureDef sig = MetaGraphDef.parseFrom((byte[])bundle.metaGraphDef()).getSignatureDefOrThrow(SIGNATURE_DEF);
        Map outputSpecMap = sig.getOutputsMap();
        return outputSpecMap;
    }

    public Map<String, Tensor> preprocessInput(Map<String, String> featureMap) throws UnsupportedEncodingException {
        HashMap<String, Tensor> inputTensorMap = new HashMap<String, Tensor>();
        for (String inputKey : this.inputSpecMap.keySet()) {
            Tensor tensorInput;
            Object arrayInput;
            TensorInfo inputInfo = this.inputSpecMap.get(inputKey);
            String featureKey = inputInfo.getName();
            if (inputInfo.getDtype() == org.tensorflow.framework.DataType.DT_FLOAT) {
                float rawInput = 0.0f;
                try {
                    rawInput = Float.parseFloat(featureMap.getOrDefault(inputKey, "0.0"));
                }
                catch (NumberFormatException e) {
                    e.printStackTrace();
                }
                arrayInput = new float[][]{{rawInput}};
                tensorInput = Tensor.create((Object)arrayInput, Float.class);
                inputTensorMap.put(featureKey, tensorInput);
                continue;
            }
            String rawInput = featureMap.getOrDefault(inputKey, "");
            if (rawInput == null) {
                rawInput = "";
            }
            arrayInput = new byte[][][]{new byte[][]{rawInput.getBytes("UTF-8")}};
            tensorInput = Tensor.create((Object)arrayInput, String.class);
            inputTensorMap.put(featureKey, tensorInput);
        }
        return inputTensorMap;
    }

    public List tensorToList(Tensor tensorResult) {
        long[] tensorShape = tensorResult.shape();
        if (DataType.STRING == tensorResult.dataType()) {
            byte[][][] byteListResult = new byte[(int)tensorShape[0]][(int)tensorShape[1]][];
            tensorResult.copyTo((Object)byteListResult);
            String[][] listResult = new String[(int)tensorShape[0]][(int)tensorShape[1]];
            for (int p = 0; p < byteListResult.length; ++p) {
                for (int q = 0; q < byteListResult[p].length; ++q) {
                    listResult[p][q] = new String(byteListResult[p][q], StandardCharsets.UTF_8);
                }
            }
            return Arrays.asList(listResult[0]);
        }
        if (DataType.FLOAT == tensorResult.dataType()) {
            float[][] listResult = new float[(int)tensorShape[0]][(int)tensorShape[1]];
            tensorResult.copyTo((Object)listResult);
            return Floats.asList((float[])listResult[0]);
        }
        return null;
    }

    public Map<String, List> predict(Map<String, String> featureMap) throws UnsupportedEncodingException {
        Map<String, Tensor> inputTensorMap = this.preprocessInput(featureMap);
        ArrayList<String> outputNameList = new ArrayList<String>();
        for (String string : this.outputSpecMap.keySet()) {
            outputNameList.add(this.outputSpecMap.get(string).getName());
        }
        Session.Runner runner = this.bundle.session().runner();
        for (String inputKey : inputTensorMap.keySet()) {
            runner.feed(inputKey, inputTensorMap.get(inputKey));
        }
        for (String outputName : outputNameList) {
            runner.fetch(outputName);
        }
        List list = runner.run();
        HashMap<String, List> resultMap = new HashMap<String, List>();
        for (int i = 0; i < outputNameList.size(); ++i) {
            Tensor tensorResult = (Tensor)list.get(i);
            resultMap.put((String)outputNameList.get(i), this.tensorToList(tensorResult));
        }
        for (Tensor tensor : inputTensorMap.values()) {
            tensor.close();
        }
        for (Tensor tensor : list) {
            tensor.close();
        }
        return resultMap;
    }

    public static void main(String[] args) throws Exception {
        int i;
        byte[][][] arrayInput = new byte[][][]{new byte[][]{"asdfasdfd".getBytes("UTF-8")}};
        String path = "C:\\Users\\wangyaole\\Desktop\\nezha\\model\\dssm01";
        DSSMLocalTFModel model = new DSSMLocalTFModel();
        model.loadModel(path);
        HashMap<String, String> featureMap = new HashMap<String, String>();
        String colStr = "f501001,f504001,f508005,f108001,f201001,f507001,f503001,f406001,f502001,f502002,afee,ad_pk";
        String valStr = "Android,HLK-AL00,\u8363\u8000,325913,70125,3,3609,1,16,2,7.2,72093_177812";
        String[] colList = colStr.split(",");
        String[] valList = valStr.split(",");
        for (i = 0; i < colList.length - 1; ++i) {
            featureMap.put(colList[i], valList[i]);
        }
        for (i = 0; i < 10000; ++i) {
            Map<String, List> result = model.predict(featureMap);
            if (i % 1000 != 0) continue;
            System.out.println(i);
        }
        Map<String, List> result = model.predict(featureMap);
        System.out.println(result);
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getPath() {
        return this.path;
    }

    public String getVersion() {
        return this.version;
    }

    public Map<String, TensorInfo> getInputSpecMap() {
        return this.inputSpecMap;
    }

    public Map<String, TensorInfo> getOutputSpecMap() {
        return this.outputSpecMap;
    }

    public void setBundle(SavedModelBundle bundle) {
        this.bundle = bundle;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public void setVersion(String version) {
        this.version = version;
    }

    public void setInputSpecMap(Map<String, TensorInfo> inputSpecMap) {
        this.inputSpecMap = inputSpecMap;
    }

    public void setOutputSpecMap(Map<String, TensorInfo> outputSpecMap) {
        this.outputSpecMap = outputSpecMap;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DSSMLocalTFModel)) {
            return false;
        }
        DSSMLocalTFModel other = (DSSMLocalTFModel)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SavedModelBundle this$bundle = this.getBundle();
        SavedModelBundle other$bundle = other.getBundle();
        if (this$bundle == null ? other$bundle != null : !this$bundle.equals(other$bundle)) {
            return false;
        }
        String this$path = this.getPath();
        String other$path = other.getPath();
        if (this$path == null ? other$path != null : !this$path.equals(other$path)) {
            return false;
        }
        String this$version = this.getVersion();
        String other$version = other.getVersion();
        if (this$version == null ? other$version != null : !this$version.equals(other$version)) {
            return false;
        }
        Map<String, TensorInfo> this$inputSpecMap = this.getInputSpecMap();
        Map<String, TensorInfo> other$inputSpecMap = other.getInputSpecMap();
        if (this$inputSpecMap == null ? other$inputSpecMap != null : !((Object)this$inputSpecMap).equals(other$inputSpecMap)) {
            return false;
        }
        Map<String, TensorInfo> this$outputSpecMap = this.getOutputSpecMap();
        Map<String, TensorInfo> other$outputSpecMap = other.getOutputSpecMap();
        return !(this$outputSpecMap == null ? other$outputSpecMap != null : !((Object)this$outputSpecMap).equals(other$outputSpecMap));
    }

    protected boolean canEqual(Object other) {
        return other instanceof DSSMLocalTFModel;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SavedModelBundle $bundle = this.getBundle();
        result = result * 59 + ($bundle == null ? 43 : $bundle.hashCode());
        String $path = this.getPath();
        result = result * 59 + ($path == null ? 43 : $path.hashCode());
        String $version = this.getVersion();
        result = result * 59 + ($version == null ? 43 : $version.hashCode());
        Map<String, TensorInfo> $inputSpecMap = this.getInputSpecMap();
        result = result * 59 + ($inputSpecMap == null ? 43 : ((Object)$inputSpecMap).hashCode());
        Map<String, TensorInfo> $outputSpecMap = this.getOutputSpecMap();
        result = result * 59 + ($outputSpecMap == null ? 43 : ((Object)$outputSpecMap).hashCode());
        return result;
    }

    public String toString() {
        return "DSSMLocalTFModel(bundle=" + this.getBundle() + ", path=" + this.getPath() + ", version=" + this.getVersion() + ", inputSpecMap=" + this.getInputSpecMap() + ", outputSpecMap=" + this.getOutputSpecMap() + ")";
    }
}

