package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.model.util.OSSUtils;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.*;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import cn.com.duiba.nezha.alg.model.vo.CodeSizeDo;
import com.alibaba.fastjson.JSON;
import lombok.Data;
import org.slf4j.LoggerFactory;
import org.tensorflow.Tensor;

import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.*;

@Data
public class OnnxModel extends OSSUtils{
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(OnnxModel.class);
    public static String DF_OUTPUT = "prob";

    private static String ALL_INPUT_1 = "input_s";//单值特征
    private static String ALL_INPUT_2 = "input_m";//多值特征
    private static String ALL_INPUT_3 = "input_f";//浮点特征
    
    private static OrtEnvironment env = null;
    private static OrtSession session = null;
    private String version = null;


    public void getModelFromOss(String dirName) throws Exception {
        //oss存储的bucket名字
        String bucketName = "algorithm-models";
        //模型保存路径
        //String modelName = "adx_cvr_models/202412181535.onnx";

        try {
            String latestFileName = getLatestModelFromOSS(bucketName, dirName);
            this.version = latestFileName.split("/")[1].split("\\.")[0];
            byte[] modelBytes = downloadModelFromOSS(bucketName, latestFileName);
            this.env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        //this.session = env.createSession(modelBytes, options);
            this.session = this.env.createSession(modelBytes, options);
        } catch (Exception e) {
            logger.info("ONNX Model READ FROM OSS FAILURE", e);
        }

    }
    public <T> OnnxTensor getParamsTensorInt(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;

        if (featureSize > 0) {

//            int[] inputIntData = new int[sampleSize * featureSize];
            List<Integer> flatIntFeat = new ArrayList<>();
            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Integer> feature = codeDo.getSingleCode();
                flatIntFeat.addAll(feature);
            }
            int[] inputIntData = flatIntFeat.stream().mapToInt(Integer::intValue).toArray();
            long[] inputIntShape = {sampleSize, featureSize};
            try {
                ret = OnnxTensor.createTensor(this.env, IntBuffer.wrap(inputIntData), inputIntShape);
            } catch (Exception e) {
                logger.info("Create SingleFeature FAILURE", e);
            }

            //System.out.println("getParamsTensorInt"+JSON.toJSONString(matrix));
        }
        return ret;
    }

    public <T> OnnxTensor getParamsTensorMulti(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;

        if (featureSize > 0) {

            //数据格式转化
            int featureMaxLength = 0;
            List<List<List<Integer>>> multiFeat = new ArrayList<>();
            for (int i = 0; i < sampleSize; i++) { //处理每一条样本
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<String> feature = codeDo.getMultiCode();
                List<List<Integer>> rows = new ArrayList<>();
                for (int j = 0; j < featureSize; ++j) { //每一条样本下的特征
                    String[] feat = feature.get(j).split(",");
                    List<Integer> row = new ArrayList<>();
                    for (String s : feat) {
                        row.add(Integer.parseInt(s));
                    }
                    featureMaxLength = Math.max(featureMaxLength, feat.length);
                    rows.add(row);
                }
                multiFeat.add(rows);
            }
            int[] inputMultiData = new int[sampleSize * featureSize * featureMaxLength];
            int idx = 0;
            for (List<List<Integer>> seq: multiFeat) { // seq表示每一条样本 [featureSize, dynamic_len]
                for (int j = 0; j < featureSize; ++j) {
                    List<Integer> oneMultiFeat = seq.get(j); // 一条样本中的某个特征。
                    for (Integer v: oneMultiFeat) {
                        inputMultiData[idx++] = v;
                    }
                    int curMultiFeatLen = oneMultiFeat.size();
                    while (curMultiFeatLen < featureMaxLength) {
                        inputMultiData[idx++] = -1; //使用-1进行填充, 模型推理过程中会将多值特征中填充的-1删除。
                        curMultiFeatLen += 1;
                    }
                }
            }


            long[] inputMultiShape = {sampleSize, featureSize, featureMaxLength};
            try {
                ret = OnnxTensor.createTensor(this.env, IntBuffer.wrap(inputMultiData), inputMultiShape);
            } catch (Exception e) {
                logger.info("Create MultiFeature FAILURE", e);
            }

            //System.out.println("getParamsTensorInt"+JSON.toJSONString(matrix));
        }
        return ret;

    }

    public <T> OnnxTensor getParamsTensorFloat(List<T> keyList, Map<T, CodeDo> dataMap, int sampleSize, int featureSize) {
        OnnxTensor ret = null;

        if (featureSize > 0) {

            float[] inputFloatData = new float[sampleSize * featureSize];
            //数据格式转化
            for (int i = 0; i < sampleSize; i++) {
                T key = keyList.get(i);
                CodeDo codeDo = dataMap.get(key);
                List<Float> feature = codeDo.getFloatCode();
                for (int j = 0; j < featureSize; ++j) {
                    inputFloatData[i * featureSize + j] = feature.get(j);
                }
            }
            long[] inputFloatShape = {sampleSize, featureSize};
            try {
                ret = OnnxTensor.createTensor(this.env, FloatBuffer.wrap(inputFloatData), inputFloatShape);
            } catch (Exception e) {
                logger.info("Create MultiFeature FAILURE", e);
            }

            //System.out.println("getParamsTensorInt"+JSON.toJSONString(matrix));
        }
        return ret;
    }

    public float[] predict(OnnxTensor input_s, OnnxTensor input_m, OnnxTensor input_f, String output_name) throws Exception {
        float[] ret = null;


        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_s", input_s);
        inputs.put("input_m", input_m);
        inputs.put("input_f", input_f);
        // 运行推理
        OrtSession.Result result = session.run(inputs);

        // 获取输出结果
        float[] prob = (float[]) result.get(0).getValue();
        System.out.println("推理结果: " + Arrays.toString(prob));
        ret = prob;

        return ret;
    }

    public <T> Map<T, Double> predictAll(Map<T, CodeDo> dataMap, String output_name, CodeSizeDo codeSizeDo) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(dataMap)) {
            List<T> keyList = new ArrayList<>(dataMap.keySet());

            int sampleSize = keyList.size();

            try (OnnxTensor input_s = getParamsTensorInt(keyList, dataMap, sampleSize, codeSizeDo.singleSize);
                 OnnxTensor input_m = getParamsTensorMulti(keyList, dataMap, sampleSize, codeSizeDo.multiSize);
                 OnnxTensor input_f = getParamsTensorFloat(keyList, dataMap, sampleSize, codeSizeDo.floatSize);){
                if (input_s == null) {
                    logger.info(" predictStr input_s is invalid, x=null");
                } else {
                    try {
//                        DBTimeProfile.enter("LocalTFModel.predictStr.getPredictReturn");
                        float[] pred = predict(input_s, input_m, input_f, output_name);
                        if (pred != null) {
                            for (int i = 0; i < sampleSize; i++) {
                                ret.put(keyList.get(i), MathUtil.formatDouble(pred[i] + 0.000001, 6));
                            }
                        }
                    } finally {
//                        DBTimeProfile.release();
                    }
                }
            } catch (Exception e) {
                logger.info("predictStr(Map<T, String> dataMap) happend error", e);
                throw e;
            }
        }


        return ret;
    }
    public static void main(String[] args) throws Exception{

        OnnxModel onnxModel = new OnnxModel();
        onnxModel.getModelFromOss("adx_cvr_models");

        /*构建预测张量*/
        List<Integer> tmplist1 = new ArrayList<>();
//        List<List<Integer>> tmplist2 = new ArrayList<>();
        List<String> tmplist2 = new ArrayList<>();
        List<Float> tmplist3 = new ArrayList<>();
        for (int i = 0; i < 74; i++) {
            tmplist1.add(10);
        }

        for (int i = 0; i < 1; i++) {
            tmplist2.add("0,1,2,4,5");
        }
        for (int i = 1; i < 5; ++i) {
            tmplist2.add("1,2");
        }

        for (int i = 0; i < 9; ++i) {
            tmplist3.add(0.3f);
        }

        CodeSizeDo codeSizeDo = new CodeSizeDo();
        codeSizeDo.setSingleSize(74);
        codeSizeDo.setMultiSize(5);
        codeSizeDo.setFloatSize(9);

        Map<String, CodeDo> dataMap = new HashMap<>();
        CodeDo codeDo=new CodeDo();
        codeDo.setSingleCode(tmplist1);
        codeDo.setMultiCode(tmplist2);
        codeDo.setFloatCode(tmplist3);
        dataMap.put("s1",codeDo);

        Map<String, Double> ret = onnxModel.predictAll(dataMap,"prob",codeSizeDo);


        System.out.println("ret=" + JSON.toJSONString(ret));
    }
}




