package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.vo.dpa.CandidateActivityDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PrizeDo;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @author: tanyan
 * @date: 2020/7/1
 */
public class AdReplacement {

    public static List<CandidateActivityDo> recallAdAndReplacePrize(List<CandidateActivityDo> activities, List<PackageInfoDo> adPks){

        Map<String,Long> adTag2Idx = PrizeRecall.preprocessAdList(adPks);

        //ucb筛选
        List<PackageRecallDo> recallAdList = PrizeRecall.recallAdPrize(adPks);
        List<PackageRecallDo> sortedAdPks = recallAdList.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());

//        List<CandidateActivityDo> result = replaceOverSign(activities, sortedAdPks);
        List<CandidateActivityDo> result = replaceOverAll(activities, sortedAdPks);

        return result;
    }

    private static List<CandidateActivityDo> replaceOverAll(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<CandidateActivityDo>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        int idx = 0;
        for(CandidateActivityDo act : activities){
            if( idx++ % 2 ==0){
                doReplacement(act.getPrizeGroup(), tagInfo,1);
            }else{
                doReplacement(act.getPrizeGroup(), tagInfo,2);
            }
            result.add(act);
        }

        return result;
    }

    private static List<CandidateActivityDo> replaceOverSign(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<CandidateActivityDo>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        //用广告替换奖品，一半换1个，一半换2个，2个情况下类型不一样
        GroupInfo groupInfo = extractPrizeGroupInfo(activities);
        int idx = 0;
        for(List<PrizeDo> group : groupInfo.groups.values()){
            if( idx++ % 2 ==0){
                doReplacement(group, tagInfo,1);
            }else{
                doReplacement(group, tagInfo,2);
            }
        }

        //替换group
        for(String sign : groupInfo.groups.keySet()){
            List<CandidateActivityDo> acts = groupInfo.groupBuckets.get(sign);
            List<PrizeDo> group = groupInfo.groups.get(sign);
            for(CandidateActivityDo act : acts){
                act.setPrizeGroup(group);
                result.add(act);
            }
        }
        return result;
    }

    private static void doReplacement(List<PrizeDo> group, AdTagInfo tagInfo , int k) {
        k = Math.min(k, tagInfo.tagBucket.size());
        List<PackageRecallDo> ads = fetchAdFromWeightedPickedTag(tagInfo, k);
        for (int i = 0; i < ads.size(); i++) {
            PrizeDo pd = new PrizeDo(true, ads.get(i).getPackageInfoDo().getPackageIdDo().getAdPrizeId());
            group.add(i, pd);
        }
    }

    private static List<PackageRecallDo> fetchAdFromWeightedPickedTag(AdTagInfo adTagInfo, int k){
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>();
        List<Long> choosenAdTags = PrizeRecall.weightedPickTags(adTagInfo.tagWeight, k);
        for(Long tagId: choosenAdTags){
            PackageRecallDo pkrd = PrizeRecall.weightedRandomPickByMatchscore(adTagInfo.tagBucket.get(tagId));
//            PackageRecallDo pkrd = adTagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    /**
     * 抽取group信息
     * @param activities
     * @return
     */
    public static GroupInfo extractPrizeGroupInfo(List<CandidateActivityDo> activities){
        Map<String, List<CandidateActivityDo>> groupBuckets = new HashMap<String, List<CandidateActivityDo>>();
        Map<String,List<PrizeDo>> groups = new HashMap<String,List<PrizeDo>>();
        for(CandidateActivityDo act : activities){
            List<PrizeDo> group = act.getPrizeGroup();
            String signature = PrizeRecall.groupSignature(group);
            if(groups.containsKey(signature)){
                groupBuckets.get(signature).add(act);
            }else{
                List<CandidateActivityDo> bucket = new ArrayList<CandidateActivityDo>();
                bucket.add(act);
                groupBuckets.put(signature, bucket);
                groups.put(signature,group);
            }
        }
        GroupInfo result = new GroupInfo(groups, groupBuckets);
        return result;
    }

    private static class GroupInfo{
        public Map<String, List<PrizeDo>> groups;
        public Map<String, List<CandidateActivityDo>> groupBuckets;

        public GroupInfo(Map<String, List<PrizeDo>> groups, Map<String, List<CandidateActivityDo>> groupBuckets){
            this.groups = groups;
            this.groupBuckets = groupBuckets;
        }
    }

    private static AdTagInfo getAdTagInfo(List<PackageRecallDo> prds){
        Map<Long, List<PackageRecallDo>> tagBuckets = new HashMap<Long, List<PackageRecallDo>>();
        Map<Long, Double> tagWeight = new HashMap<Long, Double>();
        for (PackageRecallDo pkrd : prds) {
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (tagBuckets.containsKey(prizeTagId)) {
                tagBuckets.get(prizeTagId).add(pkrd);
            } else {
                List<PackageRecallDo> bucket = new ArrayList<PackageRecallDo>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId, bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId, 0.0d) + pkrd.getMatchScore());
        }
        return new AdTagInfo(tagBuckets,tagWeight);
    }

    private static class AdTagInfo{
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;
        public AdTagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight){
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }
}
