package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.vo.dpa.CandidateActivityDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PrizeDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @author: tanyan
 * @date: 2020/7/1
 */
public class AdReplacement {
    private static final Logger logger= LoggerFactory.getLogger(AdReplacement.class);

    /**
     * 实时券替换接口
     * @param activities 传入海选组装的活动列表
     * @param adPks 传入实时获取的可投广告列表，pk的ID对象的adPrizeId不能为空，prizeId属性为null, adPrizeTagId和adPrizeType至少一个不为空
     * @return
     */
    public static List<CandidateActivityDo> recallAdAndReplacePrize(List<CandidateActivityDo> activities, List<PackageInfoDo> adPks){
        if(AssertUtil.isAnyEmpty(activities)){
            logger.error("AdReplacement recallAdAndReplacePrize input params activities is null");
            return activities;
        }

        if(AssertUtil.isAnyEmpty(adPks)){
            logger.warn("AdReplacement recallAdAndReplacePrize input params adPks is null");
            return activities;
        }

        Map<String,Long> adTag2Idx = PrizeRecall.preprocessAdList(adPks);

        //ucb筛选
        List<PackageRecallDo> recallAdList = PrizeRecall.recallAdPrize(adPks);
        List<PackageRecallDo> sortedAdPks = recallAdList.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());

//        List<CandidateActivityDo> result = replaceOverSign(activities, sortedAdPks);
        List<CandidateActivityDo> result = replaceOverAll(activities, sortedAdPks);

        return result;
    }

    private static List<CandidateActivityDo> replaceOverAll(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        int idx = 0;
        for(CandidateActivityDo act : activities){
            if( idx++ % 2 ==0){
                doReplacement(adPks, act.getPrizeGroup(), tagInfo,1);
            }else{
                doReplacement(adPks, act.getPrizeGroup(), tagInfo,2);
            }
            result.add(act);
        }

        return result;
    }

    private static List<CandidateActivityDo> replaceOverSign(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        //用广告替换奖品，一半换1个，一半换2个，2个情况下类型不一样
        GroupInfo groupInfo = extractPrizeGroupInfo(activities);
        int idx = 0;
        for(List<PrizeDo> group : groupInfo.groups.values()){
            if( idx++ % 2 ==0){
                doReplacement(adPks, group, tagInfo,1);
            }else{
                doReplacement(adPks, group, tagInfo,2);
            }
        }

        //替换group
        for(String sign : groupInfo.groups.keySet()){
            List<CandidateActivityDo> acts = groupInfo.groupBuckets.get(sign);
            List<PrizeDo> group = groupInfo.groups.get(sign);
            for(CandidateActivityDo act : acts){
                act.setPrizeGroup(group);
                result.add(act);
            }
        }
        return result;
    }

    private static void doReplacement(List<PackageRecallDo> adPks, List<PrizeDo> group, AdTagInfo tagInfo , int k) {
        //暂时强制每个tag下只召回一个
        k = Math.min(k, tagInfo.tagBucket.size());

        List<PackageRecallDo> ads;
        if(k > tagInfo.tagBucket.size()){
            logger.warn("召回的tag基数小于目标替换数，每个tag下召回一个改为全局随机召回");
            ads =  fetchAdFromWeightedList(adPks, k);
        }else{
            ads = fetchAdFromWeightedPickedTag(tagInfo, k);
        }

        for (int i = 0; i < ads.size(); i++) {
            PackageRecallDo ad = ads.get(i);
            PrizeDo pd = new PrizeDo(true, ad.getPackageInfoDo().getPackageIdDo().getAdPrizeId(),ad.getMatchScore());
            group.add(i, pd);
        }
    }

    private static List<PackageRecallDo> fetchAdFromWeightedList(List<PackageRecallDo> adPks, int k){
        return PrizeRecall.fetchPkrdFromWeightedList(adPks, k);
    }

    private static List<PackageRecallDo> fetchAdFromWeightedPickedTag(AdTagInfo adTagInfo, int k){
        List<PackageRecallDo> result = new ArrayList<>();
        List<Long> choosenAdTags = PrizeRecall.weightedPickTags(adTagInfo.tagWeight, k);
        for(Long tagId: choosenAdTags){
//            PackageRecallDo pkrd = PrizeRecall.weightedRandomPickByMatchscore(adTagInfo.tagBucket.get(tagId));
            PackageRecallDo pkrd = adTagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    /**
     * 抽取group信息
     * @param activities 海选组装的活动列表
     * @return 抽取的奖品组合信息
     */
    public static GroupInfo extractPrizeGroupInfo(List<CandidateActivityDo> activities){
        Map<String, List<CandidateActivityDo>> groupBuckets = new HashMap<>();
        Map<String,List<PrizeDo>> groups = new HashMap<>();
        for(CandidateActivityDo act : activities){
            List<PrizeDo> group = act.getPrizeGroup();
            String signature = PrizeRecall.groupSignature(group);
            if(groups.containsKey(signature)){
                groupBuckets.get(signature).add(act);
            }else{
                List<CandidateActivityDo> bucket = new ArrayList<>();
                bucket.add(act);
                groupBuckets.put(signature, bucket);
                groups.put(signature,group);
            }
        }
        return new GroupInfo(groups, groupBuckets);
    }

    private static class GroupInfo{
        public Map<String, List<PrizeDo>> groups;
        public Map<String, List<CandidateActivityDo>> groupBuckets;

        public GroupInfo(Map<String, List<PrizeDo>> groups, Map<String, List<CandidateActivityDo>> groupBuckets){
            this.groups = groups;
            this.groupBuckets = groupBuckets;
        }
    }

    private static AdTagInfo getAdTagInfo(List<PackageRecallDo> prds){
        Map<Long, List<PackageRecallDo>> tagBuckets = new HashMap<>();
        Map<Long, Double> tagWeight = new HashMap<>();
        for (PackageRecallDo pkrd : prds) {
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (tagBuckets.containsKey(prizeTagId)) {
                tagBuckets.get(prizeTagId).add(pkrd);
            } else {
                List<PackageRecallDo> bucket = new ArrayList<>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId, bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId, 0.0d) + pkrd.getMatchScore());
        }
        return new AdTagInfo(tagBuckets,tagWeight);
    }

    private static class AdTagInfo{
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;
        public AdTagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight){
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }
}
