package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.vo.GenericPair;
import cn.com.duiba.nezha.alg.feature.vo.PrizeDo;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/com/duiba/nezha/alg/alg/dpa/PrizeRecall.class */
public class PrizeRecall {
    private static final Logger logger = LoggerFactory.getLogger(PrizeRecall.class);
    private static final Integer RECALL_GROUP_MAX_NUM = 10;
    private static final Integer TOTAL_TOPK_PICK_RANGE = 100;
    private static final Integer PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_MOST = 2;
    private static final Integer TRY_NUM_FOR_ONE_GROUP = 3;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cn/com/duiba/nezha/alg/alg/dpa/PrizeRecall$HeapNode.class */
    public static class HeapNode {
        double weight;
        int value;
        double totalWeight;

        public HeapNode(double d, int i) {
            this.weight = d;
            this.value = i;
            this.totalWeight = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cn/com/duiba/nezha/alg/alg/dpa/PrizeRecall$TagInfo.class */
    public static class TagInfo {
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;

        public TagInfo(Map<Long, List<PackageRecallDo>> map, Map<Long, Double> map2) {
            this.tagBucket = map;
            this.tagWeight = map2;
        }
    }

    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> list, int i) {
        return matchPrizeAndAd(list, new ArrayList(), i);
    }

    public static List<PackageRecallDo> matchPrizeAndAd(List<PackageInfoDo> list, List<PackageInfoDo> list2, int i) {
        if (AssertUtil.isAnyEmpty(new Object[]{list}) && AssertUtil.isAnyEmpty(new Object[]{list2})) {
            logger.error("PackageRecall matchPrize input params prizes and ad is null");
            return null;
        }
        preprocessAdList(list2);
        return genPrizeGroup(recallPrize(list), recallAdPrize(list2), Integer.valueOf(i), RECALL_GROUP_MAX_NUM);
    }

    public static Map<String, Long> preprocessAdList(List<PackageInfoDo> list) {
        Long valueOf;
        HashMap hashMap = new HashMap();
        for (PackageInfoDo packageInfoDo : list) {
            String adPrizeType = packageInfoDo.getPackageIdDo().getAdPrizeType();
            if (hashMap.containsKey(adPrizeType)) {
                valueOf = (Long) hashMap.get(adPrizeType);
            } else {
                valueOf = Long.valueOf(hashMap.size() + 10000);
                hashMap.put(adPrizeType, valueOf);
            }
            packageInfoDo.getPackageIdDo().setPrizeTagId(valueOf);
        }
        return hashMap;
    }

    private static TagInfo getTagBucket(List<PackageRecallDo> list) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (PackageRecallDo packageRecallDo : list) {
            Long prizeTagId = packageRecallDo.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (hashMap.containsKey(prizeTagId)) {
                ((List) hashMap.get(prizeTagId)).add(packageRecallDo);
            } else {
                ArrayList arrayList = new ArrayList();
                arrayList.add(packageRecallDo);
                hashMap.put(prizeTagId, arrayList);
            }
            hashMap2.put(prizeTagId, Double.valueOf(((Double) hashMap2.getOrDefault(prizeTagId, Double.valueOf(0.0d))).doubleValue() + packageRecallDo.getMatchScore().doubleValue()));
        }
        return new TagInfo(hashMap, hashMap2);
    }

    private static List<PackageRecallDo> topKFromMergedPrizeAndAdList(List<PackageRecallDo> list, List<PackageRecallDo> list2, int i) {
        if (list.size() == 0 && list2.size() == 0) {
            logger.warn("No input need to be merged!");
        }
        int i2 = 0;
        int i3 = 0;
        int min = Math.min(i, list.size() + list2.size());
        ArrayList arrayList = new ArrayList(min);
        int i4 = 0;
        while (true) {
            if (i4 >= min) {
                break;
            }
            if (i2 == list.size()) {
                arrayList.addAll(list2.subList(i3, (i3 + min) - i4));
                break;
            }
            if (i3 == list2.size()) {
                arrayList.addAll(list.subList(i2, (i2 + min) - i4));
                break;
            }
            if (list.get(i2).getMatchScore().doubleValue() > list2.get(i3).getMatchScore().doubleValue()) {
                arrayList.add(list.get(i2));
                i2++;
            } else {
                arrayList.add(list2.get(i3));
                i3++;
            }
            i4++;
        }
        return arrayList;
    }

