/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.dpa.RecallUtil;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.PriorityBlockingQueue;

public class TopKScoreSumGroupGenerater {
    private static final int TRUNCATED_COMB_NUM = 250;

    public static List<List<PackageRecallDo>> topKSumGroups(List<PackageRecallDo> pkrds, int k, int slotNum) {
        ArrayList<List<PackageRecallDo>> result = new ArrayList<List<PackageRecallDo>>(k);
        if (k < 1) {
            return result;
        }
        if (k == 1 || pkrds.size() <= slotNum) {
            ArrayList<PackageRecallDo> top1 = new ArrayList<PackageRecallDo>(pkrds.subList(0, Math.min(pkrds.size(), slotNum)));
            result.add(top1);
            return result;
        }
        if (k == 2) {
            ArrayList<PackageRecallDo> top1 = new ArrayList<PackageRecallDo>(pkrds.subList(0, slotNum));
            result.add(top1);
            ArrayList<PackageRecallDo> top2 = new ArrayList<PackageRecallDo>(slotNum);
            top2.addAll(pkrds.subList(0, slotNum - 1));
            top2.add(pkrds.get(slotNum - 1));
            result.add(top2);
            return result;
        }
        PriorityBlockingQueue<GroupInfo> candidatesQueue = new PriorityBlockingQueue<GroupInfo>(250);
        List<int[]> combs = RecallUtil.combination(pkrds.size(), slotNum, 250);
        for (int[] comb : combs) {
            candidatesQueue.add(new GroupInfo(pkrds, comb));
        }
        k = Math.min(k, candidatesQueue.size());
        for (int i = 0; i < k; ++i) {
            result.add(((GroupInfo)candidatesQueue.poll()).group);
        }
        return result;
    }

    public static List<List<PackageRecallDo>> randomKFromTopPSumGroups(List<PackageRecallDo> pkrds, int k, int slotNum) {
        ArrayList<List<PackageRecallDo>> result = new ArrayList<List<PackageRecallDo>>(k);
        if (k < 1) {
            return result;
        }
        PriorityBlockingQueue<GroupInfo> candidatesQueue = new PriorityBlockingQueue<GroupInfo>(250);
        List<int[]> combs = RecallUtil.combination(pkrds.size(), slotNum, 250);
        for (int[] comb : combs) {
            candidatesQueue.add(new GroupInfo(pkrds, comb));
        }
        ArrayList<Double> weights = new ArrayList<Double>(combs.size());
        ArrayList<List<PackageRecallDo>> candidatesList = new ArrayList<List<PackageRecallDo>>(combs.size());
        int leftK = k;
        if (candidatesQueue.size() > 0) {
            List<PackageRecallDo> theBestGroup = ((GroupInfo)candidatesQueue.poll()).group;
            result.add(theBestGroup);
            --leftK;
        }
        int leftSize = candidatesQueue.size();
        for (int i = 0; i < leftSize; ++i) {
            GroupInfo groupInfo = (GroupInfo)candidatesQueue.poll();
            weights.add(groupInfo.scoreSum);
            candidatesList.add(groupInfo.group);
        }
        List<Integer> idxs = RecallUtil.randKWithoutReplacement(weights, leftK);
        for (Integer idx : idxs) {
            result.add((List<PackageRecallDo>)candidatesList.get(idx));
        }
        return result;
    }

    private static List<PackageRecallDo> genBestGroup(List<PackageRecallDo> sortedPkrds, int slotNum) {
        ArrayList<PackageRecallDo> result = new ArrayList<PackageRecallDo>(slotNum);
        result.addAll(sortedPkrds.subList(0, slotNum));
        return result;
    }

    private List<GroupInfo> genAllCombinationInSmallK(List<PackageRecallDo> sortedPkrds, int smallK, int slotNum) {
        ArrayList<GroupInfo> result = new ArrayList<GroupInfo>();
        if (smallK <= slotNum) {
            return result;
        }
        List<int[]> combs = RecallUtil.combination(sortedPkrds.size(), slotNum, 250);
        for (int[] comb : combs) {
            result.add(new GroupInfo(sortedPkrds, comb));
        }
        return result;
    }

    private static class GroupInfo
    implements Comparable {
        Double scoreSum = 0.0;
        List<PackageRecallDo> group;

        public GroupInfo(List<PackageRecallDo> group) {
            this.scoreSum = group.stream().mapToDouble(PackageRecallDo::getMatchScore).sum();
            this.group = group;
        }

        public GroupInfo(List<PackageRecallDo> pkrds, int[] idxes) {
            this.scoreSum = 0.0;
            this.group = new ArrayList<PackageRecallDo>();
            for (int idx : idxes) {
                this.scoreSum = this.scoreSum + pkrds.get(idx).getMatchScore();
                this.group.add(pkrds.get(idx));
            }
        }

        public int compareTo(Object o) {
            GroupInfo og = (GroupInfo)o;
            return -this.scoreSum.compareTo(og.scoreSum);
        }
    }
}

