package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType2;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import lombok.Data;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 单模型+单编码器版本（二期可支持扩展 同场景，多模型使用同一个编码器，降低内存使用）
 */
@Data
public class DeepBaseModel {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DeepBaseModel.class);

    private int FBT_MAX_SIZE = 64;

    private int PB_MAX_SIZE = 128;

    public static Map<String, Integer> tmpIntMap = new HashMap<>();

    public static Map tmpMap = new HashMap<>();


    /**
     * 模型预估，获取稠密编码
     *
     * @param featureMaps       特征对象
     * @param coderCache        全局缓存对象
     * @param hasSubCache       是否开启请求级别缓存
     * @param hasMultValueCache 多值特征是否开启缓存
     * @return
     * @throws Exception
     */
    public <T> Map<T, List<Integer>> getPredDenseCodes(Map<T, FeatureMapDo> featureMaps,
                                                       Map<String, Map<String, List<Integer>>> coderCache,
                                                       String modelId,
                                                       boolean hasSubCache, boolean hasMultValueCache, List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        Map<T, List<Integer>> ret = new HashMap<>(PB_MAX_SIZE);


        // 判断是否需要请求级别缓存
        Map<String, Map<String, List<Integer>>> subCoderCache = null;
        if (hasSubCache) {
            coderCache = new HashMap<>();
            hasMultValueCache = true;
        }

        if (AssertUtil.isNotEmpty(featureMaps)) {

            for (Map.Entry<T, FeatureMapDo> entry : featureMaps.entrySet()) {
                FeatureMapDo featureMapDo = entry.getValue();
                ret.put(entry.getKey(), getDenseCode(featureMapDo.staticFeatureMap, featureMapDo.dynamicFeatureMap, coderCache, hasMultValueCache, modelId, featureList, featureDenseCoderMap));
            }

        }
        if (AssertUtil.isEmpty(ret)) {
            logger.warn(modelId + " getCodes is invalid, featureMap is null or {}");
        }
        return ret;

    }

    /**
     * 获取稠密编码
     *
     * @param featureMaps
     * @param coderCache
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<T, List<Integer>> getSampleDenses(Map<T, Map<String, String>> featureMaps,
                                                     Map<String, Map<String, List<Integer>>> coderCache,
                                                     String modelId,
                                                     List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        Map<T, List<Integer>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMaps)) {

            for (Map.Entry<T, Map<String, String>> entry : featureMaps.entrySet()) {
                ret.put(entry.getKey(), getDenseCode(entry.getValue(), null, coderCache, false, modelId, featureList, featureDenseCoderMap));
            }
        }
        if (AssertUtil.isEmpty(ret)) {
            logger.warn(modelId + " getCodes is invalid, featureMap is null or {}");
        }
        return ret;

    }


    /**
     * 获取稠密编码
     *
     * @param featureMap
     * @param coderCache
     * @param
     * @return
     * @throws Exception
     */
    public List<Integer> getSampleDense(Map<String, String> featureMap,
                                        Map<String, Map<String, List<Integer>>> coderCache,
                                        String modelId,
                                        List<FeatureBaseType2> featureList, Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        return getDenseCode(featureMap, null, coderCache, false, modelId, featureList, featureDenseCoderMap);
    }

    /**
     * 编码，带缓存
     *
     * @param staticFeatureMap  静态特征
     * @param dynamicFeatureMap 动态特征
     * @param coderCache        编码缓存
     * @return
     * @throws Exception
     */
    public List<Integer> getDenseCode(Map<String, String> staticFeatureMap,
                                      Map<String, String> dynamicFeatureMap,
                                      Map<String, Map<String, List<Integer>>> coderCache,
                                      boolean hasStaticMultValueCache,
                                      String modelId,
                                      List<FeatureBaseType2> featureList,
                                      Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        List<Integer> ret = new ArrayList<>();

        try {
            if (staticFeatureMap == null) {
                staticFeatureMap = tmpMap;
            }

            if (dynamicFeatureMap == null) {
                dynamicFeatureMap = tmpMap;
            }


            for (FeatureBaseType2 featureBaseType : featureList) {

                List<Integer> retSub = new ArrayList<>();

                String value = dynamicFeatureMap.get(featureBaseType.featureName) == null ? staticFeatureMap.get(featureBaseType.featureName) : dynamicFeatureMap.get(featureBaseType.featureName);

                /**
                 * 带缓存
                 */

                if (coderCache != null && (featureBaseType.subLen > 1 || hasStaticMultValueCache)) {
                    Map<String, List<Integer>> cache = coderCache.getOrDefault(featureBaseType.featureName, tmpMap);

                    if (!cache.containsKey(value)) {
                        cache.put(value, getFieldDenseCode(featureBaseType, value, featureDenseCoderMap));
                    }

                    retSub = cache.get(value);

                } else {
                    /**
                     * 不带缓存
                     */

                    retSub = getFieldDenseCode(featureBaseType, value, featureDenseCoderMap);

                }

                ret.addAll(retSub);
            }

        } catch (Exception e) {
            logger.warn(modelId + " getDenseCode ", e);
        }


        return ret;
    }

    /**
     * 空值处理
     * 1、单值特征：默认填充0
     * 2、多值特征：默认填充0（模型做avg pool处理）
     *
     * @param featureBaseType
     * @param featureValue
     * @return
     * @throws Exception
     */
    public List<Integer> getFieldDenseCode(FeatureBaseType2 featureBaseType,
                                           String featureValue,
                                           Map<String, Map<String, Integer>> featureDenseCoderMap) throws Exception {
        List<Integer> ret = new ArrayList<>();
        String featureValueStd = featureBaseType.std(featureValue);

        if (featureBaseType.subLen == 1) {
            ret.add(featureDenseCoderMap.getOrDefault(featureBaseType.featureName, tmpIntMap).getOrDefault(featureValueStd, 0));
        }

        if (featureBaseType.subLen > 1) {
            String[] featureArray = featureBaseType.split(featureValueStd);

            int featuresSize = featureArray == null ? 0 : featureArray.length;

            for (int i = 0; i < featureBaseType.subLen; i++) {
                if (i < featuresSize) {
                    String valueStd = featureBaseType.std(featureArray[i]);
                    ret.add(featureDenseCoderMap.getOrDefault(featureBaseType.featureName, tmpIntMap).getOrDefault(valueStd, 0));
                } else {
                    ret.add(0);
                }
            }
        }

        return ret;
    }


}
