package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.constant.DPAConstant;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

import static cn.com.duiba.nezha.alg.common.model.activityrecommend.WilsonInterval.wilsonCalc;

/**
 * @author lijingzhe
 * @description 推荐引擎-召回
 * @date 2020/6/17
 */
public class PackageRecall {
    private static final Logger logger= LoggerFactory.getLogger(PackageRecall.class);

    private static Comparator<PackageRecallDo> iComparator = new Comparator<PackageRecallDo>() {
        @Override
        public int compare(PackageRecallDo o1, PackageRecallDo o2) {
            return o2.getMatchScore() > o1.getMatchScore() ? 1 : -1;
        }
    };

    /**
     * @author: lijingzhe
     * @date: 2020/6/18
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>皮肤召回 40-1
     */
    public static List<PackageRecallDo> matchSkin(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.SKINTOPN, DPAConstant.RECALLSKIN);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>奖品召回 350->100
     */
    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.PRIZETOPN, DPAConstant.GLOBALRECALLPRIZE, DPAConstant.RECALLPRIZE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>主标题+副标题召回 130->20
     */
    public static List<PackageRecallDo> matchTitle(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.TITLETOPN, DPAConstant.GLOBALRECALLTITLE, DPAConstant.RECALLTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/21
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>副标题召回 20*130->20
     */
    public static List<PackageRecallDo> matchSubTitle(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.GLOBALRECALLSUBTITLE, DPAConstant.RECALLSUBTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: <广告位>校验合规的[皮肤,奖品,主标题,副标题] X(X<=10*10*7)->100
     */
    public static List<PackageRecallDo> matchPackageAct(List<PackageInfoDo> pks){
        return ucbPackage(pks, DPAConstant.TOPK, DPAConstant.RECALLPACKAGE);
//        return matchPackage(pks, DPAConstant.TITLETOPN, DPAConstant.TITLETOPN, DPAConstant.RECALLTITLE);
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/17
     * @methodParameters: [pks, topN, topK]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: 组件召回, 威尔逊置信下限评估+冷启动扶持
     */
    public static List<PackageRecallDo> matchPackage(List<PackageInfoDo> pks, int topN, int topK, int recallType) {
        if(AssertUtil.isAnyEmpty(pks)){
            logger.error("PackageRecall matchTitle input params is null");
            return null;
        }
        List<PackageRecallDo> result = new ArrayList<>();
        Queue<PackageRecallDo> pkrs = new PriorityQueue<>(pks.size(), iComparator);
        Set<PackageIdDo> ids = new HashSet<>();
        int supportIdLen = 0;
        for (PackageInfoDo pk : pks) {
            PackageRecallDo pkrd = new PackageRecallDo();
            pkrd.setPackageInfoDo(pk);
            // 冷启动扶持
            if(pk.getReuqest() < DPAConstant.CONFIDENTREQUEST && System.currentTimeMillis() - pk.getPackageIdDo().getCreateTime() < 3 * 24 * 3600 * 100){
                if(Math.random() < DPAConstant.PROB && supportIdLen < topK){
                    result.add(pkrd);
                    ids.add(pk.getPackageIdDo());
                    supportIdLen += 1;
                }
            }else{
                double matchScore = wilsonCalc(pk.getCost()/100, pk.getReuqest()).lowerBound;
                pkrd.setMatchScore(matchScore);
                pkrd.setMatchType(recallType);
                pkrs.add(pkrd);
            }
        }

        for (int i = 0; i < pkrs.size(); i++) {
            PackageRecallDo pr = pkrs.poll();
            if(result.size() < topN && !ids.contains(pr.getPackageInfoDo().getPackageIdDo())){
                result.add(pr);
            }
        }
        return result;
    }

    /**
     * @author: lijingzhe
     * @date: 2020/6/21
     * @methodParameters: [pks, args]
     * @methodReturnType: java.util.List<cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo>
     * @description: UCB
     */
    public static List<PackageRecallDo> ucbPackage(List<PackageInfoDo> pks, int... args) {
        List<PackageRecallDo> result = new ArrayList<>();
        if(AssertUtil.isAnyEmpty(pks)){
            logger.error("PackageRecall matchTitle input params is null");
            return null;
        }
        if(args.length >= 2){
            int topN = args[0];
            int recallType = args[1];
            int globalRecallType = args.length >= 3 ? args[2] : 0;

            // 全局
            Queue<PackageRecallDo> gpkrs = new PriorityQueue<>(pks.size(), iComparator);
            // 广告位
            Queue<PackageRecallDo> spkrs = new PriorityQueue<>(pks.size(), iComparator);
            Set<PackageIdDo> ids = new HashSet<>();
            int sNum = 0;
            double gtotal = 0, gbMax = 0, gbMin = 0 ,stotal = 0, bMax = 0, bMin = 0;
            for (PackageInfoDo pk : pks) {
                gtotal += pk.getGlobalCost();
                stotal += pk.getCost();
                double guvCost = pk.getGlobalReuqest() > 0 ? pk.getGlobalCost()/pk.getGlobalReuqest(): 0;
                gbMax = Math.max(guvCost, gbMax);
                gbMin = Math.min(guvCost, gbMin);
                double suvCost = pk.getReuqest() > 0 ? pk.getCost()/pk.getReuqest(): 0;
                bMax = Math.max(suvCost, bMax);
                bMin = Math.min(suvCost, bMin);
                if(pk.getCost() > 0){
                    sNum += 1;
                }
            }
            // 广告位/(广告位+全局)候选集占比
            double rate = sNum/pks.size();
            int slotTopN = (int) (rate * topN);

            for (PackageInfoDo pk : pks) {
                PackageRecallDo pkrd = new PackageRecallDo();
                pkrd.setPackageInfoDo(pk);
                // 初始化全局试探概率
                double globalProbTemp =pk.getGlobalCost() > 0 ? pk.getGlobalCost()/pk.getGlobalReuqest() : 0;
                // 归一化(0,1)
                double globalProb = (globalProbTemp - gbMin)/(gbMax - gbMin);
                // 全局奖励
                double globalBonus = pk.getGlobalCost() > 0 ? Math.sqrt(2 * Math.log(gtotal)/pk.getGlobalReuqest()): 0;
                double globalMatchScore = globalProb + globalBonus;
                pkrd.setMatchScore(globalMatchScore);
                pkrd.setMatchType(globalRecallType);
                gpkrs.add(pkrd);
                if(pk.getCost() > 0){
                    // 初始化广告位试探概率
                    double slotProbTemp = pk.getCost() > 0 ? pk.getCost()/pk.getReuqest() : 0;
                    double prob = (slotProbTemp - bMin)/(bMax - bMin);
                    // 广告位奖励
                    double bonus = pk.getCost() > 0 ? Math.sqrt(2 * Math.log(stotal)/pk.getReuqest()): 0;
                    double matchScore = prob + bonus;
                    pkrd.setMatchScore(matchScore);
                    pkrd.setMatchType(recallType);
                    spkrs.add(pkrd);
                }
            }
            if(slotTopN > 0 && Math.random() < (1 - DPAConstant.GLOBAL_W)){
                for (int i = 0; i < spkrs.size() && i < slotTopN; i++) {
                    PackageRecallDo pkrd = spkrs.poll();
                    if(!ids.contains(pkrd.getPackageInfoDo().getPackageIdDo())){
                        result.add(pkrd);
                        ids.add(pkrd.getPackageInfoDo().getPackageIdDo());
                    }
                }
            }else{
                for (int i = 0; i < gpkrs.size() && result.size() < topN; i++) {
                    PackageRecallDo pkrd = gpkrs.poll();
                    if(!ids.contains(pkrd.getPackageInfoDo().getPackageIdDo())){
                        result.add(pkrd);
                        ids.add(pkrd.getPackageInfoDo().getPackageIdDo());
                    }
                }
            }
        }
        return result;
    }
}