package cn.com.duiba.nezha.alg.common.model.advertexplore;


import cn.com.duiba.nezha.alg.common.model.advertexplore.expcontroller.BidPidController;
import cn.com.duiba.nezha.alg.common.model.advertexplore.expcontroller.CvrPidController;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;

import java.util.*;
import java.util.stream.Collectors;

public class AdvertExploreV1 {

    /**
     * 探索接口
     *
     * @param exploreInfos 可探索配置列表
     * @param params       参数
     * @return 探索结果
     */
    public static ExploreV1Result explore(
            // 配置信息
            List<ExploreInfo> exploreInfos,
            // 参数
            ExploreV1Params params
    ) {

        if (AssertUtil.isAnyEmpty(exploreInfos, params)) {
            return null;
        }

        List<ExploreV1PkgData> resList = new ArrayList<>();

        HashMap<ExploreV1Type, Integer> exploreTypeMap = new HashMap<>();

        // 劣质媒体过滤
        boolean isKeepMedia = mediaFilter(exploreInfos.get(0).getAppDataInfo().getData3Day(), params);

        if (!isKeepMedia) { // 该媒体上所有配置均放弃
            exploreTypeMap.put(ExploreV1Type.G0, exploreInfos.size());
        } else {

            // 广告曝光过滤
            Map<ExploreV1Type, List<ExploreInfo>> exposureFilterMap = exploreInfos.stream().collect(Collectors.groupingBy(p -> {
                ExploreData data = p.getAppAdvertInfo().getData3day();
                // 风险：「媒体&配置」维度的成本偏差计算使用的是「媒体&广告」维度的数据
                return exposureFilter(data.getExposure(), data.getConv().getOrDefault(p.getConvertType(), 0L),
                        data.getOcpcConsume(), data.getCostBias(), params);
            }));

            // TODO：黑名单设置

            // 统计曝光过滤配置数量、利用配置数量
            exploreTypeMap.put(ExploreV1Type.G2, exposureFilterMap.getOrDefault(ExploreV1Type.G2, new ArrayList<>(0)).size());
            List<ExploreV1PkgData> etList = exposureFilterMap.getOrDefault(ExploreV1Type.ET, new ArrayList<>(0)).stream().map(p -> new ExploreV1PkgData(p.getOrientationId(), p.getAdvertId(), 1, p.getTarget(), -1D, -1D)).collect(Collectors.toList());
            exploreTypeMap.put(ExploreV1Type.ET, etList.size());

            // 剩余探索部分，接着走后面的过滤逻辑
            exploreInfos = exposureFilterMap.getOrDefault(ExploreV1Type.EE, new ArrayList<>(0));

            // 探索配置过滤
            Map<ExploreV1Type, List<ExploreInfo>> exploreFilterMap = exploreInfos.stream().collect(Collectors.groupingBy(p -> {
                AppAdvertInfo appAdvertInfo = p.getAppAdvertInfo();
                return exploreFilter(
                        appAdvertInfo.getAdvertExploreV1Consume(),
                        appAdvertInfo.getAdvertOcpcConsume(), params);
            }));

            // 统计消耗过滤配置数量
            exploreTypeMap.put(ExploreV1Type.G1, exploreFilterMap.getOrDefault(ExploreV1Type.G1, new ArrayList<>(0)).size());

            // 剩余探索部分，接着走后面的逻辑
            exploreInfos = exploreFilterMap.getOrDefault(ExploreV1Type.EE, new ArrayList<>(0));

            // 计算融合CTR、CVR
            List<ExploreV1PkgData> explorePkgs = exploreInfos.stream().map(p -> calExploreData(
                    p.getOrientationId(), p.getAdvertId(),
                    p.getAdvertPkBase(), p.getAdvertPkExpV1(),
                    p.getAdvertPkAppBase(), p.getAdvertPkAppExpV1(),
                    p.appDataInfo.data3Day, p.appTradeInfo.data3day, p.appAccountInfo.data3day,
                    p.appAdvertInfo.data3day, p.appOrientExp3day, p.appOrientClick3day, p.appOrientConv3day,
                    p.predictData, p.convertType, p.target, params)
            ).collect(Collectors.toList());

            // 出价过滤
            Map<ExploreV1Type, List<ExploreV1PkgData>> bidFilterMap = explorePkgs.stream().collect(Collectors.groupingBy(p -> bidFilter(p, params)));

            // 出价过滤数量统计
            exploreTypeMap.put(ExploreV1Type.G3, bidFilterMap.getOrDefault(ExploreV1Type.G3, new ArrayList<>(0)).size());

            List<ExploreV1PkgData> eeList = bidFilterMap.getOrDefault(ExploreV1Type.EE, new ArrayList<>(0));
            exploreTypeMap.put(ExploreV1Type.EE, eeList.size());

            // 合并探索列表和利用列表
            resList.addAll(eeList);
            resList.addAll(etList);

        }

        return new ExploreV1Result(params, exploreTypeMap, resList);
    }

