package cn.com.duiba.nezha.alg.common.model.advertexplore;


import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;

import java.util.*;
import java.util.stream.Collectors;

public class AdvertExploreV1 {

    /**
     * 探索接口
     *
     * @param exploreInfos 可探索配置列表
     * @param params       参数
     * @return 探索结果
     */
    public static ExploreV1Result explore(
            // 配置信息
            List<ExploreInfo> exploreInfos,
            // 参数
            ExploreV1Params params
    ) {

        if (AssertUtil.isAnyEmpty(exploreInfos, params)) {
            return null;
        }

        List<ExploreV1PkgData> resList = new ArrayList<>();
        int filterNum;

        // 劣质媒体过滤
        boolean isKeepMedia = mediaFilter(exploreInfos.get(0).getAppDataInfo().getData3Day(), params);

        if (!isKeepMedia) { // 该媒体上所有配置均放弃
            filterNum = exploreInfos.size();
        } else {
            // 探索配置过滤
            List<ExploreInfo> exploreOrientList = exploreInfos.stream().filter(p -> exploreFilter(
                    p.appAdvertInfo.advertExploreV1Consume, p.appAdvertInfo.advertOcpcConsume,
                    p.appAdvertInfo.data3day.exposure, params)
            ).collect(Collectors.toList());

            // 计算融合CTR、CVR
            resList = exploreOrientList.stream().map(p -> calExploreData(
                    p.getOrientationId(), p.getAdvertId(),
                    p.appDataInfo.data3Day, p.appTradeInfo.data3day, p.appAccountInfo.data3day,
                    p.appAdvertInfo.data3day, p.appOrientExp3day, p.appOrientClick3day, p.appOrientConv3day,
                    p.predictData, p.convertType, p.target, params)
            ).filter(p -> bidFilter( // 出价放弃
                    p, params)
            ).collect(Collectors.toList());

            filterNum = exploreInfos.size() - resList.size();

        }

        return new ExploreV1Result(params, filterNum, resList);
    }

    /**
     * 计算配置的探索CTR、CVR
     *
     * @return
     */
    private static ExploreV1PkgData calExploreData(Long orientationId, Long advertId,
                                                   ExploreData media, ExploreData mediaAndTrade, ExploreData mediaAndAccount,
                                                   ExploreData mediaAndAdvert,
                                                   Long mediaOrientExp, Long mediaOrientClick, Map<Integer,Long> mediaOrientConv,
                                                   PredictData predictData,
                                                   Integer convertSubType, Long target, ExploreV1Params params) {
        // 计算各个维度上的ctr、cvr
        // 维度从细到粗，逐个计算权重，共享权重1
        // 实际上模型能捕捉到的最小粒度也是「媒体&广告」，但是这里用「模型&配置」的数据来决定模型权重

        // 统计维度
        media.setId(WeightDimEnum.MEDIA.getId());
        mediaAndTrade.setId(WeightDimEnum.MEDIA_AND_TRADE.getId());
        mediaAndAccount.setId(WeightDimEnum.MEDIA_AND_ACCOUNT.getId());
        mediaAndAdvert.setId(WeightDimEnum.MEDIA_AND_ADVERT.getId());

        ExploreData mediaAndOrient = new ExploreData();
        mediaAndOrient.setId(WeightDimEnum.MEDIA_AND_ORIENT.getId());
        mediaAndOrient.setExposure(mediaOrientExp);
        mediaAndOrient.setClick(mediaOrientClick);
        mediaAndOrient.setConv(mediaOrientConv);

        ExploreData[] statDatas = {mediaAndAdvert, mediaAndAccount, mediaAndTrade, media};


        // 预估维度
        predictData.setId(WeightDimEnum.MODEL.getId());

        // CTR
        WeightInfoList ctrWeightInfos = calCtrWeights(params, predictData, mediaAndOrient, statDatas);
        double mergeCtr = ctrWeightInfos.stream().map(w -> w.getW() * w.getV()).reduce(0.0, Double::sum);

        // CVR
        WeightInfoList cvrWeightInfos = calCvrWeightInfos(params, predictData, mediaAndOrient, convertSubType, statDatas);
        double mergeCvr = cvrWeightInfos.stream().map(w -> w.getW() * w.getV()).reduce(0.0, Double::sum);

        // 成本偏差
        WeightInfo costWeightInfo = calCostWeight(mediaAndAdvert, params, mergeCvr);
        mergeCvr *= costWeightInfo.getW();

        // 数量加权
//        WeightInfo ctrNumWeightInfo = numWeight(mediaAndAdvert, params, mergeCtr, true);
//        mergeCtr *= ctrNumWeightInfo.getWeight();
        WeightInfo ctrNumWeightInfo = new WeightInfo();
        ctrNumWeightInfo.setId(WeightDimEnum.CTR_NUM.getId());

        // 「媒体+广告」的加权，对小媒体将会产生持续对加权影响
        WeightInfo cvrNumWeightInfo = numWeight(mediaAndAdvert, params, mergeCvr, false);
        mergeCvr *= cvrNumWeightInfo.getW();

        // cvr上限
        mergeCvr = boundCvr(mergeCvr, cvrWeightInfos, mediaAndAdvert, predictData, convertSubType, params);

        return new ExploreV1PkgData(orientationId, advertId, target,
                DataUtil.formatdouble(mergeCtr, 6), DataUtil.formatdouble(mergeCvr, 6),
                ctrWeightInfos, cvrWeightInfos,
                costWeightInfo, ctrNumWeightInfo, cvrNumWeightInfo);
    }

