package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.constant.DPAConstant;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.CandidateActivityDo;
import cn.com.duiba.nezha.alg.feature.vo.PrizeDo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

import static cn.com.duiba.nezha.alg.alg.dpa.PackageRecall.ucbPackage;

/**
 * @author: tanyan
 * @date: 2020/7/1
 */
public class AdReplacement {
    private static final Logger logger= LoggerFactory.getLogger(AdReplacement.class);

    private static final int AD_RANDOM_RANGE = 10;

    /**
     * 实时券替换接口
     * @param activities 传入海选组装的活动列表
     * @param adPks 传入实时获取的可投广告列表，pk的ID对象的adPrizeId不能为空，prizeId属性为null, adPrizeTagId和adPrizeType至少一个不为空
     * @return
     */
    public static List<CandidateActivityDo> recallAdAndReplacePrize(List<CandidateActivityDo> activities, List<PackageInfoDo> adPks){
        if(AssertUtil.isAnyEmpty(activities)){
            logger.error("AdReplacement recallAdAndReplacePrize input params activities is null");
            return activities;
        }

        if(AssertUtil.isAnyEmpty(adPks)){
            logger.warn("AdReplacement recallAdAndReplacePrize input params adPks is null");
            return activities;
        }

        Map<Long,String> adIdx2Tag = PrizeRecall.preprocessAdList(adPks, null);

        //ucb筛选
        List<PackageRecallDo> recallAdList = recallAdPrize(adPks);
        List<PackageRecallDo> sortedAdPks = recallAdList.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());

//        List<CandidateActivityDo> result = replaceOverSign(activities, sortedAdPks);
        List<CandidateActivityDo> result = replaceOverAll(activities, sortedAdPks);
        return result;
    }

    /**
     * 筛选出recallAdPrizeNum个候选奖品，默认数量为AD_PRIZE_NUM_AT_MOST
     * @param pks 可投券列表
     * @return EE筛选的券列表
     */
    protected static List<PackageRecallDo> recallAdPrize(List<PackageInfoDo> pks){
        int topN = Math.min(DPAConstant.AD_PRIZETOPN, pks.size());
        return ucbPackage(pks, topN, DPAConstant.RECALLPRIZE, 2);
    }

    private static List<CandidateActivityDo> replaceOverAll(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        int idx = 0;
        for(CandidateActivityDo act : activities){
            CandidateActivityDo newAct = deepCloneAct(act);
            doReplacement(adPks, newAct.getPrizeGroup(), tagInfo,(idx++ % 4));
            result.add(newAct);
        }

        return result;
    }

    private static CandidateActivityDo deepCloneAct(CandidateActivityDo act) {
        CandidateActivityDo retAct = new CandidateActivityDo();

        retAct.setActivityId(act.getActivityId());
        retAct.setActivityMatchScore(act.getActivityMatchScore());
        retAct.setActPackageType(act.getActPackageType());

        retAct.setPrizeGroup(new ArrayList<>(act.getPrizeGroup()));

        retAct.setRoutineActFeature(act.getRoutineActFeature());

        retAct.setSkinFeature(act.getSkinFeature());
        retAct.setSkinId(act.getSkinId());

        retAct.setSubTitleFeature(act.getSubTitleFeature());
        retAct.setSubTitleId(act.getSubTitleId());
        retAct.setSubTitleTagId(act.getSubTitleTagId());
        retAct.setSubTitleMatchScore(act.getSubTitleMatchScore());

        retAct.setTitleId(act.getTitleId());
        retAct.setTitleMatchScore(act.getTitleMatchScore());
        retAct.setMainTitleFeature(act.getMainTitleFeature());
        retAct.setTitleTagId(act.getTitleTagId());
        return retAct;
    }

    private static List<CandidateActivityDo> replaceOverSign(List<CandidateActivityDo> activities,List<PackageRecallDo> adPks){
        List<CandidateActivityDo> result = new ArrayList<>(activities.size());
        AdTagInfo tagInfo = getAdTagInfo(adPks);

        //用广告替换奖品，一半换1个，一半换2个，2个情况下类型不一样
        GroupInfo groupInfo = extractPrizeGroupInfo(activities);
        int idx = 0;
        for(List<PrizeDo> group : groupInfo.groups.values()){
            doReplacement(adPks, group, tagInfo,(idx++ % 2) + 1);
        }

        //替换group
        for(String sign : groupInfo.groups.keySet()){
            List<CandidateActivityDo> acts = groupInfo.groupBuckets.get(sign);
            List<PrizeDo> group = new ArrayList<>(groupInfo.groups.get(sign));

            for(CandidateActivityDo act : acts){
                CandidateActivityDo newAct = deepCloneAct(act);
                newAct.setPrizeGroup(group);
                result.add(act);
            }
        }
        return result;
    }

    private static void doReplacement(List<PackageRecallDo> adPks, List<PrizeDo> group, AdTagInfo tagInfo , int k) {
        if(k < 2 || tagInfo.tagBucket.size() == 0){
            return;
        }
//        一半不替换，四分之一替换一个，四份之一替换两个
        k -= 1;

        //暂时尽可能限制每个tag下只召回一个，如果k比类型数大，减小实际替换数k
        k = Math.min(k, tagInfo.tagBucket.size());

        List<PackageRecallDo> ads = new ArrayList<>(k);
        int range = Math.min(adPks.size(), DPAConstant.AD_PRIZETOPN);
        if(k == 1){
            range = Math.min(adPks.size(), 4);
            PackageRecallDo ad = PrizeRecall.weightedRandomPickByMatchscore(adPks.subList(0, range));
            ads.add(ad);
        }else if(k == 2 && tagInfo.tagBucket.size() == 1){
            PackageRecallDo top1 = adPks.get(0);
            PackageRecallDo top2Random = PrizeRecall.weightedRandomPickByMatchscore(adPks.subList(1, range));
            ads.add(top1);
            ads.add(top2Random);
        }else if(k > tagInfo.tagBucket.size()){
            logger.warn("召回的tag基数小于目标替换数，每个tag下召回一个改为全局随机召回");
            ads =  fetchAdFromWeightedList(adPks, k);
        }else{
            ads = fetchAdFromWeightedPickedTag(tagInfo, k);
        }

        for (int i = 0; i < ads.size(); i++) {
            PrizeDo pd = prizeDoFromAd(ads.get(i));
            group.add(i, pd);
        }
    }

    private static PrizeDo prizeDoFromAd(PackageRecallDo ad){
        PrizeDo pd = new PrizeDo(true, ad.getPackageInfoDo().getPackageIdDo().getAdPrizeId(), ad.getPackageInfoDo().getPackageIdDo().getPrizeTagId(),ad.getMatchScore());
        pd.setAdTagId(ad.getPackageInfoDo().getPackageIdDo().getAdPrizeType());
        return pd;
    }

    private static List<PackageRecallDo> fetchAdFromWeightedList(List<PackageRecallDo> adPks, int k){
        return PrizeRecall.fetchPkrdFromWeightedList(adPks, k);
    }

    private static List<PackageRecallDo> fetchAdFromWeightedPickedTag(AdTagInfo adTagInfo, int k){
        List<PackageRecallDo> result = new ArrayList<>();
        List<Long> choosenAdTags = PrizeRecall.weightedPickTags(adTagInfo.tagWeight, k);
        for(Long tagId: choosenAdTags){
//            PackageRecallDo pkrd = PrizeRecall.weightedRandomPickByMatchscore(adTagInfo.tagBucket.get(tagId));
            PackageRecallDo pkrd = adTagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    /**
     * 抽取group信息
     * @param activities 海选组装的活动列表
     * @return 抽取的奖品组合信息
     */
    public static GroupInfo extractPrizeGroupInfo(List<CandidateActivityDo> activities){
        Map<String, List<CandidateActivityDo>> groupBuckets = new HashMap<>();
        Map<String,List<PrizeDo>> groups = new HashMap<>();
        for(CandidateActivityDo act : activities){
            List<PrizeDo> group = act.getPrizeGroup();
            String signature = PrizeRecall.groupSignature(group);
            if(groups.containsKey(signature)){
                groupBuckets.get(signature).add(act);
            }else{
                List<CandidateActivityDo> bucket = new ArrayList<>();
                bucket.add(act);
                groupBuckets.put(signature, bucket);
                groups.put(signature,group);
            }
        }
        return new GroupInfo(groups, groupBuckets);
    }

    private static class GroupInfo{
        public Map<String, List<PrizeDo>> groups;
        public Map<String, List<CandidateActivityDo>> groupBuckets;

        public GroupInfo(Map<String, List<PrizeDo>> groups, Map<String, List<CandidateActivityDo>> groupBuckets){
            this.groups = groups;
            this.groupBuckets = groupBuckets;
        }
    }

    private static AdTagInfo getAdTagInfo(List<PackageRecallDo> prds){
        Map<Long, List<PackageRecallDo>> tagBuckets = new HashMap<>();
        Map<Long, Double> tagWeight = new HashMap<>();
        for (PackageRecallDo pkrd : prds) {
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (tagBuckets.containsKey(prizeTagId)) {
                tagBuckets.get(prizeTagId).add(pkrd);
            } else {
                List<PackageRecallDo> bucket = new ArrayList<>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId, bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId, 0.0d) + pkrd.getMatchScore());
        }
        return new AdTagInfo(tagBuckets,tagWeight);
    }

    private static class AdTagInfo{
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;
        public AdTagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight){
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }
}
