package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.constant.DPAConstant;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.PriorityBlockingQueue;

import static cn.com.duiba.nezha.alg.common.model.activityrecommend.WilsonInterval.wilsonCalc;

/**
 * @author lijingzhe
 * @description 推荐引擎-召回
 * @date 2020/6/17
 */
public class PackageRecall {
    private static final Logger logger= LoggerFactory.getLogger(PackageRecall.class);

    private static Comparator<PackageRecallDo> iComparator = new Comparator<PackageRecallDo>() {
        @Override
        public int compare(PackageRecallDo o1, PackageRecallDo o2) {
            return o2.getMatchScore() > o1.getMatchScore() ? 1 : -1;
        }
    };

    /**
     * @author: lijingzhe
     * @date: 2020/6/18
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>皮肤召回 5*40->5*1+5, 备注：PackageIdDo.skinId非空,PackageIdDo.(titleId,subTitleId,prizeId)为空
     */
    public static List<PackageRecallDo> matchSkin(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.SKINTOPN, DPAConstant.RECALLSKIN, DPAConstant.SKINTAGN);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>奖品召回 350->50, 备注：PackageIdDo.prizeId非空,PackageIdDo.(titleId,subTitleId,skinId)为空
     */
    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.PRIZETOPN, DPAConstant.RECALLPRIZE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>主标题+副标题召回 130->2, 备注: PackageIdDo.titleId非空,PackageIdDo.(prizeId,subTitleId,skinId)为空
     */
    public static List<PackageRecallDo> matchTitle(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.TITLETOPN, DPAConstant.RECALLTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/21
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>副标题召回 2*130->2*2, 备注: PackageIdDo.subTitleId非空,PackageIdDo.(prizeId,titleId,skinId)为空
     */
    public static List<PackageRecallDo> matchSubTitle(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.SUBTITLETOPN, DPAConstant.RECALLSUBTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>校验合规的[皮肤,奖品,主标题,副标题] X(X<=10*10*7)->50
     */
    public static List<PackageRecallDo> matchPackageAct(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.TOPK, DPAConstant.RECALLPACKAGE);
//        return matchPackage(pks, DPAConstant.TITLETOPN, DPAConstant.TITLETOPN, DPAConstant.RECALLTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks, topN, topK]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: 组件召回, 威尔逊置信下限评估+冷启动扶持
     */
    public static List<PackageRecallDo> matchPackage(List<PackageInfoDo> pks, int topN, int topK, int recallType) {
        if(AssertUtil.isAnyEmpty(pks)){
            logger.error("PackageRecall matchTitle input params is null");
            return null;
        }
        List<PackageRecallDo> result = new ArrayList<>();
        Queue<PackageRecallDo> pkrs = new PriorityBlockingQueue<>(pks.size(), iComparator);
        Set<PackageIdDo> ids = new HashSet<>();
        int supportIdLen = 0;
        for (PackageInfoDo pk : pks) {
            PackageRecallDo pkrd = new PackageRecallDo();
            pkrd.setPackageInfoDo(pk);
            // 冷启动扶持
            if(pk.getRequest() < DPAConstant.CONFIDENTREQUEST && System.currentTimeMillis() - pk.getPackageIdDo().getCreateTime() < 3 * 24 * 3600 * 100){
                if(Math.random() < DPAConstant.PROB && supportIdLen < topK){
                    result.add(pkrd);
                    ids.add(pk.getPackageIdDo());
                    supportIdLen += 1;
                }
            }else{
                double matchScore = wilsonCalc(pk.getCost()/100, pk.getRequest()).lowerBound;
                pkrd.setMatchScore(matchScore);
                pkrd.setMatchType(recallType);
                pkrs.add(pkrd);
            }
        }

        for (int i = 0; i < pkrs.size(); i++) {
            PackageRecallDo pr = pkrs.poll();
            if(result.size() < topN && !ids.contains(pr.getPackageInfoDo().getPackageIdDo())){
                result.add(pr);
            }
        }
        return result;
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/21
     * @methodParameters: [pks, args]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: UCB
     */
    public static List<PackageRecallDo> ucbPackage(List<PackageInfoDo> pks, int... args) {
        List<PackageRecallDo> result = new ArrayList<>();
        if(AssertUtil.isAnyEmpty(pks)){
            logger.error("PackageRecall ucbPackage input params is null");
            return null;
        }
        if(args.length >= 2){
            int topN = args[0];
            int recallType = args[1];
            int tagN = args.length >= 3 ? args[2] : 0;

            // 全局
            Queue<PackageRecallDo> pkrs = new PriorityBlockingQueue<>(pks.size(), iComparator);
            Set<PackageIdDo> ids = new HashSet<>();
            Map<Long, Queue<PackageRecallDo>> tagMap = new HashMap<>();
            double gtotal = 0, gbMax = 0, gbMin = 0 ,stotal = 0, bMax = 0, bMin = 0;
            Long tagId = null;
            for (PackageInfoDo pk : pks) {
                gtotal += pk.getGlobalCost();
                stotal += pk.getCost();
                double guvCost = pk.getGlobalRequest() > 0 ? pk.getGlobalCost()/pk.getGlobalRequest(): 0;
                gbMax = Math.max(guvCost, gbMax);
                gbMin = Math.min(guvCost, gbMin);
                double suvCost = pk.getRequest() > 0 ? pk.getCost()/pk.getRequest(): 0;
                bMax = Math.max(suvCost, bMax);
                bMin = Math.min(suvCost, bMin);
            }

            for (PackageInfoDo pk : pks) {
                if(recallType == DPAConstant.RECALLSKIN){
                    tagId = pk.getPackageIdDo().getSkinTagId();
                }else if(recallType == DPAConstant.RECALLPRIZE){
                    tagId = pk.getPackageIdDo().getPrizeTagId();
                }

                PackageRecallDo pkrd = new PackageRecallDo();
                pkrd.setPackageInfoDo(pk);
                // 初始化全局试探概率
                double globalProbTemp = pk.getGlobalCost()/(pk.getGlobalRequest() + 1);
                // 归一化(0,1)
                double globalProb = (globalProbTemp - gbMin)/(gbMax - gbMin);
                // 全局奖励
                double globalBonus = Math.sqrt(2 * Math.log(gtotal)/(pk.getGlobalRequest() + 1));
                double matchScore = DPAConstant.GLOBAL_MERGE_RATE * (globalProb + globalBonus);

                if(pk.getCost() > 0){
                    // 初始化广告位试探概率
                    double slotProbTemp = pk.getCost()/(pk.getRequest() + 1);
                    double prob = (slotProbTemp - bMin)/(bMax - bMin);
                    // 广告位奖励
                    double bonus = Math.sqrt(2 * Math.log(stotal)/(pk.getRequest() + 1));
                    matchScore = prob + bonus + matchScore;
                }
                pkrd.setMatchScore(matchScore);
                pkrd.setMatchType(recallType);
                pkrs.add(pkrd);
                if(tagId != null && tagN > 0){
                    Queue<PackageRecallDo> tpkrs  = tagMap.getOrDefault(tagId, new PriorityBlockingQueue<>(tagN, iComparator));
                    tpkrs.add(pkrd);
                    tagMap.put(tagId, tpkrs);
                }
            }

            // 优先取标签top组件(广告位+全局)
            if(tagMap.size() > 0 ){
                if(tagMap.keySet().size() > topN){
                    Queue<PackageRecallDo> tagNPrs = new PriorityBlockingQueue<>(tagMap.keySet().size(), iComparator);
                    for (Queue<PackageRecallDo> tagPackage : tagMap.values()) {
                        for (int i = 0; i < tagPackage.size() && i < tagN ; i++) {
                            PackageRecallDo pkrd = tagPackage.poll();
                            tagNPrs.add(pkrd);
                        }
                    }
                    for (int i = 0; i < tagNPrs.size() && result.size() < topN; i++) {
                        PackageRecallDo pkrd = tagNPrs.poll();
                        if(!ids.contains(pkrd.getPackageInfoDo().getPackageIdDo())){
                            result.add(pkrd);
                            ids.add(pkrd.getPackageInfoDo().getPackageIdDo());
                        }
                    }
                }else{
                    for (Queue<PackageRecallDo> tagPackage : tagMap.values()) {
                        for (int i = 0; i < tagPackage.size() && i < tagN ; i++) {
                            PackageRecallDo pkrd = tagPackage.poll();
                            if(!ids.contains(pkrd.getPackageInfoDo().getPackageIdDo())){
                                result.add(pkrd);
                                ids.add(pkrd.getPackageInfoDo().getPackageIdDo());
                            }
                        }
                    }
                }
            }

           // 不区分标签填充top组件
            for (int i = 0; i < pkrs.size() && result.size() < topN; i++) {
                PackageRecallDo pkrd = pkrs.poll();
                if(!ids.contains(pkrd.getPackageInfoDo().getPackageIdDo())){
                    result.add(pkrd);
                    ids.add(pkrd.getPackageInfoDo().getPackageIdDo());
                }
            }
        }
        return result;
    }
}