package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType2;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.vo.CodeDo;
import cn.com.duiba.nezha.alg.model.vo.CodeSizeDo;
import com.alibaba.fastjson.JSON;
import lombok.Data;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 单模型+单编码器版本（二期可支持扩展 同场景，多模型使用同一个编码器，降低内存使用）
 */
@Data
public class DeepBaseModelV2 {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DeepBaseModelV2.class);

    private int FBT_MAX_SIZE = 64;

    private int PB_MAX_SIZE = 128;

    private Integer S_F_IDX = 1;
    private Integer M_F_IDX = 2;
    private Integer F_F_IDX = 3;

    public static Map<String, Integer> tmpIntMap = new HashMap<>();

    public static Map tmpMap = new HashMap<>();


    /**
     * 模型预估，获取稠密编码
     *
     * @param featureMaps       特征对象
     * @param coderCache        全局缓存对象
     * @param hasSubCache       是否开启请求级别缓存
     * @param hasMultValueCache 多值特征是否开启缓存
     * @return
     * @throws Exception
     */
    public <T> Map<T, CodeDo> getPredDenseCodes(Map<T, FeatureMapDo> featureMaps,
                                                Map<String, Map<String, List<Integer>>> coderCache,
                                                String modelId,
                                                boolean hasSubCache, boolean hasMultValueCache, List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        Map<T, CodeDo> ret = new HashMap<>(PB_MAX_SIZE);


        // 判断是否需要请求级别缓存
        if (hasSubCache) {
            coderCache = new HashMap<>();
            hasMultValueCache = true;
        }

        if (AssertUtil.isNotEmpty(featureMaps)) {

            for (Map.Entry<T, FeatureMapDo> entry : featureMaps.entrySet()) {
                FeatureMapDo featureMapDo = entry.getValue();
                ret.put(entry.getKey(), getDenseCode(featureMapDo.staticFeatureMap, featureMapDo.dynamicFeatureMap, coderCache, hasMultValueCache, modelId, featureList, featureDenseCoderMap));
            }

        }
        if (AssertUtil.isEmpty(ret)) {
            logger.warn(modelId + " getCodes is invalid, featureMap is null or {}");
        }
        return ret;

    }

    /**
     * 获取稠密编码
     *
     * @param featureMaps
     * @param coderCache
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<T, CodeDo> getSampleDenses(Map<T, Map<String, String>> featureMaps,
                                              Map<String, Map<String, List<Integer>>> coderCache,
                                              String modelId,
                                              List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        Map<T, CodeDo> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMaps)) {

            for (Map.Entry<T, Map<String, String>> entry : featureMaps.entrySet()) {
                ret.put(entry.getKey(), getDenseCode(entry.getValue(), null, coderCache, false, modelId, featureList, featureDenseCoderMap));
            }
        }
        if (AssertUtil.isEmpty(ret)) {
            logger.warn(modelId + " getCodes is invalid, featureMap is null or {}");
        }
        return ret;

    }


    /**
     * 获取稠密编码
     *
     * @param featureMap
     * @param coderCache
     * @param
     * @return
     * @throws Exception
     */
    public CodeDo getSampleDense(Map<String, String> featureMap,
                                 Map<String, Map<String, List<Integer>>> coderCache,
                                 String modelId,
                                 List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        return getDenseCode(featureMap, null, coderCache, false, modelId, featureList, featureDenseCoderMap);
    }


    /**
     *
     * @param featureList
     * @return
     * @throws Exception
     */
    public CodeSizeDo getCodeSize(List<FeatureBaseType2> featureList) throws Exception {
        CodeSizeDo ret = new  CodeSizeDo();

        int singleSize=0;
        int multiSize=0;
        int floatSize=0;
        try {

            for (FeatureBaseType2 featureBaseType : featureList) {

                /**
                 * 单值特征
                 */

                if (S_F_IDX.equals(featureBaseType.getCodeType())) {
                    singleSize+=1;
                }

                /**
                 * 多值特征
                 */
                if (M_F_IDX.equals(featureBaseType.getCodeType())) {
                    multiSize+=featureBaseType.subLen;
                }

                /**
                 * 浮点特征
                 */
                if (F_F_IDX.equals(featureBaseType.getCodeType())) {
                    floatSize+=featureBaseType.subLen;
                }
            }

        } catch (Exception e) {
            logger.warn( " getDenseCode error ", e);
        }
        ret.setSingleSize(singleSize);
        ret.setMultiSize(multiSize);
        ret.setFloatSize(floatSize);
        return ret;
    }

    /**
     * 编码，带缓存
     *
     * @param staticFeatureMap        静态特征
     * @param dynamicFeatureMap       动态特征
     * @param coderCache              编码缓存
     * @param hasStaticMultValueCache 静态特征多值是否缓存
     * @param modelId
     * @param featureList
     * @param featureDenseCoderMap
     * @return
     * @throws Exception
     */
    public CodeDo getDenseCode(Map<String, String> staticFeatureMap,
                               Map<String, String> dynamicFeatureMap,
                               Map<String, Map<String, List<Integer>>> coderCache,
                               boolean hasStaticMultValueCache,
                               String modelId,
                               List<FeatureBaseType2> featureList,
                               Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        CodeDo ret = new CodeDo();

        List<Integer> sCode = new ArrayList<>(); // 单值特征编码，唯一，特征N的编码=原始编码+ sStartIndex

        List<Integer> mCode = new ArrayList<>();//多值特征编码

        List<Float> fCode = new ArrayList<>();//浮点特征编码


        int sStartIndex = 0; // 单值特征编码动态，第N个单值特征，sStartIndex=sum(featureBaseType.getDenseLen())
        int mStartIndex = 0; // 单值特征编码动态，第N个单值特征，sStartIndex=sum(featureBaseType.getDenseLen())
        try {

            if (staticFeatureMap == null) {
                staticFeatureMap = tmpMap;
            }

            if (dynamicFeatureMap == null) {
                dynamicFeatureMap = tmpMap;
            }


            for (FeatureBaseType2 featureBaseType : featureList) {


                String value = dynamicFeatureMap.get(featureBaseType.featureId) == null ? staticFeatureMap.get(featureBaseType.featureId) : dynamicFeatureMap.get(featureBaseType.featureName);


                /**
                 * 单值特征
                 */

                if (S_F_IDX.equals(featureBaseType.getCodeType())) {
                    List<Integer> filedCode = new ArrayList<>();
                    if (coderCache != null) {
                        Map<String, List<Integer>> cache = coderCache.getOrDefault(featureBaseType.featureId, tmpMap);
                        if (!cache.containsKey(value)) {
                            filedCode = getSingleFieldDenseCode(featureBaseType, value, sStartIndex, featureDenseCoderMap);
                            cache.put(value, filedCode);
                        }
                        filedCode = cache.get(value);
                    } else {
                        filedCode = getSingleFieldDenseCode(featureBaseType, value, sStartIndex, featureDenseCoderMap);
                    }

                    sCode.addAll(filedCode);

                    sStartIndex += featureBaseType.denseLen + 1;
                }

                /**
                 * 多值特征
                 */
                if (M_F_IDX.equals(featureBaseType.getCodeType())) {
                    List<Integer> filedCode = new ArrayList<>();
                    if (coderCache != null) {
                        Map<String, List<Integer>> cache = coderCache.getOrDefault(featureBaseType.featureId, tmpMap);
                        if (!cache.containsKey(value)) {
                            filedCode = getMultiFieldDenseCode(featureBaseType, value, mStartIndex, featureDenseCoderMap);
                            cache.put(value, filedCode);
                        }
                        filedCode = cache.get(value);
                    } else {
                        filedCode = getMultiFieldDenseCode(featureBaseType, value, mStartIndex, featureDenseCoderMap);
                    }

                    mCode.addAll(filedCode);

                    mStartIndex += featureBaseType.denseLen + 1;
                }

                /**
                 * 浮点特征
                 */

                if (F_F_IDX.equals(featureBaseType.getCodeType())) {
                    List<Float> filedCode = getFloatFieldDenseCode(featureBaseType, value);
                    fCode.addAll(filedCode);
                }
            }
            ret.setFloatCode(fCode);
            ret.setSingleCode(sCode);
            ret.setMultiCode(mCode);

        } catch (Exception e) {
            logger.warn(modelId + " getDenseCode ", e);
        }


        return ret;
    }


    /**
     * 空值处理
     * 1、单值特征：默认填充0
     * 2、多值特征：默认填充0（模型做avg pool处理）
     *
     * @param featureBaseType
     * @param featureValue
     * @return
     * @throws Exception
     */
    public List<Integer> getSingleFieldDenseCode(FeatureBaseType2 featureBaseType,
                                                 String featureValue,
                                                 int startIndex,
                                                 Map<String, Map<String, Integer>> featureDenseCoderMap
    ) throws Exception {
        List<Integer> ret = new ArrayList<>();
        String featureValueStd = featureBaseType.std(featureValue);

        /**
         * 单值特征
         */
        Integer index = featureDenseCoderMap.getOrDefault(featureBaseType.featureId, tmpIntMap).getOrDefault(featureValueStd, 0);
        // 超出阈值，非法，置空
        if (index >= featureBaseType.denseLen) {
            index = 0;
        }
        ret.add(index + startIndex);
        return ret;
    }

    /**
     * 空值处理
     * 1、单值特征：默认填充0
     * 2、多值特征：默认填充0（模型做avg pool处理）
     *
     * @param featureBaseType
     * @param featureValue
     * @return
     * @throws Exception
     */
    public List<Integer> getMultiFieldDenseCode(FeatureBaseType2 featureBaseType,
                                                String featureValue,
                                                int startIndex,
                                                Map<String, Map<String, Integer>> featureDenseCoderMap
    ) throws Exception {
        List<Integer> ret = new ArrayList<>();
        String featureValueStd = featureBaseType.std(featureValue);

        /**
         * 多值特征
         */
        String[] featureArray = featureBaseType.split(featureValueStd);

        int featuresSize = featureArray == null ? 0 : featureArray.length;

        for (int i = 0; i < featureBaseType.subLen; i++) {
            Integer index = 0;
            if (i < featuresSize) {
                String valueStd = featureBaseType.std(featureArray[i]);
                index = featureDenseCoderMap.getOrDefault(featureBaseType.featureId, tmpIntMap).getOrDefault(valueStd, 0);

                // 超出阈值，非法，置空
                if (index >= featureBaseType.denseLen) {
                    index = 0;
                }
            }

            ret.add(index + startIndex);
        }
        return ret;
    }


    /**
     * 空值处理
     * 1、单值特征：默认填充0
     * 2、多值特征：默认填充0（模型做avg pool处理）
     *
     * @param featureBaseType
     * @param featureValue
     * @return
     * @throws Exception
     */
    public List<Float> getFloatFieldDenseCode(FeatureBaseType2 featureBaseType,
                                              String featureValue
    ) throws Exception {
        List<Float> ret = new ArrayList<>();
        String featureValueStd = featureBaseType.std(featureValue);

        /**
         * 多值特征
         */
        String[] featureArray = featureBaseType.split(featureValueStd);

        int featuresSize = featureArray == null ? 0 : featureArray.length;

        for (int i = 0; i < featureBaseType.subLen; i++) {
            if (i < featuresSize) {
                String valueStd = featureBaseType.std(featureArray[i]);

                if (valueStd != null) {
                    Float idx = Float.parseFloat(valueStd);
                    ret.add(idx);
                } else {
                    ret.add(0f);
                }

            } else {
                ret.add(0f);
            }
        }
        return ret;
    }

}
