package cn.com.duiba.nezha.alg.alg.adx.flowfilter;

import cn.com.duiba.nezha.alg.alg.adx.AdxStatData;
import cn.com.duiba.nezha.alg.alg.adx.rcmd2.*;
import cn.com.duiba.nezha.alg.alg.adx.rtbbid2.AdxConstant;
import cn.com.duiba.nezha.alg.alg.vo.adx.directly.AdxIndexStatsDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.flowfilter.AdxIndexStatDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.flowfilter.FilterReqDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.flowfilter.FilterRetDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.pd.AdxStatsDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.parse.AdxFeatureParse;
import cn.com.duiba.nezha.alg.feature.vo.AdxFeatureDo;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

public class FlowFilterAlg {

    private static final Logger logger = LoggerFactory.getLogger(FlowFilterAlg.class);



    /**
     * 1.构建流量静态特征，无计划创意特征
     */
    public static Map<String,String> getStaticFeatureMap(AdxFeatureDo adxFeatureDo, AdxStatsDo resoAppStats) {

        AdxIndexStatsDo resoApp1HourInfo = AdxStatData.getAdxTimeIndex(resoAppStats, "1hour");
        adxFeatureDo.setResoAppExpCntDay(resoApp1HourInfo.getExpCnt());
        adxFeatureDo.setResoAppClickCntDay(resoApp1HourInfo.getClickCnt());
        adxFeatureDo.setResoAppAdCostDay(resoApp1HourInfo.getAdvertConsume());

        if(AssertUtil.isEmpty(adxFeatureDo)) {
            logger.warn("FlowFilterAlg.getStaticFeatureMap adxFeatureDo is not valid");
        }

        Map<String,String> featureMap = AdxFeatureParse.generateFeatureMapStatic(adxFeatureDo);
        return featureMap;
    }


    /**
     * 2.流量初筛ctr预估
     * model: ctrModel模型, key为TAE:ALGBID:MODEL:adx_mid_ftrl_fm_ctr_v101
     * predictType：PredictType.CTR
     */
    public static Double predict(Map<String,String> staticFeatureMap, Model model, PredictType predictType) throws Exception {

        Double ret = null;

        if(AssertUtil.isAnyEmpty(staticFeatureMap, model, predictType)) {
            logger.warn("FlowFilterAlg.predict params is not valid featureMap:{},predictType:{}", JSON.toJSONString(staticFeatureMap), JSON.toJSONString(predictType));
            return ret;
        }

        Map<FeatureIndex, FeatureMapDo> featureMap = new HashMap<>();
        FeatureIndex featureIndex = new FeatureIndex(1L, 1L);
        FeatureMapDo featureMapDo = new FeatureMapDo();
        featureMapDo.setStaticFeatureMap(staticFeatureMap);
        featureMap.put(featureIndex, featureMapDo);

        if (PredictType.CTR.equals(predictType)) {
            Map<FeatureIndex, Double> preMap = model.predictCtr(featureMap);

            if (AssertUtil.isEmpty(preMap)) {
                logger.warn("FlowFilterAlg.predict preMap is null");
                return ret;
            }

            ret = preMap.get(featureIndex);
            if (AssertUtil.isEmpty(ret)) {
                logger.warn("FlowFilterAlg.predict preValue is null");
            }

        }
        return ret;
    }


    /**
     * 3.流量初筛
     * 预估 千次广告消耗ecpm = 预估ctr * 统计点击价值 *1000
     * 获取底价basePrice，比较ecpm和basePrice，决定过滤比例
     */
    public static FilterRetDo getFilter(FilterReqDo filterReqDo) {

        FilterRetDo ret = new FilterRetDo();
        try {

            if (!isValid(filterReqDo)) {
                return ret;
            }

            // 资源位+flowPolyId统计数据
            Double statCtr = null, statClickValue = null;
            if (filterReqDo.getFlowPolyIdStatsDo() != null) {
                AdxStatsDo statsDo = filterReqDo.getFlowPolyIdStatsDo();
                statCtr = AdxStatsDo.mergeStatCtr(statsDo, 0.7, 0.2, 0.1);
                statClickValue = AdxStatsDo.mergeStatClickValue(statsDo, 0.7, 0.2, 0.1);
            }

            // 资源位整体统计数据
            AdxIndexStatDo group0 = filterReqDo.getGroupTag0StatsDo();//对比组0
            AdxIndexStatDo group1 = filterReqDo.getGroupTag0StatsDo();//对比组1
            AdxIndexStatDo group2 = filterReqDo.getGroupTag0StatsDo();//实验组2
            AdxIndexStatDo sumStatDo = getSumStatDo(group0, group1, group2);
            if (statCtr == null && (sumStatDo.getExp() > 100 || sumStatDo.getClick() > 50)) {
                statCtr = MathUtil.stdwithBoundary(DataUtil.division(sumStatDo.getClick(), sumStatDo.getExp()),0,1.0);
            }
            if (statClickValue == null && (sumStatDo.getClick() > 50 || sumStatDo.getAdCost() > 2000)) {
                statClickValue = DataUtil.division(sumStatDo.getAdCost(), sumStatDo.getClick());
            }

            // 预估:千次广告消耗ecpm = 预估ctr * 统计点击价值 * 1000
            Double preCtr = filterReqDo.getFilterPreCtr();
            Double ctr = MathUtil.mean(preCtr, statCtr, 0.9);
            if (statClickValue == null) {
                logger.info("FlowFilterAlg.getFilter statClickValue is null, resId = {}", filterReqDo.getResId());
                return ret;
            }
            Double ecpm = ctr *  statClickValue * 1000;

            // 获取底价basePrice
            Double basePrice = filterReqDo.getBasePrice();


            //计算过滤比例
            double p = 0.0;
            Double filterWeight = getFilterWeight(group1, group2);
            if (ecpm < basePrice * 0.5) {
                p = 0.8;
            } else if (ecpm < basePrice * 0.7) {
                p = 0.5;
            } else if (ecpm < basePrice * 0.9) {
                p = 0.2;
            }
            p = p * filterWeight;

            //判断是否过滤
            double random = Math.random();
            boolean isFilter = random < p;
            ret.setFilter(isFilter);
            return ret;

        } catch (Exception e) {
            logger.error("FlowFilterAlg.getFilter", e);
        }
        return ret;
    }



