package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;

import java.util.*;
import java.util.concurrent.PriorityBlockingQueue;

/**
 * @author tanyan
 * 该实现适用于k << list.size的情况
 */
public class TopKScoreSumGroupGenerater {
    //设置截断数，防止超时
    private static final int TRUNCATED_COMB_NUM = 250;

    /**
     * topK 得分和最高组件组合
     * @param pkrds 传入的是有序的组件列表
     * @param k 生成组数
     * @param slotNum 每组组件数
     * @return 得分和最高组件组合的列表
     */
    public static List<List<PackageRecallDo>> topKSumGroups(List<PackageRecallDo> pkrds, int k, int slotNum){

        List<List<PackageRecallDo>> result = new ArrayList<>(k);
        if(k < 1){
            return result;
        }else if(k == 1 || pkrds.size() <= slotNum){
            List<PackageRecallDo> top1 = new ArrayList<>(pkrds.subList(0, Math.min(pkrds.size(), slotNum)));
            result.add(top1);
            return result;
        }else if(k == 2 ){
            List<PackageRecallDo> top1 = new ArrayList<>(pkrds.subList(0, slotNum));
            result.add(top1);

            List<PackageRecallDo> top2 = new ArrayList<>(slotNum);
            top2.addAll(pkrds.subList(0, slotNum-1));
            top2.add(pkrds.get(slotNum-1));
            result.add(top2);
            return result;
        }

        Queue<GroupInfo> candidatesQueue = new PriorityBlockingQueue<>(TRUNCATED_COMB_NUM);
        List<int[]> combs = RecallUtil.combination(pkrds.size(), slotNum, TRUNCATED_COMB_NUM);
        for (int[] comb : combs) {
            candidatesQueue.add(new GroupInfo(pkrds, comb));
        }
        k = Math.min(k, candidatesQueue.size());
        for (int i = 0; i < k; i++) {
            result.add(candidatesQueue.poll().group);
        }
        return result;
    }

    /**
     * 生成k个组合，第一个组合是最佳组合，剩下k-1个组合实际从TRUNCATED_COMB_NUM个组合中加权随机抽选，权重为每个组合的召回分和
     * @param pkrds 组件列表
     * @param k 返回组合数
     * @param slotNum 组合长度
     * @return 返回组合
     */
    public static List<List<PackageRecallDo>> randomKFromTopPSumGroups(List<PackageRecallDo> pkrds, int k, int slotNum){
        List<List<PackageRecallDo>> result = new ArrayList<>(k);
        if(k < 1){
            return result;
        }

        Queue<GroupInfo> candidatesQueue = new PriorityBlockingQueue<>(TRUNCATED_COMB_NUM);
        List<int[]> combs = RecallUtil.combination(pkrds.size(), slotNum, TRUNCATED_COMB_NUM);
        for (int[] comb : combs) {
            candidatesQueue.add(new GroupInfo(pkrds, comb));
        }

        List<Double> weights = new ArrayList<>(combs.size());
        List<List<PackageRecallDo>> candidatesList = new ArrayList<>(combs.size());
        int leftK = k;

        //加入最佳得分组
        if(candidatesQueue.size() > 0) {
            List<PackageRecallDo> theBestGroup = candidatesQueue.poll().group;
            result.add(theBestGroup);
            leftK--;
        }

        int leftSize = candidatesQueue.size();

        for (int i = 0; i < leftSize; i++) {
            GroupInfo groupInfo = candidatesQueue.poll();
            weights.add(groupInfo.scoreSum);
            candidatesList.add(groupInfo.group);
        }

        List<Integer> idxs = RecallUtil.randKWithoutReplacement(weights, leftK);
        for (Integer idx : idxs) {
            result.add(candidatesList.get(idx));
        }

        return result;
    }

    /**
     * 生成得分和最高的组合
     * @param sortedPkrds 传入有序的组件列表
     * @param slotNum slotNum应小于组件列表的size
     * @return 得分和最高组合
     */
    private static List<PackageRecallDo> genBestGroup(List<PackageRecallDo> sortedPkrds, int slotNum){
        List<PackageRecallDo> result = new ArrayList<>(slotNum);
        result.addAll(sortedPkrds.subList(0, slotNum));
        return result;
    }


    private List<GroupInfo> genAllCombinationInSmallK(List<PackageRecallDo> sortedPkrds, int smallK, int slotNum){
        List<GroupInfo> result = new ArrayList<>();
        if(smallK <= slotNum){
//            result.add(new ArrayList<>(sortedPkrds));
            return result;
        }
        List<int[]> combs = RecallUtil.combination(sortedPkrds.size(), slotNum, TRUNCATED_COMB_NUM);
        for (int[] comb : combs) {
            result.add(new GroupInfo(sortedPkrds, comb));
        }
        return result;
    }


    private static class GroupInfo implements Comparable{
        Double scoreSum = 0.0;
        List<PackageRecallDo> group;

        public GroupInfo(List<PackageRecallDo> group) {
            this.scoreSum = group.stream().mapToDouble(PackageRecallDo::getMatchScore).sum();
            this.group = group;
        }

        public GroupInfo(List<PackageRecallDo> pkrds, int[] idxes){
            scoreSum = 0.0;
            group = new ArrayList<>();
            for (int idx : idxes) {
                scoreSum += pkrds.get(idx).getMatchScore();
                group.add(pkrds.get(idx));
            }
        }

        @Override
        public int compareTo(Object o) {
            GroupInfo og = (GroupInfo)o;
            return -this.scoreSum.compareTo(og.scoreSum);
        }
    }
}