    /**
     * 计算配置的探索CTR、CVR
     *
     * @return
     */
    private static ExploreV1PkgData calExploreData(Long orientationId, Long advertId,
                                                   ExpHourData orientBase, ExpHourData orientExpV1,
                                                   ExpHourData mediaOrientBase, ExpHourData mediaOrientExpV1,
                                                   ExploreData media, ExploreData mediaAndTrade, ExploreData mediaAndAccount,
                                                   ExploreData mediaAndAdvert,
                                                   Long mediaOrientExp, Long mediaOrientClick, Map<Integer, Long> mediaOrientConv,
                                                   PredictData predictData,
                                                   Integer convertSubType, Long target, ExploreV1Params params) {
        // 计算各个维度上的ctr、cvr
        // 维度从细到粗，逐个计算权重，共享权重1
        // 实际上模型能捕捉到的最小粒度也是「媒体&广告」，但是这里用「模型&配置」的数据来决定模型权重

        // 统计维度
        media.setId(WeightDimEnum.MEDIA.getId());
        mediaAndTrade.setId(WeightDimEnum.MEDIA_AND_TRADE.getId());
        mediaAndAccount.setId(WeightDimEnum.MEDIA_AND_ACCOUNT.getId());
        mediaAndAdvert.setId(WeightDimEnum.MEDIA_AND_ADVERT.getId());

        ExploreData mediaAndOrient = new ExploreData();
        mediaAndOrient.setId(WeightDimEnum.MEDIA_AND_ORIENT.getId());
        mediaAndOrient.setExposure(mediaOrientExp);
        mediaAndOrient.setClick(mediaOrientClick);
        mediaAndOrient.setConv(mediaOrientConv);

        ExploreData[] statDatas = {mediaAndAdvert, mediaAndAccount, media};


        // 预估维度
        predictData.setId(WeightDimEnum.MODEL.getId());

        // CTR
        WeightInfoList ctrWeightInfos = calCtrWeights(params, predictData, mediaAndOrient, statDatas);
        double mergeCtr = ctrWeightInfos.stream().map(w -> w.getW() * w.getV()).reduce(0.0, Double::sum);

        // CVR
        WeightInfoList cvrWeightInfos = calCvrWeightInfos(params, predictData, mediaAndOrient, convertSubType, statDatas);
        double mergeCvr = cvrWeightInfos.stream().map(w -> w.getW() * w.getV()).reduce(0.0, Double::sum);

        // 成本偏差
        WeightInfo costWeightInfo = CvrPidController.calcControlFactor(orientExpV1,convertSubType,
                params.getCostFactorUpperBound(),params.getCostFactorLowerBound(),params.getKp(),params.getKi(),
                params.getCostControlClickCntThreshold(),WeightDimEnum.CVR_FACTOR);
        mergeCvr *= costWeightInfo.getV();

        // 数量加权
//        WeightInfo ctrNumWeightInfo = numWeight(mediaAndAdvert, params, mergeCtr, true);
//        mergeCtr *= ctrNumWeightInfo.getWeight();
//        WeightInfo ctrNumWeightInfo = new WeightInfo();
//        ctrNumWeightInfo.setId(WeightDimEnum.CTR_NUM.getId());
        // 「媒体+广告」的加权，对小媒体将会产生持续对加权影响
        WeightInfo cvrNumWeightInfo = numWeight(mediaAndAdvert, params, mergeCvr, false);
        mergeCvr *= cvrNumWeightInfo.getW();

        // 竞价成功率因子, 目前只看配置维度
        // 「媒体&配置」维度的竞价成功率 vs 「配置」维度的竞价成功率
        WeightInfo bidWeightInfo = BidPidController.calcControlFactor(mediaOrientBase, orientBase, params.getBidReqCntThreshold(), params.getTargetRatio(),
                params.getBidFactorUpperBound(), params.getBidFactorLowerBound(), params.getBidKp(), params.getBidKi(), WeightDimEnum.BID_FACTOR_MEDIA);

        mergeCvr *= bidWeightInfo.getV();

        // cvr上限
        mergeCvr = boundCvr(mergeCvr, cvrWeightInfos, mediaAndAdvert, predictData, convertSubType, params);

        return new ExploreV1PkgData(orientationId, advertId, 0, target,
                DataUtil.formatdouble(mergeCtr, 6), DataUtil.formatdouble(mergeCvr, 6),
                ctrWeightInfos, cvrWeightInfos,
                costWeightInfo, bidWeightInfo, cvrNumWeightInfo);
    }