    public static WeightInfo numWeight(ExploreData statData, ExploreV1Params params, double value, boolean isCtr) {
        // 数量加权, 数量越少，权重越高，[1,1.2]
        // 「媒体&广告」维度

        // 点击率加权使用曝光量，转化率加权使用点击量
//        Long cnt = isCtr ? statData.getExposure() : statData.getClick();
        Long cnt = statData.getExposure();  // 固定使用曝光量，避免点击率低的情况下，持续给cvr加权
        String dimName = isCtr ? WeightDimEnum.CTR_NUM.getId() : WeightDimEnum.CVR_NUM.getId();

        // calSigmoid 输出值范围：(0, 1)
        double sigmoidValue = calSigmoid(params.getThetaNum(), params.getBetaNum(), cnt);
        // 默认压缩到（1, 1.1）
        double weight = 1 + (1 - sigmoidValue) * params.getNumWeightThreshold();

        return new WeightInfo(dimName, cnt, DataUtil.formatDouble(weight, 4),
                DataUtil.formatDouble(value, 6));
    }

    public static WeightInfo calCostWeight(ExploreData statData, ExploreV1Params params, double value) {
        // 「媒体&广告」维度

        double costWeight;
        // 成本偏差置信数量
        if (statData.getExposure() < params.getCostExpThreshold()) {
            // 曝光数量不置信的情况下，成本偏差权重默认为1
            costWeight = 1.0;
        } else {
            // 离线成本偏差，当考虑了「媒体&广告」
            // 「媒体&配置」维度上数量太少，无法置信

            // 计算线性偏差调整权重，偏差目标为 +2%
            // 离线成本偏差反应太慢，线性调整应该反而会比较温和
            double linerWeight = (0.02 + 1) / (statData.getCostBias() + 1);

            // 权重默认范围限制在[0.9, 1.1]
            costWeight = Math.min(Math.max(linerWeight, 1 - params.getCostWeightThreshold()), 1 + params.getCostWeightThreshold());
        }

        return new WeightInfo(WeightDimEnum.COST.getId(), statData.getExposure(), DataUtil.formatDouble(costWeight, 4),
                DataUtil.formatDouble(value, 4));
    }

