package cn.com.duiba.nezha.alg.alg.adx.rcmd2;

import cn.com.duiba.nezha.alg.alg.base.Roulette;
import cn.com.duiba.nezha.alg.alg.vo.adx.pd.AdxStatsDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.rcmd2.*;
import cn.com.duiba.nezha.alg.alg.vo.adx.rcmd2.AdxIdeaFeatureDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.rtb2.AdxFactorDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.rtb2.AdxFactorReqDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.feature.vo.AdxFeatureDo;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

public class AdxRcmd {

    private static final Logger logger = LoggerFactory.getLogger(AdxRcmd.class);

    public static final double COLD_RATE = 0.25;

    /**
     * 1.粗排
     * 描述：可投<计划，创意> 列表与召回池取交集，按ecpm排序，个数限制 limitSize=20；
     *
     * @param adIdeaDos 参竞列表(not null)
     * @param recallDos 召回池
     * @param limitSize 粗排结果长度(not null)
     * @return 粗排结果
     */
    public static List<AdIdeaDo> rawRank(List<AdIdeaDo> adIdeaDos, List<AdIdeaDo> recallDos, Integer limitSize) {

        try {

            if (AssertUtil.isEmpty(adIdeaDos) || adIdeaDos.size() < 1) {
                logger.warn("AdxRcmd.rawRank adIdeaDos is null");
                return null;
            }

            //参竞列表展平
            List<IdeaUnitDo> ideaUnitList = adIdeaDos.stream().flatMap(ideaUnit ->
                    ideaUnit.getIdeaUnitDos().stream()).collect(Collectors.toList());
            //冷启动状态
            Map<Long, Boolean> coldStartMap = adIdeaDos.stream().collect(Collectors.toMap(AdIdeaDo::getIdeaId, AdIdeaDo::getIsColdStart));

            //召回池为空时，随机选取
            if (AssertUtil.isEmpty(recallDos) || recallDos.size() < 1) {
                logger.warn("AdxRcmd.rawRank recallDos is null，resId is {}, ideaUnitList size is {}", adIdeaDos.get(0).getResId(), ideaUnitList.size());
                return AdxRcmdBase.getRandomList(ideaUnitList, limitSize, coldStartMap);
            }

            //参竞列表&召回池 获取交集
            List<AdIdeaDo> interDos = AdxRcmdBase.getIntersection(adIdeaDos, recallDos);

            //交集个数为空, 随机选取
            if (interDos == null || interDos.size() < 1) {
                logger.warn("AdxRcmd.rawRank getIntersection is null，resId is {}, ideaUnitList size is {}", adIdeaDos.get(0).getResId(), ideaUnitList.size());
                logger.info("AdxRcmd.rawRank getIntersection is null，resId is {}, ideaUnitList size is {}, adIdeaDos is {}, recallDos is {}",
                        adIdeaDos.get(0).getResId(), ideaUnitList.size(), JSON.toJSONString(adIdeaDos), JSON.toJSONString(recallDos));
                return AdxRcmdBase.getRandomList(ideaUnitList, limitSize, coldStartMap);
            }

            List<IdeaUnitDo> interList = interDos.stream().flatMap(ideaUnit -> ideaUnit.getIdeaUnitDos().stream()).collect(Collectors.toList());
            logger.info("AdxRcmd.rawRank getIntersection size is {}, resId is {},", interList.size(), adIdeaDos.get(0).getResId());


            //交集个数>limitSize, 按新创意25%和统计ecpm排序75%进行截断
            if (interList.size() > limitSize) {

                //粗排池分配
                List<IdeaUnitDo> newList = interList.stream().filter(ideaUnitDo -> ideaUnitDo.getIsNew()).collect(Collectors.toList());
                List<IdeaUnitDo> oldList = interList.stream().filter(ideaUnitDo -> !ideaUnitDo.getIsNew()).collect(Collectors.toList());
                long coldSize = AdxRcmdBase.getRecallSize(newList, oldList, limitSize, COLD_RATE, "COLD");
                long bestSize = AdxRcmdBase.getRecallSize(newList, oldList, limitSize, COLD_RATE, "BEST");
                List<IdeaUnitDo> coldList = newList.stream().sorted(Comparator.comparing(IdeaUnitDo::getLast1DayExpCnt)).limit(coldSize).collect(Collectors.toList());
                List<IdeaUnitDo> bestList = oldList.stream().sorted(Comparator.comparing(IdeaUnitDo::getStatAdEcpm).reversed()).limit(bestSize).collect(Collectors.toList());

                List<IdeaUnitDo> interList2 = new ArrayList<>();
                interList2.addAll(coldList);
                interList2.addAll(bestList);
                List<AdIdeaDo> interCandidates = AdxRcmdBase.getCovertList(interList2);
                interCandidates.forEach(s -> s.setIsColdStart(coldStartMap.get(s.getIdeaId())));
                return interCandidates;

//                List<IdeaUnitDo> interList2 = interList.stream().sorted(Comparator.comparing(IdeaUnitDo::getStatAdEcpm, Comparator.nullsLast(Comparator.naturalOrder())).reversed()).
//                        limit(limitSize).collect(Collectors.toList());
            } else {
                return interDos;
            }


        } catch (Exception e) {
            logger.error("AdxRcmd.rawRank error", e);
            return null;
        }

    }




    /**
     * 2.特征生成
     *
     */
    public static Map<FeatureIndex, FeatureMapDo> getFeatureMap(List<AdxIdeaFeatureDo> adxIdeaFeatureDos, AdxFeatureDo adxFeatureDo,
                                                                Map<Long, AdxStatsDo> ideaAppStatsList, AdxStatsDo resoAppStats) {
        return AdxRecommend.getFeatureMap(adxIdeaFeatureDos, adxFeatureDo, ideaAppStatsList, resoAppStats);
    }