    public static WeightInfo numWeight(ExploreData statData, ExploreV1Params params, double value, boolean isCtr) {
        // 数量加权, 数量越少，权重越高，[1,1.2]
        // 「媒体&广告」维度

        // 点击率加权使用曝光量，转化率加权使用点击量
//        Long cnt = isCtr ? statData.getExposure() : statData.getClick();
        Long cnt = statData.getExposure();  // 固定使用曝光量，避免点击率低的情况下，持续给cvr加权
        String dimName = isCtr ? WeightDimEnum.CTR_NUM.getId() : WeightDimEnum.CVR_NUM.getId();

        // calSigmoid 输出值范围：(0, 1)
        double sigmoidValue = calSigmoid(params.getThetaNum(), params.getBetaNum(), cnt);
        // 默认压缩到（1, 1.1）
        double weight = 1 + (1 - sigmoidValue) * params.getNumWeightThreshold();

        return new WeightInfo(dimName, cnt, DataUtil.formatDouble(weight, 4),
                DataUtil.formatDouble(value, 6));
    }

    public static WeightInfo calCostWeight(ExploreData statData, ExploreV1Params params, double value) {
        // 「媒体&广告」维度

        double costWeight;
        // 成本偏差置信数量
        if (statData.getExposure() < params.getCostExpThreshold()) {
            // 曝光数量不置信的情况下，成本偏差权重默认为1
            costWeight = 1.0;
        } else {
            // 离线成本偏差，当考虑了「媒体&广告」
            // 「媒体&配置」维度上数量太少，无法置信

            // 计算线性偏差调整权重，偏差目标为 +2%
            // 离线成本偏差反应太慢，线性调整应该反而会比较温和
            double linerWeight = (0.02 + 1) / (statData.getCostBias() + 1);

            // 权重默认范围限制在[0.9, 1.1]
            costWeight = Math.min(Math.max(linerWeight, 1 - params.getCostWeightThreshold()), 1 + params.getCostWeightThreshold());
        }

        return new WeightInfo(WeightDimEnum.COST.getId(), statData.getExposure(), DataUtil.formatDouble(costWeight, 4),
                DataUtil.formatDouble(value, 4));
    }

    /**
     * 计算CTR各维度权重信息
     *
     * @param params        参数
     * @param predictData   模型预估信息
     * @param modelStatData 模型维度的统计信息（「媒体&配置维度」）
     * @param statDatas     各个统计维度的统计信息
     * @return 各个统计维度的CTR权重信息
     */
    public static WeightInfoList calCtrWeights(ExploreV1Params params, PredictData predictData, ExploreData modelStatData, ExploreData... statDatas) {
        double remainWeight = 1.0;
        double modelWeight = calDimWeight(remainWeight, params.getCtrWeightThreshold(),
                params.getThetaCtr(), params.getBetaCtr(), modelStatData.getExposure());

        // 模型权重信息
        WeightInfo modelWeightInfo = new WeightInfo(predictData.getId(), modelStatData.getExposure(), modelWeight, DataUtil.formatDouble(predictData.getCtr(), 4));
        // 统计权重信息列表
        List<WeightInfo> statWeightInfos = getStatWeightInfos(remainWeight - modelWeight, true, 0, params, statDatas);

        WeightInfoList weightInfos = new WeightInfoList(statWeightInfos.size() + 1);
        weightInfos.add(modelWeightInfo);
        weightInfos.addAll(statWeightInfos);

        return weightInfos.reWeight();
    }

