/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.api.model.TFModel;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.model.tf.TensorflowUtils;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import cn.com.duiba.nezha.alg.model.vo.CodeSizeDo;
import com.alibaba.fastjson.JSON;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class LocalTFModelV2
extends TensorflowUtils
implements TFModel {
    private static final Logger logger = LoggerFactory.getLogger(LocalTFModelV2.class);
    private int DF_MAP_SIZE = 128;
    private static String SIGNATURE_DEF = "serving_default";
    public static String DF_OUTPUT = "prob";
    private static String ALL_INPUT_1 = "input_s";
    private static String ALL_INPUT_2 = "input_m";
    private static String ALL_INPUT_3 = "input_f";
    private SavedModelBundle bundle = null;
    private String version = null;
    private String path = null;
    private String inputName1 = null;
    private String inputName2 = null;
    private String inputName3 = null;
    private Map<String, String> outputNameMap = null;

    public void loadModel(String path) throws Exception {
        this.loadModel(path, null);
    }

    public void loadModel(String path, String version) throws Exception {
        if (version == null) {
            version = LocalTFModelV2.getLastVersion(path);
        }
        this.bundle = LocalTFModelV2.getModelBundle(path, version);
        this.inputName1 = TensorflowUtils.getInputName(ALL_INPUT_1, this.bundle, SIGNATURE_DEF);
        this.inputName2 = TensorflowUtils.getInputName(ALL_INPUT_2, this.bundle, SIGNATURE_DEF);
        this.inputName3 = TensorflowUtils.getInputName(ALL_INPUT_3, this.bundle, SIGNATURE_DEF);
        this.path = path;
        this.version = version;
    }

    public String getOutputName(String output) throws Exception {
        if (this.outputNameMap == null) {
            this.outputNameMap = new HashMap<String, String>();
        }
        if (!this.outputNameMap.containsKey(output)) {
            try {
                String outputName = TensorflowUtils.getOutputName(output, this.bundle, SIGNATURE_DEF);
                if (outputName != null) {
                    this.outputNameMap.put(output, outputName);
                }
            }
            catch (Exception e) {
                System.out.println("getOutputName , empty");
                logger.info("getOutputName empty, output: {}, model path: {}{}", new Object[]{output, this.path, this.version, e});
                throw e;
            }
        }
        if (!this.outputNameMap.containsKey(output)) {
            // empty if block
        }
        return this.outputNameMap.get(output);
    }

    public void close() throws Exception {
        if (this.bundle != null) {
            logger.info("close model  ,path=" + this.path);
            this.bundle.close();
            this.bundle = null;
        }
    }

    public <T> float[] predict(Tensor<Integer> input1, Tensor<String> input2, Tensor<Float> input3, int size, String output) throws Exception {
        float[] ret = null;
        String outputName = this.getOutputName(output);
        if (outputName == null) {
            logger.info("predict(Tensor<T> x, int size, String output)  invalid,outputName= " + outputName);
        }
        Session.Runner runner = this.bundle.session().runner();
        if (this.inputName1 != null) {
            runner.feed(this.inputName1, input1);
        }
        if (this.inputName2 != null) {
            runner.feed(this.inputName2, input2);
        }
        if (this.inputName3 != null) {
            runner.feed(this.inputName3, input3);
        }
        List y = runner.fetch(outputName).run();
        float[] result = new float[size];
        ((Tensor)y.get(0)).copyTo((Object)result);
        ((Tensor)y.get(0)).close();
        ret = result;
        return ret;
    }

    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, CodeSizeDo codeSizeDo) throws Exception {
        return this.predictAll(dataMap, DF_OUTPUT, codeSizeDo);
    }

    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, String output, CodeSizeDo codeSizeDo) throws Exception {
        Map<Object, Object> ret = new HashMap();
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            int sampleSize = keyList.size();
            try (Tensor input1 = LocalTFModelV2.getParamsTensorInt(keyList, dataMap, sampleSize, codeSizeDo.singleSize);
                 Tensor input2 = LocalTFModelV2.getParamsTensorStr(keyList, dataMap, sampleSize, codeSizeDo.multiSize);
                 Tensor input3 = LocalTFModelV2.getParamsTensorFloat(keyList, dataMap, sampleSize, codeSizeDo.floatSize);){
                if (input1 == null) {
                    logger.info(" predictStr input1 is invalid, x=null");
                } else {
                    ret = this.getPredictReturn(keyList, (Tensor<Integer>)input1, (Tensor<String>)input2, (Tensor<Float>)input3, output);
                }
            }
            catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", (Throwable)e);
                throw e;
            }
        }
        return ret;
    }

    public <T, K> Map<T, Double> getPredictReturn(List<T> keyList, Tensor<Integer> input1, Tensor<String> input2, Tensor<Float> input3, String output) throws Exception {
        HashMap<T, Double> ret = new HashMap<T, Double>(this.DF_MAP_SIZE);
        int sampleSize = keyList.size();
        float[] pred = this.predict(input1, input2, input3, sampleSize, output);
        if (pred != null) {
            for (int i = 0; i < sampleSize; ++i) {
                ret.put(keyList.get(i), MathUtil.formatDouble((double)((double)pred[i] + 1.0E-6), (int)6));
            }
        }
        return ret;
    }

    public static <T> Tensor getParamsTensorFloat(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        Tensor ret = null;
        if (featureSize > 0) {
            float[][] matrix = new float[sampleSize][featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Float> feature = codeDo.getFloatCode();
                for (int j = 0; j < featureSize; ++j) {
                    matrix[i][j] = feature.get(j).floatValue();
                }
            }
            ret = Tensor.create((Object)matrix, Float.class);
        }
        return ret;
    }

    public static <T> Tensor getParamsTensorInt(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        Tensor ret = null;
        if (featureSize > 0) {
            int[][] matrix = new int[sampleSize][featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Integer> feature = codeDo.getSingleCode();
                for (int j = 0; j < featureSize; ++j) {
                    matrix[i][j] = feature.get(j);
                }
            }
            ret = Tensor.create((Object)matrix, Integer.class);
        }
        return ret;
    }

    public static <T> Tensor getParamsTensorStr(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) throws Exception {
        Tensor ret = null;
        if (featureSize > 0) {
            byte[][][] matrix = new byte[sampleSize][featureSize][];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<String> featureList = codeDo.getMultiCode();
                for (int j = 0; j < featureSize; ++j) {
                    String feature = featureList.get(j);
                    matrix[i][j] = feature.getBytes("UTF-8");
                }
            }
            ret = Tensor.create((Object)matrix, String.class);
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        int i;
        LocalTFModelV2 model = new LocalTFModelV2();
        model.loadModel("/Users/lwj/Documents/pythonProjectFiles/tf/data/serving-model/deepfm36/", null);
        System.out.println(JSON.toJSONString((Object)model));
        ArrayList<Integer> tmplist1 = new ArrayList<Integer>();
        ArrayList<String> tmplist2 = new ArrayList<String>();
        for (i = 0; i < 92; ++i) {
            tmplist1.add(10);
        }
        for (i = 0; i < 22; ++i) {
            tmplist2.add("0");
        }
        CodeSizeDo codeSizeDo = new CodeSizeDo();
        codeSizeDo.setSingleSize(92);
        codeSizeDo.setMultiSize(22);
        HashMap<String, CodeDo> dataMap = new HashMap<String, CodeDo>();
        CodeDo codeDo = new CodeDo();
        codeDo.setSingleCode(tmplist1);
        codeDo.setMultiCode(tmplist2);
        dataMap.put("s1", codeDo);
        Map ret = model.predictAll(dataMap, codeSizeDo);
        System.out.println("ret=" + JSON.toJSONString(ret));
    }

    public int getDF_MAP_SIZE() {
        return this.DF_MAP_SIZE;
    }

    public SavedModelBundle getBundle() {
        return this.bundle;
    }

    public String getVersion() {
        return this.version;
    }

    public String getPath() {
        return this.path;
    }

    public String getInputName1() {
        return this.inputName1;
    }

    public String getInputName2() {
        return this.inputName2;
    }

    public String getInputName3() {
        return this.inputName3;
    }

    public Map<String, String> getOutputNameMap() {
        return this.outputNameMap;
    }

    public void setDF_MAP_SIZE(int DF_MAP_SIZE) {
        this.DF_MAP_SIZE = DF_MAP_SIZE;
    }

    public void setBundle(SavedModelBundle bundle) {
        this.bundle = bundle;
    }

    public void setVersion(String version) {
        this.version = version;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public void setInputName1(String inputName1) {
        this.inputName1 = inputName1;
    }

    public void setInputName2(String inputName2) {
        this.inputName2 = inputName2;
    }

    public void setInputName3(String inputName3) {
        this.inputName3 = inputName3;
    }

    public void setOutputNameMap(Map<String, String> outputNameMap) {
        this.outputNameMap = outputNameMap;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LocalTFModelV2)) {
            return false;
        }
        LocalTFModelV2 other = (LocalTFModelV2)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getDF_MAP_SIZE() != other.getDF_MAP_SIZE()) {
            return false;
        }
        SavedModelBundle this$bundle = this.getBundle();
        SavedModelBundle other$bundle = other.getBundle();
        if (this$bundle == null ? other$bundle != null : !this$bundle.equals(other$bundle)) {
            return false;
        }
        String this$version = this.getVersion();
        String other$version = other.getVersion();
        if (this$version == null ? other$version != null : !this$version.equals(other$version)) {
            return false;
        }
        String this$path = this.getPath();
        String other$path = other.getPath();
        if (this$path == null ? other$path != null : !this$path.equals(other$path)) {
            return false;
        }
        String this$inputName1 = this.getInputName1();
        String other$inputName1 = other.getInputName1();
        if (this$inputName1 == null ? other$inputName1 != null : !this$inputName1.equals(other$inputName1)) {
            return false;
        }
        String this$inputName2 = this.getInputName2();
        String other$inputName2 = other.getInputName2();
        if (this$inputName2 == null ? other$inputName2 != null : !this$inputName2.equals(other$inputName2)) {
            return false;
        }
        String this$inputName3 = this.getInputName3();
        String other$inputName3 = other.getInputName3();
        if (this$inputName3 == null ? other$inputName3 != null : !this$inputName3.equals(other$inputName3)) {
            return false;
        }
        Map<String, String> this$outputNameMap = this.getOutputNameMap();
        Map<String, String> other$outputNameMap = other.getOutputNameMap();
        return !(this$outputNameMap == null ? other$outputNameMap != null : !((Object)this$outputNameMap).equals(other$outputNameMap));
    }

    protected boolean canEqual(Object other) {
        return other instanceof LocalTFModelV2;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getDF_MAP_SIZE();
        SavedModelBundle $bundle = this.getBundle();
        result = result * 59 + ($bundle == null ? 43 : $bundle.hashCode());
        String $version = this.getVersion();
        result = result * 59 + ($version == null ? 43 : $version.hashCode());
        String $path = this.getPath();
        result = result * 59 + ($path == null ? 43 : $path.hashCode());
        String $inputName1 = this.getInputName1();
        result = result * 59 + ($inputName1 == null ? 43 : $inputName1.hashCode());
        String $inputName2 = this.getInputName2();
        result = result * 59 + ($inputName2 == null ? 43 : $inputName2.hashCode());
        String $inputName3 = this.getInputName3();
        result = result * 59 + ($inputName3 == null ? 43 : $inputName3.hashCode());
        Map<String, String> $outputNameMap = this.getOutputNameMap();
        result = result * 59 + ($outputNameMap == null ? 43 : ((Object)$outputNameMap).hashCode());
        return result;
    }

    public String toString() {
        return "LocalTFModelV2(DF_MAP_SIZE=" + this.getDF_MAP_SIZE() + ", bundle=" + this.getBundle() + ", version=" + this.getVersion() + ", path=" + this.getPath() + ", inputName1=" + this.getInputName1() + ", inputName2=" + this.getInputName2() + ", inputName3=" + this.getInputName3() + ", outputNameMap=" + this.getOutputNameMap() + ")";
    }
}