    private static List<PackageRecallDo> genPrizeGroup(List<PackageRecallDo> list, List<PackageRecallDo> list2, Integer num, Integer num2) {
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        List list3 = (List) list.stream().sorted(Comparator.comparing((v0) -> {
            return v0.getMatchScore();
        }).reversed()).collect(Collectors.toList());
        List list4 = (List) list2.stream().sorted(Comparator.comparing((v0) -> {
            return v0.getMatchScore();
        }).reversed()).collect(Collectors.toList());
        PackageRecallDo genBestScoreGroup = genBestScoreGroup(list3, list4, num.intValue());
        arrayList.add(genBestScoreGroup);
        hashSet.add(groupSignature(genBestScoreGroup));
        arrayList.addAll(genByRandomPick(list3, list4, hashSet, num, Integer.valueOf((int) Math.ceil(num2.intValue() * 0.4d))));
        arrayList.addAll(genByPickByTag(list3, list4, hashSet, num.intValue(), Integer.valueOf((int) Math.ceil(num2.intValue() * 0.5d))));
        return arrayList;
    }

    private static List<PackageRecallDo> genByPickByTag(List<PackageRecallDo> list, List<PackageRecallDo> list2, Set<String> set, int i, Integer num) {
        ArrayList arrayList = new ArrayList(num.intValue());
        TagInfo tagBucket = getTagBucket(list);
        TagInfo tagBucket2 = getTagBucket(list2);
        for (int i2 = 0; i2 < num.intValue(); i2++) {
            boolean z = false;
            for (int i3 = 0; i3 < TRY_NUM_FOR_ONE_GROUP.intValue() && !z; i3++) {
                PackageRecallDo genOneGroupByTagInfo = genOneGroupByTagInfo(list, list2, tagBucket, tagBucket2, i);
                String groupSignature = groupSignature(genOneGroupByTagInfo);
                if (!set.contains(groupSignature)) {
                    set.add(groupSignature);
                    arrayList.add(genOneGroupByTagInfo);
                    z = true;
                }
            }
        }
        return arrayList;
    }

    private static PackageRecallDo genOneGroupByTagInfo(List<PackageRecallDo> list, List<PackageRecallDo> list2, TagInfo tagInfo, TagInfo tagInfo2, int i) {
        PackageRecallDo packageRecallDo = new PackageRecallDo();
        int min = Math.min(tagInfo2.tagBucket.size(), AD_PRIZE_NUM_AT_MOST.intValue());
        int i2 = i - min;
        List<PackageRecallDo> fetchPrizeFromRandomPickedTag = min <= tagInfo2.tagBucket.size() ? fetchPrizeFromRandomPickedTag(tagInfo2, min) : fetchPkrdFromWeightedList(list2, min);
        List<PackageRecallDo> fetchPrizeFromRandomPickedTag2 = i2 <= tagInfo.tagBucket.size() ? fetchPrizeFromRandomPickedTag(tagInfo, i2) : fetchPkrdFromWeightedList(list, i2);
        ArrayList arrayList = new ArrayList(i);
        Iterator<PackageRecallDo> it = fetchPrizeFromRandomPickedTag.iterator();
        while (it.hasNext()) {
            arrayList.add(prizeDoFromPackageRecallDo(it.next()));
        }
        Iterator<PackageRecallDo> it2 = fetchPrizeFromRandomPickedTag2.iterator();
        while (it2.hasNext()) {
            arrayList.add(prizeDoFromPackageRecallDo(it2.next()));
        }
        packageRecallDo.getPackageInfoDo().getPackageIdDo().setPrizeGroup(arrayList);
        return packageRecallDo;
    }