    /**
     * 计算CVR各个维度权重信息
     *
     * @param params         参数
     * @param predictData    模型预估信息
     * @param modelStatData  模型维度统计信息（「媒体&配置维度」）
     * @param convertSubtype 转化类型
     * @param statDatas      各个统计维度的统计信息
     * @return 各个统计维度的CVR权重信息
     */
    public static WeightInfoList calCvrWeightInfos(ExploreV1Params params, PredictData predictData, ExploreData modelStatData, Integer convertSubtype, ExploreData... statDatas) {
        double remainWeight = 1.0;
        double modelWeight = calDimWeight(remainWeight, params.getCvrWeightThreshold(),
                params.getThetaCvr(), params.getBetaCvr(), modelStatData.getClick());

        // 模型权重信息
        WeightInfo modelWeightInfo = new WeightInfo(predictData.getId(), modelStatData.getClick(), modelWeight, DataUtil.formatDouble(predictData.getCvr(), 4));
        // 统计权重信息列表
        List<WeightInfo> statWeightInfos = getStatWeightInfos(remainWeight - modelWeight, false, convertSubtype, params, statDatas);

        WeightInfoList weightInfos = new WeightInfoList(statWeightInfos.size() + 1);
        weightInfos.add(modelWeightInfo);
        weightInfos.addAll(statWeightInfos);

        return weightInfos.reWeight();
    }


    /**
     * 计算各个统计维度的权重信息
     *
     * @param remainWeight   当前剩余权重
     * @param isCtr          是否为ctr
     * @param convertSubtype 转化类型
     * @param params         算法参数
     * @param statDatas      各个统计维度的统计信息
     * @return 各个统计维度的权重信息
     */
    public static List<WeightInfo> getStatWeightInfos(double remainWeight, boolean isCtr, Integer convertSubtype, ExploreV1Params params, ExploreData... statDatas) {
        ArrayList<WeightInfo> weights = new ArrayList<>(statDatas.length);

        if (isCtr) {
            for (ExploreData data : statDatas) {
                double dimWeight = calDimWeight(remainWeight, params.getCtrWeightThreshold(),
                        params.getThetaCtr(), params.getBetaCtr(), data.getExposure());
                remainWeight -= dimWeight;
                weights.add(new WeightInfo(data.getId(), data.getExposure(), dimWeight, data.getCtr()));
            }
        } else {
            for (ExploreData data : statDatas) {
                // 可以考虑：根据后端类型，调整数量阈值
                double dimWeight = calDimWeight(remainWeight, params.getCvrWeightThreshold(),
                        params.getThetaCvr(), params.getBetaCvr(), data.getClick());
                if (data.getId() == WeightDimEnum.MEDIA_AND_TRADE.getId()) {
                    dimWeight = 0.0;
                }
                remainWeight -= dimWeight;
                weights.add(new WeightInfo(data.getId(), data.getClick(), dimWeight, data.getCvr(convertSubtype)));
            }
        }

        return weights;
    }


    /**
     * 权重计算公式
     *
     * @param remainWeight
     * @param weightThreshold
     * @param theta
     * @param beta
     * @param value
     * @return
     */
    // 计算方式为remainWeight*(1/(1+exp(-theta*value+beta)))
    public static double calDimWeight(double remainWeight, double weightThreshold, double theta, double beta, double value) {
        if (remainWeight < 0.001) {
            return remainWeight;
        }
        double res = 0.0;
        double resWeight = 1 / (1 + Math.exp(-theta * value + beta));
        res = remainWeight * resWeight;
        res = Math.min(res, weightThreshold);
        return res;
    }

    public static double calSigmoid(double theta, double beta, double value) {
        return 1 / (1 + Math.exp(-1.0 * theta * value + beta));
    }

    public static double calLogValue(double alpha, double beta, double value) {
        if ((value + beta) <= 0 || alpha <= 0) {
            return 0.0;
        } else {
            return Math.log(value + beta) / Math.log(alpha);
        }


    }

    public static ExploreV1Type bidFilter(ExploreV1PkgData pkgData, ExploreV1Params params) {
        // 出价过滤
        if (pkgData.getTarget() != null && pkgData.getE2() != null && pkgData.getE2() > 0) {
            double bid = pkgData.getE2() * pkgData.getTarget();

            // 预估偏低的情况下，容易被过滤掉
            // 预估偏高的情况下，反而容易出来
            double prob = Math.min(1.0, bid / (params.getBottomFee() * params.getBottomRatio()));

            Random random = new Random();

            return random.nextDouble() < prob ? ExploreV1Type.EE : ExploreV1Type.G3;
        }

        // 探索
        return ExploreV1Type.EE;
    }