    /**
     * 判断入参数是否合法
     */
    public static Boolean isValid(FilterReqDo filterReqDo) {

        if (AssertUtil.isEmpty(filterReqDo)) {
            logger.warn("FlowFilterAlg.getFilter filterReqDo is null");
            return false;
        }

        Integer pmpType = filterReqDo.getPmpType(); //0-RTB，1-PD，2-PDB
        if (AssertUtil.isEmpty(pmpType) || !pmpType.equals(0)) {
            return false;
        }

        //预估ctr
        Double preCtr = filterReqDo.getFilterPreCtr();
        if (AssertUtil.isEmpty(preCtr)) {
            logger.warn("FlowFilterAlg.getFilter filterPreCtr is null, resId = {}, pmpType = {}",
                    filterReqDo.getResId(), filterReqDo.getPmpType());
            return false;
        }

        //底价
        Double basePrice = filterReqDo.getBasePrice();
        if (AssertUtil.isEmpty(basePrice) || basePrice <= 0) {
            logger.warn("FlowFilterAlg.getFilter basePrice is null, resId = {}, appId = {}, slotId = {}",
                    filterReqDo.getResId(), filterReqDo.getAppId(), filterReqDo.getSlotId());
            return false;
        }

        return true;

    }


    /**
     * 估算过滤比例（前提：对比组1，实验组2 切流相同）
     */
    public static double getFilterWeight(AdxIndexStatDo group1, AdxIndexStatDo group2) {

        Long group1BidCnt = AdxIndexStatDo.getBidCnt(group1);
        Long group2BidCnt = AdxIndexStatDo.getBidCnt(group2);

        Double ret = 1.0;

        if (group1BidCnt > 1000L && group2BidCnt > 1000L) {
            double filterRate = MathUtil.stdwithBoundary(DataUtil.division((group1BidCnt - group2BidCnt), group1BidCnt, 3), 0, 1.0);
            ret = filterRate < 0.2 ? 1.1 : (filterRate < 0.5 ? 1.0 : (filterRate < 0.8 ? 0.9 : 0.8));
        }

        return ret;
    }


    /**
     * 各分流组的数据指标合并
     */
    public static AdxIndexStatDo getSumStatDo(AdxIndexStatDo group0, AdxIndexStatDo group1, AdxIndexStatDo group2) {

        AdxIndexStatDo ret = new AdxIndexStatDo();

        Long group0BidCnt = AdxIndexStatDo.getBidCnt(group0);
        Long group1BidCnt = AdxIndexStatDo.getBidCnt(group1);
        Long group2BidCnt = AdxIndexStatDo.getBidCnt(group2);
        ret.setBid(group0BidCnt + group1BidCnt + group2BidCnt);

        Long group0ExpCnt = AdxIndexStatDo.getExpCnt(group0);
        Long group1ExpCnt = AdxIndexStatDo.getExpCnt(group1);
        Long group2ExpCnt = AdxIndexStatDo.getExpCnt(group2);
        ret.setExp(group0ExpCnt + group1ExpCnt + group2ExpCnt);

        Long group0ClickCnt = AdxIndexStatDo.getClickCnt(group0);
        Long group1ClickCnt = AdxIndexStatDo.getClickCnt(group1);
        Long group2ClickCnt = AdxIndexStatDo.getClickCnt(group2);
        ret.setClick(group0ClickCnt + group1ClickCnt + group2ClickCnt);

        Long group0AdxCost = AdxIndexStatDo.getAdxCost(group0);
        Long group1AdxCost = AdxIndexStatDo.getAdxCost(group1);
        Long group2AdxCost = AdxIndexStatDo.getAdxCost(group2);
        ret.setAdxCost(group0AdxCost + group1AdxCost + group2AdxCost);

        Long group0AdCost = AdxIndexStatDo.getAdCost(group0);
        Long group1AdCost = AdxIndexStatDo.getAdCost(group1);
        Long group2AdCost = AdxIndexStatDo.getAdCost(group2);
        ret.setAdCost(group0AdCost + group1AdCost + group2AdCost);

        return ret;
    }

}
