/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.dpa.PackageRecall;
import cn.com.duiba.nezha.alg.alg.dpa.RecallUtil;
import cn.com.duiba.nezha.alg.alg.dpa.TopKScoreSumGroupGenerater;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.PrizeDo;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PrizeRecall {
    private static final Logger logger = LoggerFactory.getLogger(PrizeRecall.class);
    private static final Integer RECALL_GROUP_MAX_NUM = 10;
    private static final Integer TOTAL_TOPK_PICK_RANGE = 100;
    private static final Integer PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_MOST = 2;
    private static final Integer TRY_NUM_FOR_ONE_GROUP = 3;
    private static final Integer TOP_SCORE_SUM_GROUP_NUM = 3;
    private static final Integer SWAP_RANDOM_RANGE = 20;

    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> pks, int slotNum) {
        ArrayList<PackageInfoDo> adPks = new ArrayList<PackageInfoDo>();
        return PrizeRecall.matchPrizeAndAd(pks, adPks, slotNum);
    }

    public static List<PackageRecallDo> matchPrizeAndAd(List<PackageInfoDo> pks, List<PackageInfoDo> adPks, int slotNum) {
        if (AssertUtil.isAnyEmpty((Object[])new Object[]{pks}) && AssertUtil.isAnyEmpty((Object[])new Object[]{adPks})) {
            logger.error("PackageRecall matchPrize input params prizes and ad has null");
            return null;
        }
        Map<String, Long> adTag2Idx = PrizeRecall.preprocessAdList(adPks, pks);
        List<PackageRecallDo> recallPrizeList = PrizeRecall.recallPrize(pks);
        List<PackageRecallDo> recallAdList = PrizeRecall.recallAdPrize(adPks);
        return PrizeRecall.genPrizeGroup(recallPrizeList, recallAdList, slotNum, RECALL_GROUP_MAX_NUM);
    }

    public static Map<String, Long> preprocessAdList(List<PackageInfoDo> adPks, List<PackageInfoDo> pks) {
        long maxPkId = 0L;
        long maxPkTagId = 0L;
        long offset = 100000L;
        if (pks == null || pks.size() == 0) {
            maxPkId = offset;
            maxPkTagId = offset;
        } else {
            for (PackageInfoDo pk : pks) {
                PackageIdDo pid = pk.getPackageIdDo();
                maxPkId = pid.getPrizeId() > maxPkId ? pid.getPrizeId() : maxPkId;
                maxPkTagId = pid.getPrizeTagId() > maxPkTagId ? pid.getPrizeTagId() : maxPkTagId;
            }
        }
        long fakePrizeId4Ad = maxPkId + offset;
        long startFakeTagId4Ad = maxPkTagId + offset;
        HashMap<String, Long> adTag2Idx = new HashMap<String, Long>();
        for (PackageInfoDo pk : adPks) {
            Long tagIdx;
            pk.getPackageIdDo().setPrizeId(fakePrizeId4Ad++);
            String typeString = pk.getPackageIdDo().getAdPrizeType();
            if (adTag2Idx.containsKey(typeString)) {
                tagIdx = (Long)adTag2Idx.get(typeString);
            } else {
                tagIdx = (long)adTag2Idx.size() + startFakeTagId4Ad;
                adTag2Idx.put(typeString, tagIdx);
            }
            pk.getPackageIdDo().setPrizeTagId(tagIdx);
        }
        return adTag2Idx;
    }

    private static TagInfo getTagBucket(List<PackageRecallDo> prds) {
        HashMap<Long, List<PackageRecallDo>> tagBuckets = new HashMap<Long, List<PackageRecallDo>>();
        HashMap<Long, Double> tagWeight = new HashMap<Long, Double>();
        for (PackageRecallDo pkrd : prds) {
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (tagBuckets.containsKey(prizeTagId)) {
                ((List)tagBuckets.get(prizeTagId)).add(pkrd);
            } else {
                ArrayList<PackageRecallDo> bucket = new ArrayList<PackageRecallDo>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId, bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId, 0.0) + pkrd.getMatchScore());
        }
        return new TagInfo(tagBuckets, tagWeight);
    }

    private static List<PackageRecallDo> topKFromMergedPrizeAndAdList(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, int topk) {
        if (pks.size() == 0 && adPks.size() == 0) {
            logger.warn("No input need to be merged!");
        }
        int curIdxPK = 0;
        int curIdxAd = 0;
        int totalSize = Math.min(topk, pks.size() + adPks.size());
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(totalSize);
        for (int i = 0; i < totalSize; ++i) {
            if (curIdxPK == pks.size()) {
                result.addAll(adPks.subList(curIdxAd, curIdxAd + totalSize - i));
                break;
            }
            if (curIdxAd == adPks.size()) {
                result.addAll(pks.subList(curIdxPK, curIdxPK + totalSize - i));
                break;
            }
            if (pks.get(curIdxPK).getMatchScore() > adPks.get(curIdxAd).getMatchScore()) {
                result.add(pks.get(curIdxPK));
                ++curIdxPK;
                continue;
            }
            result.add(adPks.get(curIdxAd));
            ++curIdxAd;
        }
        return result;
    }

    private static List<PackageRecallDo> genPrizeGroup(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, int slotNum, int groupNum) {
        HashSet<String> existingGroupSet = new HashSet<String>();
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>();
        List<PackageRecallDo> sortedPks = pks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore).reversed()).collect(Collectors.toList());
        List<PackageRecallDo> sortedAdPks = adPks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore).reversed()).collect(Collectors.toList());
        ArrayList<PackageRecallDo> allPkrds = new ArrayList<PackageRecallDo>(pks.size() + adPks.size());
        allPkrds.addAll(sortedPks);
        allPkrds.addAll(adPks);
        List<PackageRecallDo> sortedAllPks = allPkrds.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore).reversed()).collect(Collectors.toList());
        List<List<PackageRecallDo>> topPkLists = PrizeRecall.getTopKScoreSumPkLists(sortedAllPks, slotNum, 10);
        List<PackageRecallDo> topKScoreSumGroups = PrizeRecall.genTopKScoreSumGroups(topPkLists, 3);
        topKScoreSumGroups.forEach(pkrd -> existingGroupSet.add(PrizeRecall.groupSignature(pkrd)));
        result.addAll(topKScoreSumGroups);
        int swapGenNum = 4;
        List<PackageRecallDo> swapGroups = PrizeRecall.genSwapGroups(topPkLists, sortedPks, swapGenNum, existingGroupSet);
        result.addAll(swapGroups);
        int ramdomPickNum = 3;
        List<PackageRecallDo> randomPickGroups = PrizeRecall.genByRandomPick(sortedPks, sortedAdPks, existingGroupSet, slotNum, ramdomPickNum);
        result.addAll(randomPickGroups);
        int tagGenNum = 5;
        List<PackageRecallDo> tagGroups = PrizeRecall.genByPickByTag(sortedPks, sortedAdPks, existingGroupSet, slotNum, tagGenNum);
        result.addAll(tagGroups);
        return result;
    }

    private static List<List<PackageRecallDo>> getTopKScoreSumPkLists(List<PackageRecallDo> sortedAllPks, int slotNum, int k) {
        List<List<PackageRecallDo>> result = TopKScoreSumGroupGenerater.randomKFromTopPSumGroups(sortedAllPks, k, slotNum);
        return result;
    }

    private static List<PackageRecallDo> genTopKScoreSumGroups(List<List<PackageRecallDo>> pkGroups, int groupNum) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        if (pkGroups.size() == 0 || groupNum < 1) {
            return result;
        }
        groupNum = Math.min(groupNum, pkGroups.size());
        for (List<PackageRecallDo> pkGroup : pkGroups.subList(0, groupNum)) {
            PackageRecallDo pkWithGroup = PrizeRecall.groupPkrdFromPkrdList(pkGroup);
            result.add(pkWithGroup);
        }
        return result;
    }

    private static List<PackageRecallDo> genByPickByTag(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, Set<String> existingGroupSet, int slotNum, Integer groupNum) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        TagInfo prizeTagInfo = PrizeRecall.getTagBucket(pks);
        TagInfo adTagInfo = PrizeRecall.getTagBucket(adPks);
        for (int i = 0; i < groupNum; ++i) {
            boolean isNewOne = false;
            for (int tryNum = 0; tryNum < TRY_NUM_FOR_ONE_GROUP && !isNewOne; ++tryNum) {
                PackageRecallDo pkrd = PrizeRecall.genOneGroupByTagInfo(pks, adPks, prizeTagInfo, adTagInfo, slotNum);
                String groupSign = PrizeRecall.groupSignature(pkrd);
                if (existingGroupSet.contains(groupSign)) continue;
                existingGroupSet.add(groupSign);
                result.add(pkrd);
                isNewOne = true;
            }
        }
        return result;
    }

    private static PackageRecallDo genOneGroupByTagInfo(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, TagInfo prizeTagInfo, TagInfo adTagInfo, int slotNum) {
        int adNum = Math.min(adTagInfo.tagBucket.size(), AD_PRIZE_NUM_AT_MOST);
        int prizeNum = slotNum - adNum;
        List<PackageRecallDo> ads = adNum <= adTagInfo.tagBucket.size() ? PrizeRecall.fetchPrizeFromRandomPickedTag(adTagInfo, adNum) : PrizeRecall.fetchPkrdFromWeightedList(adPks, adNum);
        List<PackageRecallDo> prizes = prizeNum <= prizeTagInfo.tagBucket.size() ? PrizeRecall.fetchPrizeFromRandomPickedTag(prizeTagInfo, prizeNum) : PrizeRecall.fetchPkrdFromWeightedList(pks, prizeNum);
        ArrayList<PackageRecallDo> pkrdList = new ArrayList<PackageRecallDo>(slotNum);
        ads.forEach(ad -> pkrdList.add((PackageRecallDo)ad));
        prizes.forEach(prize -> pkrdList.add((PackageRecallDo)prize));
        return PrizeRecall.groupPkrdFromPkrdList(pkrdList);
    }

    private static PrizeDo prizeDoFromPackageRecallDo(PackageRecallDo pkrd) {
        PackageIdDo curPid = pkrd.getPackageInfoDo().getPackageIdDo();
        boolean isAd = curPid.getAdPrizeId() != null;
        Long prizeTagId = curPid.getPrizeTagId();
        Long id = isAd ? curPid.getAdPrizeId() : curPid.getPrizeId();
        return new PrizeDo(isAd, id, prizeTagId, pkrd.getMatchScore());
    }

    private static List<PackageRecallDo> fetchPrizeFromRandomPickedTag(TagInfo tagInfo, int k) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>();
        List<Long> choosenAdTags = PrizeRecall.weightedPickTags(tagInfo.tagWeight, k);
        for (Long tagId : choosenAdTags) {
            PackageRecallDo pkrd = tagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    protected static List<Long> weightedPickTags(Map<Long, Double> tagWeight, int topk) {
        ArrayList<Long> result = new ArrayList<Long>(topk);
        if (topk <= 0) {
            return result;
        }
        ArrayList<Long> tags = new ArrayList<Long>(tagWeight.size());
        ArrayList<Double> weights = new ArrayList<Double>(tagWeight.size());
        for (Long tagId : tagWeight.keySet()) {
            tags.add(tagId);
            weights.add(tagWeight.get(tagId));
        }
        if (topk > tagWeight.size()) {
            logger.warn("k\u5927\u4e8e\u7c7b\u76ee\u6570,\u622a\u65ad\u6210\u7c7b\u76ee\u6570");
            topk = tagWeight.size();
        }
        List<Integer> resultIdx = topk == tagWeight.size() ? RecallUtil.argSort(weights) : RecallUtil.randKWithoutReplacement(weights, topk);
        for (Integer idx : resultIdx) {
            result.add((Long)tags.get(idx));
        }
        return result;
    }

    private static List<PackageRecallDo> genByRandomPick(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, Set<String> existingGroupSet, Integer slotNum, Integer groupNum) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        int mergeNum = Math.min(TOTAL_TOPK_PICK_RANGE, pks.size() + adPks.size());
        List<PackageRecallDo> sortedPrizes = PrizeRecall.topKFromMergedPrizeAndAdList(pks, adPks, mergeNum);
        for (int i = 0; i < groupNum; ++i) {
            boolean isNewOne = false;
            for (int tryNum = 0; tryNum < TRY_NUM_FOR_ONE_GROUP && !isNewOne; ++tryNum) {
                PackageRecallDo pkrd = PrizeRecall.randomPickOneGroup(sortedPrizes, slotNum);
                String groupSign = PrizeRecall.groupSignature(pkrd);
                if (existingGroupSet.contains(groupSign)) continue;
                existingGroupSet.add(groupSign);
                result.add(pkrd);
                isNewOne = true;
            }
        }
        return result;
    }

    private static PackageRecallDo randomPickOneGroup(List<PackageRecallDo> sortedPrizes, Integer slotNum) {
        PackageRecallDo result = new PackageRecallDo();
        ArrayList<PrizeDo> group = new ArrayList<PrizeDo>(slotNum);
        List<PackageRecallDo> pkrds = PrizeRecall.fetchPkrdFromWeightedList(sortedPrizes, slotNum);
        for (PackageRecallDo pkrd : pkrds) {
            PrizeDo prizeDo = PrizeRecall.prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(group);
        return result;
    }

    protected static List<PackageRecallDo> fetchPkrdFromWeightedList(List<PackageRecallDo> sortedPrizes, int k) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(k);
        if (k > sortedPrizes.size()) {
            logger.warn("\u65e0\u653e\u56de\u62bd\u6837\u6570\u5927\u4e8e\u6837\u672c\u6570,\u8fd4\u56de\u539fList");
            return sortedPrizes;
        }
        List<Double> weights = sortedPrizes.stream().map(PackageRecallDo::getMatchScore).collect(Collectors.toList());
        List<Integer> idxs = RecallUtil.randKWithoutReplacement(weights, k);
        for (Integer idx : idxs) {
            result.add(sortedPrizes.get(idx));
        }
        return result;
    }

    private static PackageRecallDo genBestScoreGroup(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, int slotNum) {
        int i;
        ArrayList<PackageRecallDo> pkrdGroup = new ArrayList<PackageRecallDo>(slotNum);
        int adPreSelectNum = Math.min(AD_PRIZE_NUM_AT_LEAST, adPks.size());
        int prizePreSelectNum = Math.min(PRIZE_NUM_AT_LEAST, pks.size());
        for (i = 0; i < adPreSelectNum; ++i) {
            pkrdGroup.add(adPks.get(i));
        }
        for (i = 0; i < prizePreSelectNum; ++i) {
            pkrdGroup.add(pks.get(i));
        }
        int leftNum = slotNum - adPreSelectNum - prizePreSelectNum;
        int pkIdx = prizePreSelectNum;
        int adPkIdx = adPreSelectNum;
        for (int i2 = 0; i2 < leftNum; ++i2) {
            if (pkIdx == pks.size() && adPkIdx == adPks.size()) {
                logger.error("No sufficient prize or ad candidate");
                break;
            }
            if (pkIdx == pks.size()) {
                pkrdGroup.add(adPks.get(adPkIdx++));
                continue;
            }
            if (adPkIdx == adPks.size()) {
                pkrdGroup.add(pks.get(pkIdx++));
                continue;
            }
            if (pks.get(pkIdx).getMatchScore() > adPks.get(adPkIdx).getMatchScore()) {
                pkrdGroup.add(pks.get(pkIdx++));
                continue;
            }
            pkrdGroup.add(adPks.get(adPkIdx++));
        }
        return PrizeRecall.groupPkrdFromPkrdList(pkrdGroup);
    }

    static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pkrds) {
        if (pkrds.size() == 0) {
            logger.error("Weighted random pick from empty list");
            return null;
        }
        double scoreSum = pkrds.stream().mapToDouble(PackageRecallDo::getMatchScore).sum();
        double randomNum = Math.random() * scoreSum;
        int findIdx = 0;
        for (int i = 0; i < pkrds.size(); ++i) {
            findIdx = i;
            if ((randomNum -= pkrds.get(i).getMatchScore().doubleValue()) <= 0.0) break;
        }
        return pkrds.get(findIdx);
    }

    protected static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pks, List<PackageRecallDo> choosen) {
        ArrayList<PackageRecallDo> candidates = new ArrayList<PackageRecallDo>(pks.size() - choosen.size());
        for (PackageRecallDo pk : pks) {
            if (choosen.contains(pk)) continue;
            candidates.add(pk);
        }
        return PrizeRecall.weightedRandomPickByMatchscore(candidates);
    }

    private static List<PackageRecallDo> recallPrize(List<PackageInfoDo> pks) {
        if (pks.size() == 0) {
            return new ArrayList<PackageRecallDo>();
        }
        int topN = Math.min(100, pks.size());
        return PackageRecall.ucbPackage(pks, topN, 2, 2);
    }

    protected static List<PackageRecallDo> recallAdPrize(List<PackageInfoDo> pks) {
        if (pks.size() == 0) {
            return new ArrayList<PackageRecallDo>();
        }
        int topN = Math.min(20, pks.size());
        return PackageRecall.ucbPackage(pks, topN, 2, 2);
    }

    private static String groupSignature(PackageInfoDo packageInfo) {
        List<PrizeDo> group = packageInfo.getPackageIdDo().getPrizeGroup();
        return PrizeRecall.groupSignature(group);
    }

    private static String groupSignature(PackageRecallDo pkrd) {
        return PrizeRecall.groupSignature(pkrd.getPackageInfoDo());
    }

    protected static String groupSignature(List<PrizeDo> group) {
        return group.stream().sorted().map(PrizeDo::toString).collect(Collectors.joining("#"));
    }

    private static List<PackageRecallDo> genSwapGroups(List<List<PackageRecallDo>> sourceGroups, List<PackageRecallDo> pks, int groupNum, Set<String> existingGroupSet) {
        if (AssertUtil.isEmpty(sourceGroups)) {
            logger.error("Sourcegroups of swap action is empty!");
            return null;
        }
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(groupNum);
        for (int i = 0; i < groupNum; ++i) {
            boolean isNewOne = false;
            for (int tryNum = 0; tryNum < TRY_NUM_FOR_ONE_GROUP && !isNewOne; ++tryNum) {
                int sourceIdx = (int)Math.random() * sourceGroups.size();
                List<PackageRecallDo> pkrdList = PrizeRecall.aSwapGroup(sourceGroups.get(sourceIdx), pks);
                PackageRecallDo pkrd = PrizeRecall.groupPkrdFromPkrdList(pkrdList);
                String groupSign = PrizeRecall.groupSignature(pkrd);
                if (existingGroupSet.contains(groupSign)) continue;
                existingGroupSet.add(groupSign);
                result.add(pkrd);
                isNewOne = true;
            }
        }
        return result;
    }

    private static List<PackageRecallDo> aSwapGroup(List<PackageRecallDo> sourceGroup, List<PackageRecallDo> pks) {
        if (pks.size() <= sourceGroup.size()) {
            logger.warn("Number is not match, no swap!");
            return sourceGroup;
        }
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(sourceGroup.size());
        Set existIds = sourceGroup.stream().map(pk -> pk.getPackageInfoDo().getPackageIdDo().getPrizeId()).collect(Collectors.toSet());
        ArrayList<PackageRecallDo> candidatePrizes = new ArrayList<PackageRecallDo>(pks.size() - sourceGroup.size());
        ArrayList<Double> weights = new ArrayList<Double>(pks.size() - sourceGroup.size());
        for (PackageRecallDo pk2 : pks) {
            if (existIds.contains(pk2.getPackageInfoDo().getPackageIdDo().getPrizeId())) continue;
            candidatePrizes.add(pk2);
            weights.add(pk2.getMatchScore());
        }
        int sourceSwapRange = 4;
        int sourceIdx = (int)(Math.random() * (double)sourceSwapRange);
        int candidateIdx = RecallUtil.weightedRandPick(weights);
        for (int i = 0; i < sourceGroup.size(); ++i) {
            if (i == sourceIdx) {
                result.add((PackageRecallDo)candidatePrizes.get(candidateIdx));
                continue;
            }
            result.add(sourceGroup.get(i));
        }
        return result;
    }

    private static PackageRecallDo groupPkrdFromPkrdList(List<PackageRecallDo> pkrdList) {
        PackageRecallDo result = new PackageRecallDo();
        ArrayList<PrizeDo> prizeGroup = new ArrayList<PrizeDo>(pkrdList.size());
        pkrdList.forEach(pk -> prizeGroup.add(PrizeRecall.prizeDoFromPackageRecallDo(pk)));
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(prizeGroup);
        return result;
    }

    private static class TagInfo {
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;

        public TagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight) {
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }
}