    private static PrizeDo prizeDoFromPackageRecallDo(PackageRecallDo packageRecallDo) {
        PackageIdDo packageIdDo = packageRecallDo.getPackageInfoDo().getPackageIdDo();
        boolean z = packageIdDo.getAdPrizeId() != null;
        return new PrizeDo(z, z ? packageIdDo.getAdPrizeId() : packageIdDo.getPrizeId(), packageIdDo.getPrizeTagId(), packageRecallDo.getMatchScore());
    }

    private static List<PackageRecallDo> fetchPrizeFromRandomPickedTag(TagInfo tagInfo, int i) {
        ArrayList arrayList = new ArrayList();
        Iterator<Long> it = weightedPickTags(tagInfo.tagWeight, i).iterator();
        while (it.hasNext()) {
            arrayList.add(tagInfo.tagBucket.get(it.next()).get(0));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<Long> weightedPickTags(Map<Long, Double> map, int i) {
        ArrayList arrayList = new ArrayList(i);
        if (i <= 0) {
            return arrayList;
        }
        ArrayList arrayList2 = new ArrayList(map.size());
        ArrayList arrayList3 = new ArrayList(map.size());
        for (Long l : map.keySet()) {
            arrayList2.add(l);
            arrayList3.add(map.get(l));
        }
        if (i > map.size()) {
            logger.warn("k大于类目数,截断成类目数");
            i = map.size();
        }
        Iterator<Integer> it = (i == map.size() ? argSort(arrayList3) : randKWithoutReplacement(arrayList3, i)).iterator();
        while (it.hasNext()) {
            arrayList.add(arrayList2.get(it.next().intValue()));
        }
        return arrayList;
    }

    private static List<PackageRecallDo> genByRandomPick(List<PackageRecallDo> list, List<PackageRecallDo> list2, Set<String> set, Integer num, Integer num2) {
        ArrayList arrayList = new ArrayList(num2.intValue());
        List<PackageRecallDo> list3 = topKFromMergedPrizeAndAdList(list, list2, Math.min(TOTAL_TOPK_PICK_RANGE.intValue(), list.size() + list2.size()));
        for (int i = 0; i < num2.intValue(); i++) {
            boolean z = false;
            for (int i2 = 0; i2 < TRY_NUM_FOR_ONE_GROUP.intValue() && !z; i2++) {
                PackageRecallDo randomPickOneGroup = randomPickOneGroup(list3, num);
                String groupSignature = groupSignature(randomPickOneGroup);
                if (!set.contains(groupSignature)) {
                    set.add(groupSignature);
                    arrayList.add(randomPickOneGroup);
                    z = true;
                }
            }
        }
        return arrayList;
    }

    private static PackageRecallDo randomPickOneGroup(List<PackageRecallDo> list, Integer num) {
        PackageRecallDo packageRecallDo = new PackageRecallDo();
        ArrayList arrayList = new ArrayList(num.intValue());
        Iterator<PackageRecallDo> it = fetchPkrdFromWeightedList(list, num.intValue()).iterator();
        while (it.hasNext()) {
            arrayList.add(prizeDoFromPackageRecallDo(it.next()));
        }
        packageRecallDo.getPackageInfoDo().getPackageIdDo().setPrizeGroup(arrayList);
        return packageRecallDo;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<PackageRecallDo> fetchPkrdFromWeightedList(List<PackageRecallDo> list, int i) {
        ArrayList arrayList = new ArrayList(i);
        if (i > list.size()) {
            logger.warn("无放回抽样数大于样本数,返回原List");
            return list;
        }
        Iterator<Integer> it = randKWithoutReplacement((List) list.stream().map((v0) -> {
            return v0.getMatchScore();
        }).collect(Collectors.toList()), i).iterator();
        while (it.hasNext()) {
            arrayList.add(list.get(it.next().intValue()));
        }
        return arrayList;
    }

    private static PackageRecallDo genBestScoreGroup(List<PackageRecallDo> list, List<PackageRecallDo> list2, int i) {
        PackageRecallDo packageRecallDo = new PackageRecallDo();
        ArrayList arrayList = new ArrayList();
        int min = Math.min(AD_PRIZE_NUM_AT_LEAST.intValue(), list2.size());
        int min2 = Math.min(PRIZE_NUM_AT_LEAST.intValue(), list.size());
        for (int i2 = 0; i2 < min; i2++) {
            PackageRecallDo packageRecallDo2 = list2.get(i2);
            arrayList.add(new PrizeDo(true, packageRecallDo2.getPackageInfoDo().getPackageIdDo().getAdPrizeId(), packageRecallDo2.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo2.getMatchScore()));
        }
        for (int i3 = 0; i3 < min2; i3++) {
            PackageRecallDo packageRecallDo3 = list.get(i3);
            arrayList.add(new PrizeDo(false, packageRecallDo3.getPackageInfoDo().getPackageIdDo().getPrizeId(), packageRecallDo3.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo3.getMatchScore()));
        }
        int i4 = (i - min) - min2;
        int i5 = min2;
        int i6 = min;
        int i7 = 0;
        while (true) {
            if (i7 >= i4) {
                break;
            }
            if (i5 == list.size() && i6 == list2.size()) {
                logger.error("No sufficient prize or ad candidate");
                break;
            }
            if (i5 == list.size()) {
                PackageRecallDo packageRecallDo4 = list2.get(i6);
                arrayList.add(new PrizeDo(true, packageRecallDo4.getPackageInfoDo().getPackageIdDo().getAdPrizeId(), packageRecallDo4.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo4.getMatchScore()));
                i6++;
            } else if (i6 == list2.size()) {
                PackageRecallDo packageRecallDo5 = list.get(i5);
                arrayList.add(new PrizeDo(false, packageRecallDo5.getPackageInfoDo().getPackageIdDo().getPrizeId(), packageRecallDo5.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo5.getMatchScore()));
                i5++;
            } else if (list.get(i5).getMatchScore().doubleValue() > list2.get(i6).getMatchScore().doubleValue()) {
                PackageRecallDo packageRecallDo6 = list.get(i5);
                arrayList.add(new PrizeDo(false, packageRecallDo6.getPackageInfoDo().getPackageIdDo().getPrizeId(), packageRecallDo6.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo6.getMatchScore()));
                i5++;
            } else {
                PackageRecallDo packageRecallDo7 = list2.get(i6);
                arrayList.add(new PrizeDo(true, packageRecallDo7.getPackageInfoDo().getPackageIdDo().getAdPrizeId(), packageRecallDo7.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), packageRecallDo7.getMatchScore()));
                i6++;
            }
            i7++;
        }
        packageRecallDo.getPackageInfoDo().getPackageIdDo().setPrizeGroup(arrayList);
        return packageRecallDo;
    }

    static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> list) {
        if (list.size() == 0) {
            logger.error("Weighted random pick from empty list");
            return null;
        }
        double random = Math.random() * list.stream().mapToDouble((v0) -> {
            return v0.getMatchScore();
        }).sum();
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            i = i2;
            random -= list.get(i2).getMatchScore().doubleValue();
            if (random <= 0.0d) {
                break;
            }
        }
        return list.get(i);
    }