    /**
     * 计算CTR各维度权重信息
     *
     * @param params        参数
     * @param predictData   模型预估信息
     * @param modelStatData 模型维度的统计信息（「媒体&配置维度」）
     * @param statDatas     各个统计维度的统计信息
     * @return 各个统计维度的CTR权重信息
     */
    public static WeightInfoList calCtrWeights(ExploreV1Params params, PredictData predictData, ExploreData modelStatData, ExploreData... statDatas) {
        double remainWeight = 1.0;
        double modelWeight = calDimWeight(remainWeight, params.getCtrWeightThreshold(),
                params.getThetaCtr(), params.getBetaCtr(), modelStatData.getExposure());

        // 模型权重信息
        WeightInfo modelWeightInfo = new WeightInfo(predictData.getId(), modelStatData.getExposure(), modelWeight, DataUtil.formatDouble(predictData.getCtr(), 4));
        // 统计权重信息列表
        List<WeightInfo> statWeightInfos = getStatWeightInfos(remainWeight - modelWeight, true, 0, params, statDatas);

        WeightInfoList weightInfos = new WeightInfoList(statWeightInfos.size() + 1);
        weightInfos.add(modelWeightInfo);
        weightInfos.addAll(statWeightInfos);

        return weightInfos.reWeight();
    }

    /**
     * 计算CVR各个维度权重信息
     *
     * @param params         参数
     * @param predictData    模型预估信息
     * @param modelStatData  模型维度统计信息（「媒体&配置维度」）
     * @param convertSubtype 转化类型
     * @param statDatas      各个统计维度的统计信息
     * @return 各个统计维度的CVR权重信息
     */
    public static WeightInfoList calCvrWeightInfos(ExploreV1Params params, PredictData predictData, ExploreData modelStatData, Integer convertSubtype, ExploreData... statDatas) {
        double remainWeight = 1.0;
        double modelWeight = calDimWeight(remainWeight, params.getCvrWeightThreshold(),
                params.getThetaCvr(), params.getBetaCvr(), modelStatData.getClick());

        // 模型权重信息
        WeightInfo modelWeightInfo = new WeightInfo(predictData.getId(), modelStatData.getClick(), modelWeight, DataUtil.formatDouble(predictData.getCvr(), 4));
        // 统计权重信息列表
        List<WeightInfo> statWeightInfos = getStatWeightInfos(remainWeight - modelWeight, false, convertSubtype, params, statDatas);

        WeightInfoList weightInfos = new WeightInfoList(statWeightInfos.size() + 1);
        weightInfos.add(modelWeightInfo);
        weightInfos.addAll(statWeightInfos);

        return weightInfos.reWeight();
    }


    /**
     * 计算各个统计维度的权重信息
     *
     * @param remainWeight   当前剩余权重
     * @param isCtr          是否为ctr
     * @param convertSubtype 转化类型
     * @param params         算法参数
     * @param statDatas      各个统计维度的统计信息
     * @return 各个统计维度的权重信息
     */
    public static List<WeightInfo> getStatWeightInfos(double remainWeight, boolean isCtr, Integer convertSubtype, ExploreV1Params params, ExploreData... statDatas) {
        ArrayList<WeightInfo> weights = new ArrayList<>(statDatas.length);

        if (isCtr) {
            for (ExploreData data : statDatas) {
                double dimWeight = calDimWeight(remainWeight, params.getCtrWeightThreshold(),
                        params.getThetaCtr(), params.getBetaCtr(), data.getExposure());
                remainWeight -= dimWeight;
                weights.add(new WeightInfo(data.getId(), data.getExposure(), dimWeight, data.getCtr()));
            }
        } else {
            for (ExploreData data : statDatas) {
                // 可以考虑：根据后端类型，调整数量阈值
                double dimWeight = calDimWeight(remainWeight, params.getCvrWeightThreshold(),
                        params.getThetaCvr(), params.getBetaCvr(), data.getClick());
                if(data.getId()==WeightDimEnum.MEDIA_AND_TRADE.getId()) {
                    dimWeight=0.0;
                }
                remainWeight -= dimWeight;
                weights.add(new WeightInfo(data.getId(), data.getClick(), dimWeight, data.getCvr(convertSubtype)));
            }
        }

        return weights;
    }


