package cn.com.duiba.nezha.alg.api.facade.recall.channel;

import cn.com.duiba.nezha.alg.api.dto.base.BaseAdvertDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.*;
import cn.com.duiba.nezha.alg.api.enums.RecallEnums.RecallChannelType;
import cn.com.duiba.nezha.alg.api.facade.recall.AbstractRecallChanelFacade;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import com.google.common.base.Joiner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 协同过滤召回
 * @author lijicong
 * @since 2021-09-07
 */
public class CollabFilterRecallChannelFacadeImpl extends AbstractRecallChanelFacade<CollabFilterAdvertDTO, CollabFilterRequest>
        implements CollabFilterRecallChannelFacade {

    static Logger logger = LoggerFactory.getLogger(CollabFilterRecallChannelFacadeImpl.class);

    private static class Constance {
        private static final String prefix = "NZ_A12";
    }

    @Override
    public RecallChannelType getType() {
        return RecallChannelType.collFilter;
    }

    /**
     * 生成协同过滤召回的结果
     * @param collabFilterAdvertDTOMap 哪吒配置列表
     * @param collabFilterRequest redis中的根据统计模型生成的数据
     * @return 协同过滤召回的结果
     */
    @Override
    public Map<String, RecallChannelResult> execute(Map<String, CollabFilterAdvertDTO> collabFilterAdvertDTOMap,
                                                    CollabFilterRequest collabFilterRequest) {

        Map<String, RecallChannelResult> res = new HashMap<>();

        if (AssertUtil.isAnyEmpty(collabFilterAdvertDTOMap, collabFilterRequest)) {
            return res;
        }

        Map<String, List<CollabFilterRedisDTO>> colFilterRecallInitRes = collabFilterRequest.getRedisMap();
        Map<String, Integer> nums = collabFilterRequest.getNum();
        Map<String, Double> threshold = collabFilterRequest.getThreshold();
        Set<Long> advertIds = collabFilterAdvertDTOMap.values().stream().map(CollabFilterAdvertDTO::getAdvertId).collect(Collectors.toSet());
        Map<Long, List<Long>> adOrientMap = collabFilterAdvertDTOMap.values().stream().collect(Collectors.groupingBy(BaseAdvertDTO::getAdvertId, Collectors.mapping(BaseAdvertDTO::getPackageId, Collectors.toList())));

        // 1.与nezha取交集，并对数据格式进行转换
        Map<String, Map<Long, Map<String, Double>>> helperMap = new HashMap<>();
        for (Map.Entry<String, List<CollabFilterRedisDTO>> entry: colFilterRecallInitRes.entrySet()) {
            String key = entry.getKey();                    // redis的键，eg: NZ_A12_weight2_cf_136
            String curIndex = key.split("_")[2];      //当前指标，eg：weight2等
            List<CollabFilterRedisDTO> curColFilterRedisDTOList = entry.getValue();

            if (AssertUtil.isNotEmpty(curColFilterRedisDTOList)) {
                // 1.1 取交集
                List<CollabFilterRedisDTO> retained = curColFilterRedisDTOList.stream()
                        .filter(pkg -> pkg != null && advertIds.contains(pkg.getAdvertId()))
                        .collect(Collectors.toList());

                // 1.2 数据格式转换
                for (CollabFilterRedisDTO ig: retained) {
                    Map<Long, Map<String, Double>> curAdRes = new HashMap<>();
                    Map<String, Double> curInnerRes = getInnerVal(ig, curIndex, key);
                    Long CurAdId = ig.getAdvertId();
                    curAdRes.put(CurAdId, curInnerRes);
                    if (helperMap.containsKey(curIndex)) {
                        if (helperMap.get(curIndex).containsKey(CurAdId)) {
                            helperMap.get(curIndex).get(CurAdId).putAll(curInnerRes);
                        }else {
                            helperMap.get(curIndex).putAll(curAdRes);
                        }
                    }else {
                        helperMap.put(curIndex, curAdRes);
                    }
                }
            }
        }

        // 2.对各指标下的数据进行排序、截断，生成该通道最终召回结果
        for (Map.Entry<String, Map<Long, Map<String, Double>>> entry: helperMap.entrySet()) {
            List<RecallChannelResult> innerTmpRes;
            List<RecallChannelResult> curIndexRecallResList = new ArrayList<>();
            String index = entry.getKey();      //指标
            Map<Long, Map<String, Double>> advertRecallResMap = entry.getValue();

            // 2.1 对每个广告计划计算其最终得分
            for (Map.Entry<Long, Map<String, Double>> entry1: advertRecallResMap.entrySet()) {
                RecallChannelResult curRecallRes = new RecallChannelResult();
                Long advertId = entry1.getKey(); //广告计划
                Map<String, Double> advertBeRecalledRes = entry1.getValue();     //各维度上召回情况
                String slotKey = getKey(index, collabFilterRequest.getSlotId().toString());
                Double finalScore = advertBeRecalledRes.getOrDefault(slotKey, 0.0D);;
                curRecallRes.setAdvertId(advertId);
                curRecallRes.setScore(finalScore);
                curIndexRecallResList.add(curRecallRes);
            }

            // 2.2 对当前指标下的广告计划进行排序、截断
            innerTmpRes = curIndexRecallResList.stream()
                    .filter(curIndexRecallRes -> curIndexRecallRes.getScore() > threshold.get(index))   //按照得分阈值进行过滤
                    .sorted(Comparator.comparing(RecallChannelResult::getScore).reversed())             //按照通道得分进行降序排
                    .limit(nums.get(index))                                                             //按照通道最大召回长度进行截断
                    .collect(Collectors.toList());

            // 2.3 拼接广告对应的配置信息生成最终的结果，并将结果添加到res中
            for (RecallChannelResult rcr: innerTmpRes) {
                RecallChannelResult finalOneRecallRes = new RecallChannelResult();
                List<Long> orientList = adOrientMap.get(rcr.getAdvertId());
                for (Long orient: orientList) {
                    finalOneRecallRes.setAdvertId(rcr.getAdvertId());
                    finalOneRecallRes.setPackageId(orient);
                    finalOneRecallRes.setScore(rcr.getScore());
                    finalOneRecallRes.setIndex(index);
                    res.put(finalOneRecallRes.buildKey(), finalOneRecallRes);
                }
            }
        }

        return res;
    }

    /**
     * 获取对应指标下的得分
     * @param collabFilterRedisDTO   redis中的数据
     * @param type  指标类型
     * @param key   redis的键
     * @return 对应指标下的得分
     */
    private Map<String, Double> getInnerVal(CollabFilterRedisDTO collabFilterRedisDTO, String type, String key) {
        Map<String, Double> res = new HashMap<>();

        if (AssertUtil.isAnyEmpty(collabFilterRedisDTO, type, key)) {
            return res;
        }

        if ("ctr".equals(type)) {
            res.put(key, collabFilterRedisDTO.getCtrScore());
        }else if ("cvr0".equals(type)) {
            res.put(key, collabFilterRedisDTO.getCvrScore0());
        }else if ("cvr1".equals(type)) {
            res.put(key, collabFilterRedisDTO.getCvrScore1());
        }else if ("dCvr".equals(type)) {
            res.put(key, collabFilterRedisDTO.getDCvrScore());
        }else if ("consume".equals(type)) {
            res.put(key, collabFilterRedisDTO.getConsumeScore());
        }else if ("weight0".equals(type)) {
            res.put(key, collabFilterRedisDTO.getWeightScore0());
        }else if ("weight1".equals(type)) {
            res.put(key, collabFilterRedisDTO.getWeightScore1());
        }else if ("weight2".equals(type)) {
            res.put(key, collabFilterRedisDTO.getWeightScore2());
        }else {
            logger.info("Wrong input type. Please enter type in array('ctr', 'cvr0', 'cvr1', 'dCvr', 'consume', 'weight0', 'weight1', 'weight2')");
            System.out.println("Wrong input type. Please enter type in array('ctr', 'cvr0', 'cvr1', 'dCvr', 'consume', 'weight0', 'weight1', 'weight2')");
        }

        return res;
    }

    /**
     * 获取redis的key
     * @param index 指标
     * @param dataDim 广告位ID
     * @return redis的key
     */
    private String getKey(String index, String dataDim) {
        return AssertUtil.isAnyEmpty(index, dataDim) ? null : Joiner.on("_").join(CollabFilterRecallChannelFacadeImpl.Constance.prefix, index, "cf", dataDim);
    }

}
