package cn.com.duiba.nezha.alg.api.facade.recall.channel;

import cn.com.duiba.nezha.alg.api.dto.base.BaseAdvertDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.AppIndustryGeneralAdvertDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.AppIndustryGeneralRedisDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.AppIndustryGeneralRequest;
import cn.com.duiba.nezha.alg.api.dto.recall.RecallChannelResult;
import cn.com.duiba.nezha.alg.api.enums.RecallEnums.RecallChannelType;
import cn.com.duiba.nezha.alg.api.facade.recall.AbstractRecallChanelFacade;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import com.google.common.base.Joiner;
import org.apache.commons.collections.Bag;
import org.apache.commons.collections.bag.HashBag;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 媒体行业泛化召回
 * @author lijicong
 * @since 2021-09-07
 */
public class AppIndustryGeneralRecallChannelFacadeImpl extends AbstractRecallChanelFacade<AppIndustryGeneralAdvertDTO, AppIndustryGeneralRequest>
        implements AppIndustryGeneralRecallChannelFacade {

    static Logger logger = LoggerFactory.getLogger(AppIndustryGeneralRecallChannelFacadeImpl.class);

    //todo 此处静态类中的属性值稍后需要做成可配置的
    private static class Constance {
        private static final String prefix = "NZ_A4";   // 与其他通道中的该属性值可以一起做成一个枚举类
        static Double weight3Day = 0.96D;
        static Double weight1Day = 0.98D;
        static Double weight1Hour = 1.0D;
    }

    @Override
    public RecallChannelType getType() {
        return RecallChannelType.AppIndustryGeneral;
    }

    /**
     * 生成媒体行业泛化召回的结果
     * @param appIndustryGeneralAdvertDTOMap 哪吒配置列表
     * @param appIndustryGeneralRequest redis中的根据统计模型生成的数据
     * @return 媒体行业泛化召回的结果
     */
    @Override
    public Map<String, RecallChannelResult> execute(Map<String, AppIndustryGeneralAdvertDTO> appIndustryGeneralAdvertDTOMap,
                                                    AppIndustryGeneralRequest appIndustryGeneralRequest) {

        Map<String, RecallChannelResult> res = new HashMap<>();     //此处哈希表的键为'广告ID_配置ID'，eg：379245_138832

        if (AssertUtil.isAnyEmpty(appIndustryGeneralAdvertDTOMap, appIndustryGeneralRequest)) {
            return res;
        }

        Map<String, List<AppIndustryGeneralRedisDTO>> appIndustryGeneRecallInitRes = appIndustryGeneralRequest.getRedisMap();  //此处哈希表的键为指标值
        Map<String, Integer> nums = appIndustryGeneralRequest.getNum();
        Map<String, Double> threshold = appIndustryGeneralRequest.getThreshold();
        Set<Long> advertIds = appIndustryGeneralAdvertDTOMap.values().stream().map(AppIndustryGeneralAdvertDTO::getAdvertId).collect(Collectors.toSet());  //nehza中本订单全量的广告列表
        Map<Long, List<Long>> adOrientMap = appIndustryGeneralAdvertDTOMap.values().stream().collect(Collectors.groupingBy(BaseAdvertDTO::getAdvertId, Collectors.mapping(BaseAdvertDTO::getPackageId, Collectors.toList())));  //获取广告下对应的配置列表

        // 1.与nezha取交集，并数据处理成方便之后截断、加权的格式
        Map<String, Map<Long, Map<String, Double>>> helperMap = new HashMap<>();
        for (Map.Entry<String, List<AppIndustryGeneralRedisDTO>> entry: appIndustryGeneRecallInitRes.entrySet()) {
            String key = entry.getKey();
            String curIndex = key.split("_")[2];      // 当前指标，eg：cvr0、weight0等
            List<AppIndustryGeneralRedisDTO> curAppIndustryGeneralRedisDTOList = entry.getValue();

            if (AssertUtil.isNotEmpty(curAppIndustryGeneralRedisDTOList)) {
                // 1.1 取交集
                List<AppIndustryGeneralRedisDTO> retained = curAppIndustryGeneralRedisDTOList.stream()
                        .filter(pkg -> pkg != null && advertIds.contains(pkg.getAdvertId()))
                        .collect(Collectors.toList());

                // 1.2 数据格式转换
                for (AppIndustryGeneralRedisDTO ig: retained) {
                    Map<Long, Map<String, Double>> curAdRes = new HashMap<>();
                    Map<String, Double> curInnerRes = getInnerVal(ig, curIndex, key);
                    Long CurAdId = ig.getAdvertId();
                    curAdRes.put(CurAdId, curInnerRes);
                    if (helperMap.containsKey(curIndex)) {
                        if (helperMap.get(curIndex).containsKey(CurAdId)) {
                            helperMap.get(curIndex).get(CurAdId).putAll(curInnerRes);
                        }else {
                            helperMap.get(curIndex).putAll(curAdRes);
                        }
                    }else {
                        helperMap.put(curIndex, curAdRes);
                    }
                }
            }
        }

        // 2.对各指标下的数据进行加权，生成该通道最终召回结果
        for (Map.Entry<String, Map<Long, Map<String, Double>>> entry: helperMap.entrySet()) {
            List<RecallChannelResult> innerTmpRes;
            List<RecallChannelResult> curIndexRecallResList = new ArrayList<>();
            String index = entry.getKey();      //指标
            Map<Long, Map<String, Double>> advertRecallResMap = entry.getValue();

            // 2.1 对每个广告计划计算其最终得分
            for (Map.Entry<Long, Map<String, Double>> entry1: advertRecallResMap.entrySet()) {
                RecallChannelResult curRecallRes = new RecallChannelResult();
                Long advertId = entry1.getKey(); //广告计划
                Map<String, Double> advertBeRecalledRes = entry1.getValue();     //各维度上召回情况
                Double threeDays = getCurDimsScore(advertBeRecalledRes, index, "3day", "global", appIndustryGeneralRequest.getSlotId().toString());
                Double oneDays = getCurDimsScore(advertBeRecalledRes, index, "1day", "global", appIndustryGeneralRequest.getSlotId().toString());;
                Double oneHour = getCurDimsScore(advertBeRecalledRes, index, "1hour", "global", appIndustryGeneralRequest.getSlotId().toString());;;
                Double[] scoreList = new Double[]{oneHour, oneDays, threeDays};
                Double finalScore = getFinalScore(Arrays.asList(scoreList));
                curRecallRes.setAdvertId(advertId);
                curRecallRes.setScore(finalScore);
                curIndexRecallResList.add(curRecallRes);
            }

            // 2.2 对当前指标下的广告计划进行排序、截断
            innerTmpRes = curIndexRecallResList.stream()
                    .filter(curIndexRecallRes -> curIndexRecallRes.getScore() > threshold.get(index))   //按照得分阈值进行过滤
                    .sorted(Comparator.comparing(RecallChannelResult::getScore).reversed())             //按照通道得分进行降序排
                    .limit(nums.get(index))                                                             //按照通道最大召回长度进行截断
                    .collect(Collectors.toList());

            // 2.3 拼接广告对应的配置信息生成最终的结果，并将结果添加到res中
            for (RecallChannelResult rcr: innerTmpRes) {
                RecallChannelResult finalOneRecallRes = new RecallChannelResult();
                List<Long> orientList = adOrientMap.get(rcr.getAdvertId());
                for (Long orient: orientList) {
                    finalOneRecallRes.setAdvertId(rcr.getAdvertId());
                    finalOneRecallRes.setPackageId(orient);
                    finalOneRecallRes.setScore(rcr.getScore());
                    finalOneRecallRes.setIndex(index);
                    res.put(finalOneRecallRes.buildKey(), finalOneRecallRes);
                }
            }
        }

        return res;
    }

    /**
     * 获取对应指标下的得分
     * @param appIndustryGeneralRedisDTO   redis中的数据
     * @param type  指标类型
     * @param key   redis的键
     * @return 对应指标下的得分
     */
    private Map<String, Double> getInnerVal(AppIndustryGeneralRedisDTO appIndustryGeneralRedisDTO, String type, String key) {
        Map<String, Double> res = new HashMap<>();

        if (AssertUtil.isAnyEmpty(appIndustryGeneralRedisDTO, type, key)) {
            return res;
        }

        if ("ctr".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getCtrScore());
        }else if ("cvr0".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getCvrScore0());
        }else if ("cvr1".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getCvrScore1());
        }else if ("dCvr".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getDCvrScore());
        }else if ("consume".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getConsumeScore());
        }else if ("weight0".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getWeightScore0());
        }else if ("weight1".equals(type)) {
            res.put(key, appIndustryGeneralRedisDTO.getWeightScore1());
        }else {
            logger.info("Wrong input type. Please enter type in array('ctr', 'cvr0', 'cvr1', 'dCvr', 'consume', 'weight0', 'weight1')");
            System.out.println("Wrong input type. Please enter type in array('ctr', 'cvr0', 'cvr1', 'dCvr', 'consume', 'weight0', 'weight1')");
        }

        return res;
    }

    /**
     * 获取redis的key
     * @param index 指标
     * @param timeDims  时间维度
     * @param dataDims  数据维度
     * @return  redis的key
     */
    private List<String> getKey(String index, List<String> timeDims, List<String> dataDims) {
        List<String> res = new ArrayList<>();

        if (AssertUtil.isAnyEmpty(index, timeDims, dataDims)) {
            return res;
        }

        for (String time: timeDims) {
            dataDims.forEach(data -> res.add(Joiner.on("_").join(Constance.prefix, index, "app", time, data)));
        }

        return res;
    }


    /**
     * 获取redis的key
     * @param index 指标
     * @param timeDim 时间维度
     * @param dataDim 数据维度
     * @return redis的key
     */
    private String getKey(String index ,String timeDim, String dataDim) {
        return AssertUtil.isAnyEmpty(index, timeDim, dataDim) ? null : Joiner.on("_").join(Constance.prefix, index, "app", timeDim, dataDim);
    }

    /**
     * 获取当前维度上的得分
     * @param advertBeRecalledRes 召回结果数据
     * @param index 指标
     * @param timeDim 时间维度
     * @param global 全局
     * @param slot 广告位
     * @return 当前维度上的得分
     */
    private Double getCurDimsScore(Map<String, Double> advertBeRecalledRes, String index, String timeDim, String global, String slot) {
        if (AssertUtil.isAnyEmpty(advertBeRecalledRes, index, timeDim, global, slot)) {
            return 0.0D;
        }

        String globalKey = getKey(index, timeDim, global);
        String slotKey = getKey(index, timeDim, slot);
        if (advertBeRecalledRes.containsKey(slotKey)) {
            return advertBeRecalledRes.getOrDefault(slotKey, 0.0D);
        }else {
            return advertBeRecalledRes.getOrDefault(globalKey, 0.0D);
        }
    }

    /**
     * 时间维度加权的得分
     * @param scoreList 各时间维度得分组成的列表；注：时间维度从小到大排列
     * @return 通道中该指标下的该广告计划的最终得分
     */
    private Double getFinalScore(List<Double> scoreList) {
        Bag bag = new HashBag(scoreList);
        int nonZero = scoreList.size() - bag.getCount(0.0D);
        return nonZero != 0 ? (Constance.weight1Hour * scoreList.get(0) + Constance.weight1Day * scoreList.get(1) + Constance.weight3Day * scoreList.get(2)) / (double) nonZero : 0.0D;
    }

    /**
     * 测试
     * @return res
     */
    @Test
    public Map<String, RecallChannelResult> test1() {
        Map<String, RecallChannelResult> res = new HashMap<>();
        Map<String, AppIndustryGeneralAdvertDTO> appIndustryGeneralAdvertDTOMap = new HashMap<>();
//        AppIndustryGeneralRequest appIndustryGeneralRequest;

        // 构造输入对象



        return res;

    }

}
