package cn.com.duiba.nezha.alg.common.model.advertexplore;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class AdExploreUcb {

    private static double calWilsonScore(Long exposeCnt, Long clickCnt) {
        // 计算 wilson 置信分数，取99%置信度
        if (exposeCnt == null || clickCnt == null || exposeCnt == 0) {
            return 0.0;
        }
        double ratio = clickCnt * 1.0 / exposeCnt;
        double faithLevel = 0.95;
        double faithSquare = faithLevel * faithLevel;
        return (ratio + (faithSquare / (2 * exposeCnt)) -
                faithLevel * Math.sqrt(4 * exposeCnt * ratio * (1 - ratio) + faithSquare)
                        / (2 * exposeCnt)) / (1 + faithSquare / exposeCnt);
    }

    public static Double concatUcb(Double val1, Double val2,
                                   Double explore1, Double explore2,
                                   List<Double> alpha) {
        return (val1 == null || val2 == null || explore1 == null || explore2 == null
                ) ? 0.0 : alpha.get(0) * val1 + alpha.get(1) * val2 +
                            alpha.get(2) * explore1 + alpha.get(3) * explore2;
    }

    public static Double division(Long v1, Long v2) {
        return (v1 != null && v2 != null && v2 != 0)  ? v1.doubleValue() / v2.doubleValue() : 0.0;
    }

    public static Double division(Double v1, Long v2) {
        return (v1 != null && v2 != null && v2 != 0)  ? v1 / v2.doubleValue() : 0.0;
    }

    public static Double dot(Long v1, Long v2) {
        return (v1 != null && v2 != null)  ? v1.doubleValue() * v2.doubleValue() : 0.0;
    }

    public static Double dot(Double v1, Double v2) {
        return (v1 != null && v2 != null)  ? v1 * v2 : 0.0;
    }

    public static Long add(Long v1, Long v2) {
        return (v1 != null && v2 != null)  ? v1 + v2 : 0;
    }

    public static Double add(Double v1, Double v2) {
        return (v1 != null && v2 != null)  ? v1 + v2 : 0;
    }

    public static Double subtract(Double v1, Double v2) {
        return (v1 != null && v2 != null)  ? v1 - v2 : 0;
    }

    public static Double log(Long v1) {
        return (v1 != null)  ? Math.log(v1) : 0;
    }

    public static Double log10(Long v1) {
        return (v1 != null)  ? Math.log10(v1) : 0;
    }

    /**
     *
     * @param app3dRt 媒体 3天+实时数据
     * @param adExploreUcbData ucb探索 3天+实时数据
     * @param params ucb超参
     *
     * 按流量维度调用
     * @return {广告行业ID：ucb}
     */
    public static Map<Long, UcbResult> calcUcb(ExploreData app3dRt,
                                                  AdExploreUcbData adExploreUcbData,
                                                  ExploreParams params
                                                  ) throws IllegalArgumentException{
        Map<Long, UcbResult> tradeUcb = new HashMap<>();
        // 计算媒体的 sqrt(ln(total) / total_app)
        double appWeight = Math.sqrt(log10(adExploreUcbData.getTestExpose3dRt())
                        / add(adExploreUcbData.getAppTestExpose3dRt(), 1L));

        if (app3dRt != null) {
            // 计算媒体ctr
//            // 对媒体曝光数量 < params.AppExposeCnt 的进行补偿
//            double appCtr = division(app3dRt.getClick(), app3dRt.getExposure()) +
//                    (app3dRt.getExposure() > params.getAppExposeCnt() ?
//                            0.0 : subtract(0.5 / (log10(app3dRt.getExposure()) + 1),
//                            params.getAppSmooth()));
            // 使用wilsonScore
            double appCtr = calWilsonScore(app3dRt.getExposure(), app3dRt.getClick());

            // 计算媒体cvr
            Map<Integer, Double> appConv = new HashMap<>();
            Map<Integer, Long> appCvr = (app3dRt.getConv() != null) ? app3dRt.getConv() : new HashMap<>();
            for (int i=0; i < 9; i++) {
                if (!appCvr.containsKey(i)) {
                    appCvr.put(i, 0L);
                }
            }
            for (Map.Entry<Integer, Long> objConv: appCvr.entrySet()) {
                appConv.put(objConv.getKey(),
                        calWilsonScore(app3dRt.getClick(), objConv.getValue()));
            }

            // 计算所有广告主
            Map<Long, ExploreData> acc3dRt = adExploreUcbData.getAppAcc3dRt();
            if (acc3dRt != null) {
                for (Map.Entry<Long, ExploreData> at3dRt : acc3dRt.entrySet()) {
                    List<Double> ucbAlpha1 =  (at3dRt.getValue().getExposure() != null &&
                            at3dRt.getValue().getExposure() > params.getUcbAlphaCnt1()) ?
                            params.getUcbAlpha1() : params.getUcbAlpha1start();
                    List<Double> ucbAlpha2 =  (at3dRt.getValue().getClick() != null &&
                            at3dRt.getValue().getExposure() > params.getUcbAlphaCnt2()) ?
                            params.getUcbAlpha2() : params.getUcbAlpha2start();
                    // 若超参数为空，则报错
                    if (ucbAlpha1 == null || ucbAlpha1.size() != 4
                            || ucbAlpha2 == null || ucbAlpha2.size() != 4 ) {
                        throw new IllegalArgumentException("Params is illegal");
                    }

                    // 计算当前媒体&广告主 sqrt(ln(total_app) / total_app&acc) / tj
                    double accWeight = Math.sqrt(division(
                            log10(adExploreUcbData.getAppTestExpose3dRt()),
                            add(adExploreUcbData.getAppAccTestExpose3dRt().get(at3dRt.getKey()),
                                    1L)));

                    // 计算媒体 f(ctr, 媒体&广告主 ctr)
//                    // 对媒体&广告主曝光数量 < params.AppExposeCnt 的进行补偿
//                    double concatCtr = concatUcb(appCtr,
//                            add(division(at3dRt.getValue().getClick(),
//                                    at3dRt.getValue().getExposure()),
//                                    at3dRt.getValue().getExposure() > params.getAppAccExposeCnt() ?
//                                        0.0 : subtract(0.5 / (log10(at3dRt.getValue().getExposure()) + 1.0),
//                                            params.getAppAccSmooth())), appWeight, accWeight, ucbAlpha1
//                    );
                    // 使用wilsonScore
                    double concatCtr = concatUcb(appCtr, calWilsonScore(app3dRt.getExposure(),
                            app3dRt.getClick()), appWeight, accWeight, ucbAlpha1
                    );

                    Map<Integer, Double> result = new HashMap<>();
                    Map<Integer, Double> resultS1 = new HashMap<>();
                    Map<Integer, Double> resultS2 = new HashMap<>();

                    Map<Integer, Long> subAppCvr = (at3dRt.getValue().getConv() != null) ? at3dRt.getValue().getConv() : new HashMap<>();
                    for (int i=0; i < 9; i++) {
                        if (!subAppCvr.containsKey(i)) {
                            subAppCvr.put(i, 0L);
                        }
                    }

                    for (Map.Entry<Integer, Long> objConv: subAppCvr.entrySet()) {
                        // 计算媒体 f(cvr, 媒体&广告主 cvr)
                        double concatCvr = concatUcb(appConv.get(objConv.getKey()),
                                calWilsonScore(at3dRt.getValue().getClick(), objConv.getValue()),
                                appWeight, accWeight, ucbAlpha2

                        );

                        // 计算concat ucb, 放入ucbResult
                        result.put(objConv.getKey(), concatCtr * concatCvr);
                        // 记录 concatCtr 和 concatCvr
                        resultS1.put(objConv.getKey(), concatCtr);
                        resultS2.put(objConv.getKey(), concatCvr);

                    }
                    UcbResult ucbResult = new UcbResult();
                    ucbResult.setObjUcb(result);
                    ucbResult.setObjUcbScore1(resultS1);
                    ucbResult.setObjUcbScore2(resultS2);
                    // <广告主ID, ucbResult>
                    tradeUcb.put(at3dRt.getKey(), ucbResult);
                }
            }
        }

        return tradeUcb;
    }


}