    protected static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> list, List<PackageRecallDo> list2) {
        ArrayList arrayList = new ArrayList(list.size() - list2.size());
        for (PackageRecallDo packageRecallDo : list) {
            if (!list2.contains(packageRecallDo)) {
                arrayList.add(packageRecallDo);
            }
        }
        return weightedRandomPickByMatchscore(arrayList);
    }

    private static List<PackageRecallDo> recallPrize(List<PackageInfoDo> list) {
        return list.size() == 0 ? new ArrayList() : PackageRecall.ucbPackage(list, Math.min(100, list.size()), 2, 3, 2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<PackageRecallDo> recallAdPrize(List<PackageInfoDo> list) {
        return list.size() == 0 ? new ArrayList() : PackageRecall.ucbPackage(list, Math.min(50, list.size()), 2, 3, 2);
    }

    private static List<PackageRecallDo> mergePrize(List<PackageRecallDo> list, List<PackageRecallDo> list2, int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < PRIZE_NUM_AT_LEAST.intValue(); i2++) {
            arrayList.add(list.remove(0));
        }
        for (int i3 = 0; i3 < AD_PRIZE_NUM_AT_LEAST.intValue(); i3++) {
            arrayList.add(list2.remove(0));
        }
        int intValue = (i - AD_PRIZE_NUM_AT_LEAST.intValue()) - PRIZE_NUM_AT_LEAST.intValue();
        int i4 = 0;
        while (true) {
            if (i4 >= intValue) {
                break;
            }
            if (list.size() == 0 && list2.size() == 0) {
                logger.error("insufficient recalled prizes!");
                break;
            }
            if (list.size() == 0) {
                arrayList.addAll(list2.subList(0, intValue - i4));
                break;
            }
            if (list2.size() == 0) {
                arrayList.addAll(list.subList(0, intValue - i4));
                break;
            }
            if (list.get(0).getMatchScore().doubleValue() > list2.get(0).getMatchScore().doubleValue()) {
                arrayList.add(list.remove(0));
            } else {
                arrayList.add(list2.remove(0));
            }
            i4++;
        }
        arrayList.sort((packageRecallDo, packageRecallDo2) -> {
            if (packageRecallDo.getMatchScore().equals(packageRecallDo2.getMatchScore())) {
                return 0;
            }
            return packageRecallDo.getMatchScore().doubleValue() > packageRecallDo2.getMatchScore().doubleValue() ? 1 : -1;
        });
        return arrayList;
    }

    private static String groupSignature(PackageInfoDo packageInfoDo) {
        return groupSignature(packageInfoDo.getPackageIdDo().getPrizeGroup());
    }

    private static String groupSignature(PackageRecallDo packageRecallDo) {
        return groupSignature(packageRecallDo.getPackageInfoDo());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String groupSignature(List<PrizeDo> list) {
        return (String) list.stream().sorted().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining("#"));
    }

    protected static List<Integer> randKWithReplacement(List<Double> list, int i) {
        if (null != list && !list.isEmpty()) {
            return null;
        }
        logger.error("the rates list is null or empty");
        throw new RuntimeException("the rates list is null or empty");
    }

    protected static List<Integer> argSort(List<Double> list) {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new GenericPair(Integer.valueOf(i), list.get(i)));
        }
        return (List) ((List) arrayList.stream().sorted().collect(Collectors.toList())).stream().mapToInt(genericPair -> {
            return ((Integer) genericPair.first).intValue();
        }).boxed().collect(Collectors.toList());
    }

    protected static List<Integer> randKWithoutReplacement(List<Double> list, int i) {
        if (null == list || list.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }
        if (i >= list.size()) {
            logger.error("k is bigger than rates' size");
            throw new RuntimeException("k is bigger than rates' size");
        }
        ArrayList arrayList = new ArrayList(list.size());
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.add(new HeapNode(list.get(i2).doubleValue(), i2));
        }
        ArrayList arrayList2 = new ArrayList(i);
        List<HeapNode> buildHeap = buildHeap(arrayList);
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add(Integer.valueOf(heapPop(buildHeap)));
        }
        return arrayList2;
    }

    private static List<HeapNode> buildHeap(List<HeapNode> list) {
        ArrayList arrayList = new ArrayList(list.size() + 1);
        arrayList.add(null);
        arrayList.addAll(list);
        for (int size = arrayList.size() - 1; size > 1; size--) {
            ((HeapNode) arrayList.get(size >> 1)).totalWeight = ((HeapNode) arrayList.get(size >> 1)).totalWeight + ((HeapNode) arrayList.get(size)).totalWeight;
        }
        return arrayList;
    }

    private static int heapPop(List<HeapNode> list) {
        double random = list.get(1).totalWeight * Math.random();
        int i = 1;
        while (random > list.get(i).weight) {
            random -= list.get(i).weight;
            i <<= 1;
            if (random > list.get(i).totalWeight) {
                random -= list.get(i).totalWeight;
                i++;
            }
        }
        double d = list.get(i).weight;
        int i2 = list.get(i).value;
        list.get(i).weight = 0.0d;
        while (i > 0) {
            list.get(i).totalWeight -= d;
            i >>= 1;
        }
        return i2;
    }
}