    /**
     * 权重计算公式
     *
     * @param remainWeight
     * @param weightThreshold
     * @param theta
     * @param beta
     * @param value
     * @return
     */
    // 计算方式为remainWeight*(1/(1+exp(-theta*value+beta)))
    public static double calDimWeight(double remainWeight, double weightThreshold, double theta, double beta, double value) {
        if (remainWeight < 0.001) {
            return remainWeight;
        }
        double res = 0.0;
        double resWeight = 1 / (1 + Math.exp(-theta * value + beta));
        res = remainWeight * resWeight;
        res = Math.min(res, weightThreshold);
        return res;
    }

    public static double calSigmoid(double theta, double beta, double value) {
        return 1 / (1 + Math.exp(-1.0 * theta * value + beta));
    }

    public static double calLogValue(double alpha, double beta, double value) {
        if ((value + beta) <= 0 || alpha <= 0) {
            return 0.0;
        } else {
            return Math.log(value + beta) / Math.log(alpha);
        }


    }

    public static boolean bidFilter(ExploreV1PkgData pkgData, ExploreV1Params params) {
        if (pkgData.getTarget() != null && pkgData.getE2() != null && pkgData.getE2() > 0) {
            double bid = pkgData.getE2() * pkgData.getTarget();

            double prob = Math.min(1.0, bid / params.getBottomFee() * params.getBottomRatio());

            Random random = new Random();

            return random.nextDouble() < prob;
        }

        return true;
    }

    public static boolean exploreFilter(Long exploreConsume, Long consume, Long exp, ExploreV1Params params) {
        // 曝光数量限制
        if (exp != null && exp > params.getMediaAdExpThreshold()) {
            return false;
        }

        // 探索消耗 & 消耗占比 限制
        double exploreConsumeRatio = 0.0;
        if (exploreConsume != null && consume != null) {
            exploreConsumeRatio = consume > 0 ? exploreConsume * 1.0 / consume : 0.0;
        }

        if ((exploreConsume != null && exploreConsume > params.getExploreMinConsume() && exploreConsumeRatio > params.getExploreConsumeRatio()) ||
                (exploreConsume != null && exploreConsume > params.getExploreMaxConsume())) {
            return false;
        }

        return true;
    }

    public static boolean mediaFilter(ExploreData appDataInfo, ExploreV1Params params) {
        double cpm = 0.0;
        double cvr = 0.0;

        if (appDataInfo.getConsume() > 1000) {
            cpm = appDataInfo.getExposure() > 0 ? appDataInfo.getConsume() * 1.0 / appDataInfo.getExposure() : 0.0;
        }

        // 有消耗不一定有转化
        // 落地页CVR
        if (appDataInfo.getConv().get(0) !=null && appDataInfo.getClick() > 50) {
            cvr = appDataInfo.getClick() > 0 ? appDataInfo.getConv().get(0) * 1.0 / appDataInfo.getClick() : 0.0;
        }

        // 大盘cmp 23 * 0.5
        // 大盘平均转化率 0.12 * 0.3
        double cpmProb = Math.min(1.0, cpm / (params.getAppCpmRatio() * params.getAppCpmBase()));
        double cvrProb = Math.min(1.0, cvr / (params.getAppCvrRatio() * params.getAppCvrBase()));

        double prob = Math.max(params.getAppMinRatio(), Math.max(cpmProb, cvrProb));

        Random random = new Random();

        return random.nextDouble() < prob;
    }

    public static double boundCvr(double mergeCvr, WeightInfoList cvrWeighInfoList, ExploreData mediaAndAdvert,
                                  PredictData predictData,Integer convertSubType, ExploreV1Params params) {
        if (cvrWeighInfoList.getMw() < params.getCvrBoundWeightThreshold()) {
            double predCvr = predictData.getCvr();
            double boundCvr = Math.min(mergeCvr, params.getCvrBoundRatio() * predCvr);
            return boundCvr;
        } else {
            double statCvr = mediaAndAdvert.getCvr(convertSubType);
            double boundCvr = Math.min(mergeCvr, params.getCvrBoundRatio() * statCvr);
            return boundCvr;
        }
    }

    public static void main(String[] args) {
        System.out.println(Long.MAX_VALUE);
    }


}