    public static boolean isExploit(Long consume, Long convert, Double bias) {
        return consume != null && consume > 20000 // 有消耗
                && convert != null && convert > 10   // 有转化
                && bias != null && Math.abs(bias) < 0.15; // 偏差可控
    }


    public static ExploreV1Type exposureFilter(Long exp, Long convert, Long consume, Double bias, ExploreV1Params params) {
        // 3天 媒体&广告 曝光次数过滤
        // 根据离线成本偏差，决定是否需要利用

        if (exp != null && exp < params.getMediaAdExpThreshold()) { // 小于周期内曝光次数限制，走探索逻辑
            return ExploreV1Type.EE;
        }

        // 曝光次数大于限制，要么利用要么放弃

        // 判断历史成本
        return isExploit(consume, convert, bias) ? ExploreV1Type.ET : ExploreV1Type.G2;
    }

    public static ExploreV1Type exploreFilter(Long exploreConsume, Long consume, ExploreV1Params params) {
        // 当日探索消耗过滤

        // 计算探索消耗占比
        double exploreConsumeRatio = 0.0;
        if (exploreConsume != null && consume != null) {
            exploreConsumeRatio = consume > 0 ? exploreConsume * 1.0 / consume : 0.0;
        }

        if ((exploreConsume != null && exploreConsume > params.getExploreMinConsume() && exploreConsumeRatio > params.getExploreConsumeRatio()) ||
                (exploreConsume != null && exploreConsume > params.getExploreMaxConsume())) {
            // 达到当日消耗探索限制，放弃探索
            return ExploreV1Type.G1;
        }

        // 探索
        return ExploreV1Type.EE;
    }

    public static boolean mediaFilter(ExploreData appDataInfo, ExploreV1Params params) {
        double cpm = 0.0;
        double cvr = 0.0;

        if (appDataInfo.getConsume() > 1000) {
            // 出价偏高的情况下，cpm会偏高，反而会留下哪些偏高的媒体
            // 使用调整cpm？
            cpm = appDataInfo.getExposure() > 0 ? appDataInfo.getConsume() * 1.0 / appDataInfo.getExposure() : 0.0;
        }

        // 有消耗不一定有转化
        // 落地页CVR
        if (appDataInfo.getConv().get(0) != null && appDataInfo.getClick() > 50) {
            cvr = appDataInfo.getClick() > 0 ? appDataInfo.getConv().get(0) * 1.0 / appDataInfo.getClick() : 0.0;
        }

        // 大盘cmp 23 * 0.5
        // 大盘平均转化率 0.12 * 0.3
        double cpmProb = Math.min(1.0, cpm / (params.getAppCpmRatio() * params.getAppCpmBase()));
        double cvrProb = Math.min(1.0, cvr / (params.getAppCvrRatio() * params.getAppCvrBase()));

        // max：cmp偏高的情况下，会有较大的概率出来
        // min：cmp和cvr都比较好的情况下，才会出来
        // 新媒体会被放弃掉（cpm=0 或者 cvr=0）
        double prob = Math.max(params.getAppMinRatio(), Math.min(cpmProb, cvrProb));

        Random random = new Random();

        return random.nextDouble() < prob;
    }

    public static double boundCvr(double mergeCvr, WeightInfoList cvrWeighInfoList, ExploreData mediaAndAdvert,
                                  PredictData predictData, Integer convertSubType, ExploreV1Params params) {
        if (cvrWeighInfoList.getMw() < params.getCvrBoundWeightThreshold()) {
            double predCvr = predictData.getCvr();
            double boundCvr = Math.min(mergeCvr, params.getCvrBoundRatio() * predCvr);
            return boundCvr;
        } else {
            double statCvr = mediaAndAdvert.getCvr(convertSubType);
            double boundCvr = Math.min(mergeCvr, params.getCvrBoundRatio() * statCvr);
            return boundCvr;
        }
    }

    public static void main(String[] args) {
        System.out.println(Long.MAX_VALUE);
    }


}
