/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.model.util.OSSUtils;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import cn.com.duiba.nezha.alg.model.vo.CodeSizeDo;
import com.alibaba.fastjson.JSON;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OnnxModel
extends OSSUtils {
    private static final Logger logger = LoggerFactory.getLogger(OnnxModel.class);
    public static String DF_OUTPUT = "prob";
    private static String ALL_INPUT_1 = "input_s";
    private static String ALL_INPUT_2 = "input_m";
    private static String ALL_INPUT_3 = "input_f";
    private static OrtEnvironment env = null;
    private static OrtSession session = null;

    public void getModelFromOss() throws Exception {
        String bucketName = "algorithm-models";
        String modelName = "dcvr_model/dcvr_model.onnx";
        try {
            byte[] modelBytes = OnnxModel.downloadModelFromOSS(bucketName, modelName);
            env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            session = env.createSession(modelBytes, options);
        }
        catch (Exception e) {
            logger.info("ONNX Model READ FROM OSS FAILURE", (Throwable)e);
        }
    }

    public <T> OnnxTensor getParamsTensorInt(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;
        if (featureSize > 0) {
            ArrayList<Integer> flatIntFeat = new ArrayList<Integer>();
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Integer> feature = codeDo.getSingleCode();
                flatIntFeat.addAll(feature);
            }
            int[] inputIntData = flatIntFeat.stream().mapToInt(Integer::intValue).toArray();
            long[] inputIntShape = new long[]{sampleSize, featureSize};
            try {
                ret = OnnxTensor.createTensor((OrtEnvironment)env, (IntBuffer)IntBuffer.wrap(inputIntData), (long[])inputIntShape);
            }
            catch (Exception e) {
                logger.info("Create SingleFeature FAILURE", (Throwable)e);
            }
        }
        return ret;
    }

    public <T> OnnxTensor getParamsTensorMulti(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;
        if (featureSize > 0) {
            int featureMaxLength = 0;
            ArrayList multiFeat = new ArrayList();
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<String> list = codeDo.getMultiCode();
                ArrayList rows = new ArrayList();
                for (int j = 0; j < featureSize; ++j) {
                    String[] feat = list.get(j).split(",");
                    ArrayList<Integer> row = new ArrayList<Integer>();
                    for (String s : feat) {
                        row.add(Integer.parseInt(s));
                    }
                    featureMaxLength = Math.max(featureMaxLength, feat.length);
                    rows.add(row);
                }
                multiFeat.add(rows);
            }
            int[] inputMultiData = new int[sampleSize * featureSize * featureMaxLength];
            int idx = 0;
            for (List list : multiFeat) {
                for (int j = 0; j < featureSize; ++j) {
                    List oneMultiFeat = (List)list.get(j);
                    for (Integer v : oneMultiFeat) {
                        inputMultiData[idx++] = v;
                    }
                    while (oneMultiFeat.size() < featureMaxLength) {
                        inputMultiData[idx++] = 0;
                    }
                }
            }
            long[] inputMultiShape = new long[]{sampleSize, featureSize, featureMaxLength};
            try {
                ret = OnnxTensor.createTensor((OrtEnvironment)env, (IntBuffer)IntBuffer.wrap(inputMultiData), (long[])inputMultiShape);
            }
            catch (Exception exception) {
                logger.info("Create MultiFeature FAILURE", (Throwable)exception);
            }
        }
        return ret;
    }

    public <T> OnnxTensor getParamsTensorFloat(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;
        if (featureSize > 0) {
            float[] inputFloatData = new float[sampleSize * featureSize];
            for (int i = 0; i < sampleSize; ++i) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Float> feature = codeDo.getFloatCode();
                for (int j = 0; j < featureSize; ++j) {
                    inputFloatData[i * featureSize + j] = feature.get(j).floatValue();
                }
            }
            long[] inputFloatShape = new long[]{sampleSize, featureSize};
            try {
                ret = OnnxTensor.createTensor((OrtEnvironment)env, (FloatBuffer)FloatBuffer.wrap(inputFloatData), (long[])inputFloatShape);
            }
            catch (Exception e) {
                logger.info("Create MultiFeature FAILURE", (Throwable)e);
            }
        }
        return ret;
    }

    public float[] predict(OnnxTensor input_s, OnnxTensor input_m, OnnxTensor input_f, String output_name) throws Exception {
        float[] ret = null;
        this.getModelFromOss();
        HashMap<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
        inputs.put("input_s", input_s);
        inputs.put("input_m", input_m);
        OrtSession.Result result = session.run(inputs);
        float[] prob = (float[])result.get(0).getValue();
        System.out.println("\u63a8\u7406\u7ed3\u679c: " + Arrays.toString(prob));
        ret = prob;
        return ret;
    }

    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, String output_name, CodeSizeDo codeSizeDo) throws Exception {
        HashMap ret = new HashMap();
        if (AssertUtil.isNotEmpty(dataMap)) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            int sampleSize = keyList.size();
            try (OnnxTensor input_s = this.getParamsTensorInt(keyList, dataMap, sampleSize, codeSizeDo.singleSize);
                 OnnxTensor input_m = this.getParamsTensorMulti(keyList, dataMap, sampleSize, codeSizeDo.multiSize);
                 OnnxTensor input_f = this.getParamsTensorFloat(keyList, dataMap, sampleSize, codeSizeDo.floatSize);){
                if (input_s == null) {
                    logger.info(" predictStr input_s is invalid, x=null");
                } else {
                    float[] pred = this.predict(input_s, input_m, input_f, output_name);
                    if (pred != null) {
                        for (int i = 0; i < sampleSize; ++i) {
                            ret.put(keyList.get(i), MathUtil.formatDouble((double)((double)pred[i] + 1.0E-6), (int)6));
                        }
                    }
                }
            }
            catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", (Throwable)e);
                throw e;
            }
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        int i;
        OnnxModel onnxModel = new OnnxModel();
        onnxModel.getModelFromOss();
        ArrayList<Integer> tmplist1 = new ArrayList<Integer>();
        ArrayList<String> tmplist2 = new ArrayList<String>();
        ArrayList<Float> tmplist3 = new ArrayList<Float>();
        for (i = 0; i < 60; ++i) {
            tmplist1.add(10);
        }
        for (i = 0; i < 3; ++i) {
            tmplist2.add("0,1,2,4,5");
        }
        for (i = 0; i < 9; ++i) {
            tmplist3.add(Float.valueOf(0.3f));
        }
        CodeSizeDo codeSizeDo = new CodeSizeDo();
        codeSizeDo.setSingleSize(60);
        codeSizeDo.setMultiSize(3);
        codeSizeDo.setFloatSize(9);
        HashMap<String, CodeDo> dataMap = new HashMap<String, CodeDo>();
        CodeDo codeDo = new CodeDo();
        codeDo.setSingleCode(tmplist1);
        codeDo.setMultiCode(tmplist2);
        codeDo.setFloatCode(tmplist3);
        dataMap.put("s1", codeDo);
        Map ret = onnxModel.predictAll(dataMap, "prob", codeSizeDo);
        System.out.println("ret=" + JSON.toJSONString(ret));
    }
}

