package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.constant.DPAConstant;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PrizeDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

import static cn.com.duiba.nezha.alg.alg.dpa.PackageRecall.ucbPackage;

/**
 * @author: tanyan
 * @date: 2020/6/19
 */
public class PrizeRecall {
    private static final Logger logger= LoggerFactory.getLogger(PrizeRecall.class);

    private static Integer RECALL_GROUP_MAX_NUM = 10;
    private static Integer TOTAL_TOPK_PICK_RANGE = 100;
    private static Integer PRIZE_NUM_AT_LEAST = 1;
    private static Integer AD_PRIZE_NUM_AT_LEAST = 1;
    private static Integer AD_PRIZE_NUM_AT_MOST = 2;
    private static Integer TRY_NUM_FOR_ONE_GROUP = 3;
    private static Integer TAG_STRAGEY_AD_TYPE_NUM = 2;

    /**
     * 对外海选接口，不需要传广告奖品列表
     * @param pks
     * @param slotNum
     * @return
     */
    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> pks, int slotNum){
        List<PackageInfoDo> adPks = new ArrayList<PackageInfoDo>();
        return matchPrizeAndAd(pks,adPks,slotNum);
    }

//    /**
//     * 对外实时接口，只传广告奖品
//     * @param adPks
//     * @return
//     */
//    public static List<PackageRecallDo> matchAdPrizeOnline(List<PackageInfoDo> adPks){
//        List<PackageInfoDo> pks = new ArrayList<PackageInfoDo>();
//        int slotNum = AD_PRIZE_NUM_AT_MOST;
//        return matchPrize(pks,adPks,slotNum);
//    }

    /**
     * 召回接口
     * @param pks
     * @param adPks
     * @param slotNum
     * @return
     */
    public static List<PackageRecallDo> matchPrizeAndAd(List<PackageInfoDo> pks,
                                                   List<PackageInfoDo> adPks,
                                                   int slotNum){
        if(AssertUtil.isAnyEmpty(pks) && AssertUtil.isAnyEmpty(adPks)){
            logger.error("PackageRecall matchPrize input params prizes and ad is null");
            return null;
        }

//        int prizeRecallNum = Math.min(pks.size(), PRIZE_RECALL_NUM);
//        int adPrizeRecallNum = Math.min(pks.size(), AD_PRIZE_RECALL_NUM);

        //
        Map<String,Long> adTag2Idx = preprocessAdList(adPks);

        List<PackageRecallDo> recallPrizeList = recallPrize(pks);
        List<PackageRecallDo> recallAdList = recallAdPrize(adPks);

        List<PackageRecallDo> result = genPrizeGroup(recallPrizeList, recallAdList, slotNum, RECALL_GROUP_MAX_NUM);

//        List<PackageRecallDo> groupList = generatePrizeGroup(recallList,prizeTypeGraph,groupNum,slotNum);
        return result;
    }

    public static Map<String,Long> preprocessAdList(List<PackageInfoDo> adPks){
        Map<String,Long> adTag2Idx = new HashMap<String,Long>();
        Long offset = 10000l;
        for(PackageInfoDo pk: adPks){
            String typeString = pk.getPackageIdDo().getAdPrizeType();
            Long idx = 0l;
            if(adTag2Idx.containsKey(typeString)){
                idx = adTag2Idx.get(typeString);
            }else{
                idx = adTag2Idx.size() + offset;
                adTag2Idx.put(typeString,idx);
            }
            pk.getPackageIdDo().setPrizeTagId(idx);
        }
        return adTag2Idx;
    }

    private static TagInfo getTagBucket(List<PackageRecallDo> prds){
        Map<Long, List<PackageRecallDo>> tagBuckets = new HashMap<Long, List<PackageRecallDo>>();
        Map<Long, Double> tagWeight = new HashMap<Long, Double>();
        for (int i = 0; i < prds.size(); i++) {
            PackageRecallDo pkrd = prds.get(i);
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if(tagBuckets.containsKey(prizeTagId)){
                tagBuckets.get(prizeTagId).add(pkrd);
            }else{
                List<PackageRecallDo> bucket = new ArrayList<PackageRecallDo>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId,bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId,0.0d) + pkrd.getMatchScore());
        }
        return new TagInfo(tagBuckets,tagWeight);
    }

    private static class TagInfo{
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;
        public TagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight){
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }

    private static List<PackageRecallDo> topKFromMergedPrizeAndAdList(List<PackageRecallDo> pks,
                                                             List<PackageRecallDo> adPks,
                                                             int topk){
        if( pks.size() == 0 && adPks.size() == 0){
            logger.warn("No input need to be merged!");
        }

        int curIdxPK = 0;
        int curIdxAd = 0;
        int totalSize = Math.min(topk, pks.size() + adPks.size());

        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>(totalSize);
        for (int i = 0; i < totalSize; i++) {
            if (curIdxPK == pks.size()) {
                result.addAll(adPks.subList(curIdxAd, curIdxAd + totalSize-i));
                break;
            }
            if (curIdxAd == adPks.size()) {
                result.addAll(pks.subList(curIdxPK, curIdxPK + totalSize-i));
                break;
            }

            if (pks.get(curIdxPK).getMatchScore() > adPks.get(curIdxAd).getMatchScore()) {
                result.add(pks.get(curIdxPK));
                curIdxPK++;
            } else {
                result.add(adPks.get(curIdxAd));
                curIdxAd++;
            }

        }
        return result;
    }

    private static List<PackageRecallDo> genPrizeGroup(List<PackageRecallDo> pks,
                                                       List<PackageRecallDo> adPks,
                                                       Integer slotNum,
                                                       Integer groupNum){
        Set<String>  existingGroupSet = new HashSet<String>();
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>();

        List<PackageRecallDo> sortedPks = pks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());
        List<PackageRecallDo> sortedAdPks = adPks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());

        //生成排序分最好的组合
        PackageRecallDo theBestGroup = genBestScoreGroup(sortedPks, sortedAdPks,slotNum);
        result.add(theBestGroup);
        existingGroupSet.add(groupSignature(theBestGroup));

        //不考虑类型的topk random pk扩充召回集，swap策略
        Integer swapGenNum = (int)Math.ceil(groupNum * 0.4);
        List<PackageRecallDo> swapGroups = genByRandomPick(sortedPks, sortedAdPks, existingGroupSet, slotNum, swapGenNum);
        result.addAll(swapGroups);

        //考虑类型桶的random pk 补充
        Integer tagGenNum = (int)Math.ceil(groupNum * 0.5);
        List<PackageRecallDo> tagGroups = genByPickByTag(sortedPks, sortedAdPks, existingGroupSet, slotNum, tagGenNum);
        result.addAll(tagGroups);

        return result;
    }

    private static List<PackageRecallDo> genByPickByTag(List<PackageRecallDo> pks,
                                                        List<PackageRecallDo> adPks,
                                                        Set<String>  existingGroupSet,
                                                        int slotNum,
                                                        Integer groupNum){
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        TagInfo prizeTagInfo = getTagBucket(pks);
        TagInfo adTagInfo = getTagBucket(adPks);
        for (Integer i = 0; i < groupNum; i++) {
            int tryNum = 0;
            boolean isNewOne = false;
            while (tryNum < TRY_NUM_FOR_ONE_GROUP && isNewOne == false) {
                PackageRecallDo pkrd = genOneGroupByTagInfo(prizeTagInfo, adTagInfo, slotNum);
                String groupSign = groupSignature(pkrd);
                if(!existingGroupSet.contains(groupSign)){
                    existingGroupSet.add(groupSign);
                    result.add(pkrd);
                    isNewOne = true;
                }
                tryNum++;
            }
        }
        return result;
    }

    private static PackageRecallDo genOneGroupByTagInfo(TagInfo prizeTagInfo, TagInfo adTagInfo, int slotNum) {
        PackageRecallDo result = new PackageRecallDo();
        int adNum = Math.min(adTagInfo.tagBucket.size(), AD_PRIZE_NUM_AT_MOST);
        int prizeNum = slotNum - adNum;
        List<PackageRecallDo> prizes = fetchPrizeFromRandomPickedTag(prizeTagInfo, prizeNum);
        List<PackageRecallDo> ads = fetchPrizeFromRandomPickedTag(adTagInfo, adNum);
        List<PrizeDo> group = new ArrayList<PrizeDo>(slotNum);
        for(PackageRecallDo pkrd : ads){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        for(PackageRecallDo pkrd : prizes){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(group);
        return result;
    }

    private static PrizeDo prizeDoFromPackageRecallDo(PackageRecallDo pkrd){
        PackageIdDo curPid = pkrd.getPackageInfoDo().getPackageIdDo();
        boolean isAd = curPid.getAdPrizeId() == null? false : true;
        Long id = isAd ? curPid.getAdPrizeId() : curPid.getPrizeId();
        return new PrizeDo(isAd, id);
    }

    private static List<PackageRecallDo> fetchPrizeFromRandomPickedTag(TagInfo tagInfo, int k){
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>();
        List<Long> choosenAdTags = weightedPickTags(tagInfo.tagWeight, k);
        for(Long tagId: choosenAdTags){
//            PackageRecallDo pkrd = weightedRandomPickByMatchscore(tagInfo.tagBucket.get(tagId));
            PackageRecallDo pkrd = tagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    protected static List<Long> weightedPickTags(Map<Long, Double> tagWeight, int topk){
        List<Long> result = new ArrayList<Long>(topk);
        if(topk <= 0){
            return result;
        }

        if(topk > tagWeight.size()){
            logger.error("k larger than tag numbers");
//            throw new RuntimeException("k larger than tag numbers");
        }

        List<Long> tags = new ArrayList<Long>(tagWeight.size());
        List<Double> weights = new ArrayList<Double>(tagWeight.size());
        for (Long tagId : tagWeight.keySet()) {
            tags.add(tagId);
            weights.add(tagWeight.get(tagId));
        }
        List<Integer> resultIdx = randKWithoutReplacement(weights,topk);
        for(Integer idx : resultIdx){
            result.add(tags.get(idx));
        }
        return result;
    }

    private static List<PackageRecallDo> genByRandomPick(List<PackageRecallDo> pks,
                                                           List<PackageRecallDo> adPks,
                                                           Set<String>  existingGroupSet,
                                                           Integer slotNum,
                                                           Integer groupNum){
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        int mergeNum = Math.min(TOTAL_TOPK_PICK_RANGE, pks.size() + adPks.size());
        List<PackageRecallDo> sortedPrizes = topKFromMergedPrizeAndAdList(pks, adPks, mergeNum);
        for (Integer i = 0; i < groupNum; i++) {
            int tryNum = 0;
            boolean isNewOne = false;
            while (tryNum < TRY_NUM_FOR_ONE_GROUP && isNewOne == false) {
                PackageRecallDo pkrd = randomPickOneGroup(sortedPrizes, slotNum);
                String groupSign = groupSignature(pkrd);
                if(!existingGroupSet.contains(groupSign)){
                    existingGroupSet.add(groupSign);
                    result.add(pkrd);
                    isNewOne = true;
                }
                tryNum++;
            }
        }
        return result;
    }

    private static PackageRecallDo randomPickOneGroup(List<PackageRecallDo> sortedPrizes, Integer slotNum) {
        PackageRecallDo result = new PackageRecallDo();
        List<PrizeDo> group = new ArrayList<PrizeDo>(slotNum);
        List<Double> weights = sortedPrizes.stream().map(PackageRecallDo::getMatchScore).collect(Collectors.toList());
        List<Integer> idxs = randKWithoutReplacement(weights, slotNum);
        for(Integer idx : idxs){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(sortedPrizes.get(idx));
            group.add(prizeDo);
        }
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(group);
        return result;
    }

    private static PackageRecallDo genBestScoreGroup(List<PackageRecallDo> pks,
                                                     List<PackageRecallDo> adPks,
                                                     int slotNum){
        PackageRecallDo pkrd = new PackageRecallDo();
        List<PrizeDo> prizeGroup = new ArrayList<PrizeDo>();

        int adPreSelectNum = Math.min(AD_PRIZE_NUM_AT_LEAST, adPks.size());
        int prizePreSelectNum = Math.min(PRIZE_NUM_AT_LEAST, pks.size());

        for (Integer i = 0; i < adPreSelectNum; i++) {
            prizeGroup.add(new PrizeDo(true,adPks.get(i).getPackageInfoDo().getPackageIdDo().getAdPrizeId()));
        }

        for (Integer i = 0; i < prizePreSelectNum; i++) {
            prizeGroup.add(new PrizeDo(false,pks.get(i).getPackageInfoDo().getPackageIdDo().getPrizeId()));
        }

        int leftNum = slotNum - adPreSelectNum - prizePreSelectNum;
        int pkIdx = prizePreSelectNum;
        int adPkIdx = adPreSelectNum;

        for (int i = 0; i < leftNum; i++) {
            if(pkIdx == pks.size() && adPkIdx == adPks.size()){
                logger.error("No sufficient prize or ad candidate");
                break;
            }
            if(pkIdx == pks.size()){
                prizeGroup.add(new PrizeDo(true,
                        adPks.get(adPkIdx).getPackageInfoDo().getPackageIdDo().getAdPrizeId()));
                adPkIdx++;
                continue;
            }

            if(adPkIdx == adPks.size()){
                prizeGroup.add(new PrizeDo(false,
                        pks.get(pkIdx).getPackageInfoDo().getPackageIdDo().getPrizeId()));
                pkIdx++;
                continue;
            }

            if(pks.get(pkIdx).getMatchScore() > adPks.get(adPkIdx).getMatchScore()){
                prizeGroup.add(new PrizeDo(false,
                        pks.get(pkIdx).getPackageInfoDo().getPackageIdDo().getPrizeId()));
                pkIdx++;
            }else{
                prizeGroup.add(new PrizeDo(true,
                        adPks.get(adPkIdx).getPackageInfoDo().getPackageIdDo().getAdPrizeId()));
                adPkIdx++;
            }
        }

        pkrd.getPackageInfoDo().getPackageIdDo().setPrizeGroup(prizeGroup);
        return pkrd;
    }

    /**
     * 辅助方法，根据MatchScore加权随机挑选组件
     * @param pkrds
     * @return
     */
    static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pkrds){
        if(pkrds.size() == 0){
            logger.error("Weighted random pick from empty list");
            return null;
        }
        Double scoreSum = pkrds.stream().mapToDouble(PackageRecallDo::getMatchScore).sum();
        Double randomNum = Math.random()*scoreSum;
        Integer findIdx = 0;
        for (int i = 0; i < pkrds.size(); i++) {
            findIdx = i;
            randomNum -= pkrds.get(i).getMatchScore();
            if(randomNum <= 0.0){
                break;
            }
        }
        return pkrds.get(findIdx);
    }

    protected static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pks, List<PackageRecallDo> choosen){
        List<PackageRecallDo> candidates = new ArrayList<>(pks.size() - choosen.size() );
        for(PackageRecallDo pk : pks){
            if(!choosen.contains(pk)){
                candidates.add(pk);
            }
        }
        return weightedRandomPickByMatchscore(candidates);
    }

    /**
     * 筛选出recallNum个候选奖品，默认数量为PRIZE_NUM_AT_MOST
     * @param pks
     * @return
     */
    private static List<PackageRecallDo> recallPrize(List<PackageInfoDo> pks){
        if(pks.size() == 0){
            return new ArrayList<PackageRecallDo>();
        }
        List<PackageRecallDo> result = ucbPackage(pks, DPAConstant.PRIZETOPN, DPAConstant.RECALLPRIZE, DPAConstant.GLOBALRECALLPRIZE, 2);
        return result;
    }

    /**
     * 筛选出recallAdPrizeNum个候选奖品，默认数量为AD_PRIZE_NUM_AT_MOST
     * @param pks
     * @return
     */
    protected static List<PackageRecallDo> recallAdPrize(List<PackageInfoDo> pks){
        if(pks.size() == 0){
            return new ArrayList<PackageRecallDo>();
        }
        List<PackageRecallDo> result = ucbPackage(pks, DPAConstant.PRIZETOPN, DPAConstant.RECALLPRIZE, DPAConstant.GLOBALRECALLPRIZE, 2);
        return result;
    }

    private static List<PackageRecallDo> mergePrize(List<PackageRecallDo> recallList,
                                                    List<PackageRecallDo> recallAdList,
                                                    int slotNum){
        List<PackageRecallDo> result = new ArrayList<PackageRecallDo>(slotNum);
        for (Integer i = 0; i < PRIZE_NUM_AT_LEAST; i++) {
            result.add(recallList.remove(0));
        }
        for (Integer i = 0; i < AD_PRIZE_NUM_AT_LEAST; i++) {
            result.add(recallAdList.remove(0));
        }
        int leftNum = slotNum - AD_PRIZE_NUM_AT_LEAST - PRIZE_NUM_AT_LEAST;
        for (int i = 0; i < leftNum; i++) {
            if( recallList.size() == 0 && recallAdList.size() == 0){
                logger.error("insufficient recalled prizes!");
                break;
            }

            if(recallList.size() == 0){
                result.addAll(recallAdList.subList(0,leftNum-i));
                break;
            }

            if(recallAdList.size() == 0){
                result.addAll(recallList.subList(0,leftNum-i));
                break;
            }

            if(recallList.get(0).getMatchScore() > recallAdList.get(0).getMatchScore()){
                result.add(recallList.remove(0));
            }else{
                result.add(recallAdList.remove(0));
            }
        }

        result.sort(new Comparator<PackageRecallDo>() {
            @Override
            public int compare(PackageRecallDo o1, PackageRecallDo o2) {
                if(o1.getMatchScore() == o2.getMatchScore()){
                    return 0;
                }else if(o1.getMatchScore() > o2.getMatchScore()){
                    return 1;
                }else{
                    return -1;
                }
            }
        });
        return result;
    }

    private static String groupSignature(PackageInfoDo packageInfo){
        List<PrizeDo> group = packageInfo.getPackageIdDo().getPrizeGroup();
        return groupSignature(group);
    }

    private static String groupSignature(PackageRecallDo pkrd){
        return groupSignature(pkrd.getPackageInfoDo());
    }

    protected static String groupSignature(List<PrizeDo> group){
        return group.stream().sorted().map(prize -> prize.toString()).collect(Collectors.joining("#"));
    }


    protected static List<Integer> randKWithoutReplacement(List<Double> rates, int k) {
        if (null == rates || rates.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }
        if (k >= rates.size()) {
            logger.error("k is bigger than rates' size");
            throw new RuntimeException("k is bigger than rates' size");
        }

        List<HeapNode> nodes = new ArrayList<HeapNode>(rates.size());

        for (int index = 0; index < rates.size(); index++) {
            nodes.add(new HeapNode(rates.get(index), index));
        }

        List<Integer> result = new ArrayList<Integer>(k);
        List<HeapNode> heap = buildHeap(nodes);

        for (int index = 0; index < k; index++) {
            result.add(heapPop(heap));
        }
        return result;
    }

    private static List<HeapNode> buildHeap(List<HeapNode> nodes) {
        List<HeapNode> heap = new ArrayList<HeapNode>(nodes.size() + 1);

        heap.add(null);
        for (int index = 0; index < nodes.size(); index++) {
            heap.add(nodes.get(index));
        }

        for (int index = heap.size() - 1; index > 1; index--) {
            double curTW = heap.get(index >> 1).totalWeight;
            heap.get(index >> 1).totalWeight = curTW + heap.get(index).totalWeight;
        }
        return heap;

    }

    private static int heapPop(List<HeapNode> heap) {
        double gas = heap.get(1).totalWeight * Math.random();
        int i = 1;
        while (gas > heap.get(i).weight) {
            gas = gas - heap.get(i).weight;
            i <<= 1;
            if (gas > heap.get(i).totalWeight) {
                gas = gas - heap.get(i).totalWeight;
                i++;
            }
        }

        double weight = heap.get(i).weight;
        int value = heap.get(i).value;
        heap.get(i).weight = 0;
        while (i > 0) {
            heap.get(i).totalWeight = heap.get(i).totalWeight - weight;
            i >>= 1;
        }

        return value;
    }

    private static class HeapNode {
        double weight;
        int value;
        double totalWeight;

        public HeapNode(double weight, int value) {
            this.weight = weight;
            this.value = value;
            this.totalWeight = weight;
        }
    }
}