    /**
     * 3.模型预估
     */
    public static Map<FeatureIndex, Double> predict(Map<FeatureIndex, FeatureMapDo> featureMap, Model model, PredictType predictType)
            throws Exception {
        return AdxRecommend.predict(featureMap, model, predictType);
    }




    /**
     * 4.出价
     */
    public static List<AdxBidRet> bidding(List<AdIdeaDo> adIdeaDos, AdxAppReqDo adxAppReqDo) {

        List<AdxBidRet> adxBidRets = new ArrayList<>();

        try {
            adxBidRets = adIdeaDos.stream().map(adIdeaDo -> {
                Long ideaId = adIdeaDo.getIdeaId();
                Long resId = adIdeaDo.getResId();
                AdxFactorDo adxfactorDo = Optional.ofNullable(adIdeaDo.getAdxFactorDo()).orElse(new AdxFactorDo());
                return AdxBid.buildAdxBidReq(adIdeaDo, adxfactorDo, adxAppReqDo);
            }).flatMap(Collection::stream).map(AdxBid :: getBid).collect(Collectors.toList());

        } catch (Exception e) {
            logger.error("AdxRcmd.bidding", e);
        }

        return adxBidRets;
    }




    /**
     * 5.精排
     */
    public static List<AdxBidRet> fineRank(List<AdxBidRet> adxBidRets) {

        List<AdxBidRet> fineRankRets = adxBidRets.stream().sorted(Comparator.comparing(AdxBidRet :: getRankScore).reversed()).collect(Collectors.toList());
        AdxBidRet ret = fineRankRets.get(0);
        if (Math.random() < 0.2) {
            Map<AdxBidRet, Double> scoreMap = adxBidRets.stream().collect(Collectors.toMap(s -> s, AdxBidRet::getRankScore));
            ret = Roulette.doubleMap(scoreMap);
        }

        //新创意试投（调整出价大于底价以上）
        if (ret != null && ret.getBasePrice() != null
                && (ret.getIsNew() == null || ret.getIsNew())) {
            Long basePrice = DataUtil.double2Long(ret.getBasePrice());
            ret.setAdxAlgoPrice(Math.max(ret.getAdxAlgoPrice(), (basePrice + 1L)));
        }
        return Arrays.asList(ret);
    }





    /**
     * 6.1 创意召回任务
     * 描述：每个比例区间下，召回创意--冷启动池（top5），优选池（top10）
     * 部署：5min执行1次，不同资源位调用一次算法接口，先执行recallIdea（选创意），再执行recallAdIdea（选计划+创意）
     *      输出按资源位存redis，打印log
     *
     * @param recallReqDo 创意入参
     * @return 创意召回池
     *
     */
    public static AdxRecallRetDo recallIdea(AdxRecallReqDo recallReqDo) {

        AdxRecallRetDo ret = new AdxRecallRetDo();
        try {
            ret = AdxRcmdAlg.recallIdeaRun(recallReqDo);

        } catch (Exception e) {
            logger.error("AdxRcmd.recallIdea", e);
        }
        return ret;
    }




    /**
     * 6.2 计划创意召回任务
     * 描述：每个比例区间下，召回<计划，创意> top50
     * 部署：5min执行1次，不同资源位调用一次算法接口，先执行recallIdea（选创意），再执行recallAdIdea（选计划+创意）
     *      输出按资源位存redis，打印log
     *
     * @param recallReqDo <计划，创意>入参
     * @return <计划，创意>召回池
     *
     */
    public static AdxRecallRetDo recallAdIdea(AdxRecallReqDo recallReqDo) {

        AdxRecallRetDo ret = new AdxRecallRetDo();
        try {
            ret = AdxRcmdAlg.recallAdIdeaRun(recallReqDo);

        } catch (Exception e) {
            logger.error("AdxRcmd.recallAdIdea", e);
        }
        return ret;
    }




    /**
     * 7.计划控制参数任务
     * 描述：ROI控制策略，PID算法，三种维度（计划，计划+联盟媒体，计划+联盟媒体行业），生成roi调节因子；
     *      计划冷启动探价控制，监控近12h的探价竞价成功率和预算消耗程度 ，生成探价因子；
     * 部署：10min执行1次，按 计划+联盟媒体 维度调用算法接口，输出按三种维度控制参数，存redis，打印log
     *
     * @param adxFactorReqDo 入参
     * @return 控制参数
     *
     */
    public static AdxFactorDo adIdeaFactorRun(AdxFactorReqDo adxFactorReqDo) {

        AdxFactorDo ret = adxFactorReqDo.getAdxFactorDo();
        try {
            ret = AdxRcmdAlg.adIdeaFactorRun(adxFactorReqDo);

        } catch (Exception e) {
            logger.error("AdxRcmd.ideaFactorRun", e);
        }
        return ret;
    }




    /**
     * 8.广告位控制参数任务
     * 描述：广告位探量控制，生成广告位探量因子[1,3]（初始值=1.5），探价流量比例[0%,10%]（初始值=5%）
     * 部署：10min执行1次，不同联盟广告位id调用一次算法接口，输出按联盟广告位id存redis，打印log（广告位id较多）
     *
     * @param slotFactorReqDo 入参
     * @return 控制参数
     *
     */
    public static SlotFactorDo slotFactorRun(SlotFactorReqDo slotFactorReqDo) {

        SlotFactorDo ret = slotFactorReqDo.getSlotFactorDo();
        try {
            ret = AdxRcmdAlg.slotFactorRun(slotFactorReqDo);

        } catch (Exception e) {
            logger.error("AdxRcmd.slotFactorRun", e);
        }
        return ret;
    }


}
