package cn.com.duiba.nezha.alg.alg.advertexploitbak;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class AdvertExploit {



    /**
     * 广告利用 - 通道融合接口
     *
     * @param exploitInfo
     * @param matcherOrientMap
     * @param orientInfoMap
     * @param params
     */
    public static ExploitResult exploit(ExploitInfo exploitInfo,
                               Map<MatcherEnum, List<CandidateDto>> matcherOrientMap,
                               Map<String, OrientationDto> orientInfoMap,
                               ExploitParams params) {

        if (AssertUtil.isAnyEmpty(exploitInfo, matcherOrientMap, params)) {
            return null;
        }

        // 实验
        ExploitDto expData = dataMerge(exploitInfo.getExpData(), params);

        // 实验 + 通道
        HashMap<MatcherEnum, ExploitDto> matcherDataMap = new HashMap<>(exploitInfo.getExpMatcherData().size());
        for (Map.Entry<MatcherEnum, ExploitData> entry : exploitInfo.getExpMatcherData().entrySet()) {
            matcherDataMap.put(entry.getKey(), dataMerge(entry.getValue(), params));
        }

        // 实验 + 通道 + 广告位
        HashMap<MatcherEnum, ExploitDto> slotMatcherDataMap = new HashMap<>(exploitInfo.getSlotMatcherData().size());
        for (Map.Entry<MatcherEnum, ExploitData> entry : exploitInfo.getSlotMatcherData().entrySet()) {
            slotMatcherDataMap.put(entry.getKey(), dataMerge(entry.getValue(), params));
        }

        // == 通道效果计算
        // 使用细粒度数据
        // 问题：通道之间可能存在不公平的情况（部分通道消耗不足，使用高维度的平均值）
        HashMap<MatcherEnum, Double> matcherCostMap = new HashMap<>();
        HashMap<MatcherEnum, Double> matcherCpmMap = new HashMap<>();
        for (Map.Entry<MatcherEnum, ExploitDto> entry : matcherDataMap.entrySet()) {
            if (slotMatcherDataMap.containsKey(entry.getKey()) && slotMatcherDataMap.get(entry.getKey()).getOcpcConsume() > params.getConfidentConsume()) {
                // 广告位粒度的通道数据存在，并且消耗大于阈值
                // 使用广告位粒度的 成本偏差 和 cpm
                matcherCostMap.put(entry.getKey(), slotMatcherDataMap.get(entry.getKey()).getCostBias());
                matcherCpmMap.put(entry.getKey(), slotMatcherDataMap.get(entry.getKey()).getCpm());
            } else if (entry.getValue().getOcpcConsume() > params.getConfidentConsume()) {
                // 通道数据存在，且消耗大于阈值
                // 使用通道粒度的 成本偏差 和 cpm
                matcherCostMap.put(entry.getKey(), entry.getValue().getCostBias());
                matcherCpmMap.put(entry.getKey(), entry.getValue().getCpm());
            }
        }

        // == 计算成本和CPM权重
        // 实验成本高时，给成本更高的权重，否则给CPM更高的权重
        double maxCostWeight = params.getCostDefaultWeight();
        double maxCpmWeight = params.getCpmDefaultWeight();

        // 生效条件：3W消耗 && 10W发券
        if (expData.getOcpcConsume() != null && expData.getOcpcConsume() > params.getExpConfidentConsume()
                && expData.getExposure() != null && expData.getExposure() > params.getExpConfidentExposure()) {
            double tmpCostWeight = calSigmoid(params.getThetaCost(), params.getBetaCost(), expData.getCostBias());
            maxCostWeight = Math.min(params.getCostWeightUpperBound(), Math.max(params.getCostWeightLowerBound(), tmpCostWeight));
            maxCpmWeight = 1 - maxCostWeight;
        }

        // == 通道排序 & 计算通道权重
        // TODO 这种方式似乎不是太好
        // 成本越小越好，逆排
        Map<MatcherEnum, MatcherWeightInfo> costWeightInfoMap = getMatcherWeightInfo1(matcherCostMap, maxCostWeight, true);
        // CPM越大越好，正排
        Map<MatcherEnum, MatcherWeightInfo> cpmWeightInfoMap = getMatcherWeightInfo1(matcherCpmMap, maxCpmWeight, false);

        // == 召回 & 合并
        HashMap<MatcherEnum, MatcherWeightInfo> resultWeightInfoMap = new HashMap<>();
        HashMap<String, ResultDto> resultDtoMap = new HashMap<>();

        for (Map.Entry<MatcherEnum, List<CandidateDto>> entry: matcherOrientMap.entrySet()) {
            // 新通道（消耗未达到阈值）可能没在通道统计数据中，使用默认权重
            MatcherWeightInfo costWeightInfo = costWeightInfoMap.getOrDefault(entry.getKey(), new MatcherWeightInfo(params.getMatcherDefaultWeight(), 0));
            MatcherWeightInfo cpmWeightInfo = cpmWeightInfoMap.getOrDefault(entry.getKey(), new MatcherWeightInfo(params.getMatcherDefaultWeight(), 0));

            // 通道最终权重 & 召回数量
            double weight = Math.min(1D, Math.max(params.getMatcherWeightLowerBound(), add(costWeightInfo.getWeight(), cpmWeightInfo.getWeight())));
            int num = (int)(weight * entry.getValue().size());

            resultWeightInfoMap.put(entry.getKey(), new MatcherWeightInfo(weight, num));

            // 采样
            boolean isScored = entry.getKey().isScored();
            List<CandidateDto> matchList = candidateSampling(entry.getValue(), orientInfoMap, num, isScored, params.getNewAdvertFactor());

            // 合并所有通道召回结果
            for (CandidateDto dto: matchList) {
                String key = dto.getKey();
                OrientationDto orientInfo = orientInfoMap.get(key);

                ResultDto resultDto = resultDtoMap.getOrDefault(
                        key,
                        new ResultDto(dto.getAdvertId(), dto.getOrientationId(), new HashMap<>(matcherOrientMap.size()), orientInfo.getNewAdvert(), orientInfo.getExploreFlag(), orientInfo.getPredictBias())
                );

                // 输入通道得分
                if (isScored) {
                    resultDto.getScoreMap().put(entry.getKey(), dto.getScore());
                } else {
                    resultDto.getScoreMap().put(entry.getKey(), null);
                }

                // 放进结果Map
                resultDtoMap.put(key, resultDto);
            }
        }

        return new ExploitResult(resultDtoMap, maxCostWeight, maxCpmWeight, costWeightInfoMap, cpmWeightInfoMap, resultWeightInfoMap);
    }

    /**
     * 候选集采样
     *
     * @param candidates
     * @param orientInfoMap
     * @param samplingNum
     * @param isScored
     * @param newAdvertFactor
     * @return
     */
    private static List<CandidateDto> candidateSampling(List<CandidateDto> candidates, Map<String, OrientationDto> orientInfoMap, int samplingNum, boolean isScored, Double newAdvertFactor) {
        if (samplingNum >= candidates.size()) {
            return candidates;
        }

        // 有分数
        Random random = new Random();
        return candidates.stream().peek(c -> {
            double factor = (orientInfoMap.get(c.getKey()).getNewAdvert() == null || orientInfoMap.get(c.getKey()).getNewAdvert().equals(0)) ? 1 : newAdvertFactor;
            double newScore;
            if (isScored) {
                newScore = multiply(c.getScore(), factor);
            } else {
                // 没有分数的通道，生成随机分数
                newScore = multiply(random.nextDouble(), factor);
            }

            c.setScore(newScore);
        }).sorted(
                // 得分逆向排序
                Comparator.comparing(CandidateDto::getScore).reversed()
        ).collect(Collectors.toList()).subList(0, samplingNum);
    }

    /**
     * 固定比例衰减权重
     *
     * @param map
     * @param maxWeight
     * @param reversed
     * @return
     */
    private static Map<MatcherEnum, MatcherWeightInfo> getMatcherWeightInfo(Map<MatcherEnum, Double> map, Double maxWeight, boolean reversed) {
        return null;
    }

    /**
     * 计算通道线性权重（按排序位置）
     *
     * @param map
     * @param maxWeight
     * @param reversed
     * @return
     */
    private static Map<MatcherEnum, MatcherWeightInfo> getMatcherWeightInfo1(Map<MatcherEnum, Double> map, Double maxWeight, boolean reversed) {
        AtomicInteger index = new AtomicInteger(1);
        Integer matcherSize = map.size();
        Map<MatcherEnum, MatcherWeightInfo> matcherInfoMap = map.entrySet().stream().sorted(
                // 逆向
                reversed ? Collections.reverseOrder(Map.Entry.comparingByValue()) : Map.Entry.comparingByValue()
        ).map(Map.Entry::getKey).collect(
                Collectors.toMap(
                        n -> n,
                        n -> new MatcherWeightInfo(index.getAndIncrement() * divide(maxWeight, matcherSize), 0)
                )
        );

        return matcherInfoMap;
    }

    /**
     * 计算通道线性权重（max-min 归一化）
     *
     * @param map
     * @param maxWeight
     * @param reversed
     * @return
     */
    private static Map<MatcherEnum, MatcherWeightInfo> getMatcherWeightInfo2(Map<MatcherEnum, Double> map, Double maxWeight, boolean reversed) {
        if (map.size() == 0) {
            return new HashMap<>();
        }

        List<Map.Entry<MatcherEnum, Double>> entryList = map.entrySet().stream().sorted(Map.Entry.comparingByValue()).collect(Collectors.toList());
        double minValue = entryList.get(0).getValue();
        double maxValue = entryList.get(entryList.size() - 1).getValue();

        return entryList.stream().collect(
                Collectors.toMap(
                        Map.Entry::getKey,
                        e -> {
                            double tmpWeight = reversed ? 1 - (e.getValue() - minValue) / (maxValue - minValue) : (e.getValue() - minValue) / (maxValue - minValue);
                            return new MatcherWeightInfo(maxWeight * tmpWeight, 0);
                        }
                )
        );
    }

    /**
     * 合并时间维度的统计数据
     *
     * @param data
     * @param params
     * @return
     */
    private static ExploitDto dataMerge(ExploitData data, ExploitParams params) {
        ExploitDto res = new ExploitDto();

        ExploitDto data3day = data.getData3day();
        ExploitDto dataToday = data.getDataToday();
        ExploitDto dataLastHour = data.getDataLastHour();

        res.setExposure(add(data3day.getExposure(), dataToday.getExposure()));
        res.setClick(add(data3day.getClick(), dataToday.getClick()));
        res.setTargetConvert(add(data3day.getTargetConvert(), dataToday.getTargetConvert()));
        res.setConv(mergeMap(data3day.getConv(), dataToday.getConv(), 8));
        res.setOcpcConsume(add(data3day.getOcpcConsume(), dataToday.getOcpcConsume()));
        res.setAdjOcpcConsume(add(data3day.getAdjOcpcConsume(), dataToday.getAdjOcpcConsume()));

        // 权重计算
        // 初始状态下，权重都为0
        double data3dayWeight = data3day.getOcpcConsume() == null || data3day.getOcpcConsume() < params.getConfidentConsume() ? 0D : params.getData3dayWeight();
        double dataTodayWeight = dataToday.getOcpcConsume() == null || dataToday.getOcpcConsume() < params.getConfidentConsume() ? 0D : params.getDataTodayWeight();
        double dataLastHourWeight = dataLastHour.getOcpcConsume() == null || dataLastHour.getOcpcConsume() < params.getConfidentConsume() ? 0D : params.getDataLastHourWeight();

        // 成本偏差加权
        Double data3dayBiasPart = multiply(data3dayWeight, data3day.getCostBias());
        Double dataTodayBiasPart = multiply(dataTodayWeight, dataToday.getCostBias());
        Double dataLastHourBiasPart = multiply(dataLastHourWeight, dataLastHour.getCostBias());
        // 初始状态下，成本偏差为0
        // 消耗少的通道，成本偏差为0，相当于扶持了
        Double costBias = divide(data3dayBiasPart + dataTodayBiasPart + dataLastHourBiasPart, data3dayWeight + dataTodayWeight + dataLastHourWeight);
        res.setCostBias(costBias);

        // cpm加权
        Double data3dayCpmPart = multiply(data3dayWeight, divide(data3day.getOcpcConsume(), data3day.getExposure()));
        Double dataTodayCpmPart = multiply(dataTodayWeight, divide(dataToday.getOcpcConsume(), dataToday.getExposure()));
        Double dataLastHourCpmPart = multiply(dataLastHourWeight, divide(dataLastHour.getOcpcConsume(), dataLastHour.getExposure()));
        // 初始状态下，cpm为0
        Double cpm = divide(data3dayCpmPart + dataTodayCpmPart + dataLastHourCpmPart, data3dayWeight + dataTodayWeight + dataLastHourWeight);
        res.setCpm(cpm);

        return res;

    }

    /**
     * 合并两个map
     * @param map1
     * @param map2
     * @param initCapacity
     * @return
     */
    private static Map<Integer, Long> mergeMap(Map<Integer, Long> map1, Map<Integer, Long> map2, int initCapacity) {

        Map<Integer, Long> resMap = new HashMap<>(initCapacity);

        // 先全部put进行，保证每个key都有，有的会覆盖去重
        resMap.putAll(map1);
        resMap.putAll(map2);

        // 然后遍历这个map,将对应的value改成相加的就好
        resMap.forEach((k, v) -> {
            Long finalV = add(map1.get(k), map2.get(k));
            resMap.put(k, finalV);
        });

        return resMap;
    }

    public static double calSigmoid(double theta, double beta, double value) {
        return 1 / (1 + Math.exp(-1.0 * theta * value + beta));
    }

    private static Long add(Long v1, Long v2) {
        Long v1New = v1 == null ? 0L : v1;
        Long v2New = v2 == null ? 0L : v2;

        return v1New + v2New;
    }

    private static Double add(Double v1, Double v2) {
        Double v1New = v1 == null ? 0D : v1;
        Double v2New = v2 == null ? 0D : v2;

        return v1New + v2New;
    }

    private static Double multiply(Double v1, Double v2) {
        return (v1 != null && v2 != null) ? v1 * v2 : 0D;
    }

    private static Double divide(Long v1, Long v2) {
        return (v1 != null && v2 != null && v2 != 0) ? v1.doubleValue() / v2 : 0D;
    }

    private static Double divide(Double v1, Integer v2) {
        return (v1 != null && v2 != null && v2 != 0) ? v1 / v2 : 0D;
    }

    private static Double divide(Double v1, Double v2) {
        return (v1 != null && v2 != null && v2 != 0) ? v1 / v2 : 0D;
    }

    public static void main(String[] args) {
        Map<MatcherEnum, Double> map = new HashMap<>();
//        map.put(MatcherEnum.EXPLORE, 1.8);
//        map.put(MatcherEnum.TRUSTEESHIP, 2.0);
//        map.put(MatcherEnum.DSSM, 0.9);
//        map.put(MatcherEnum.PREFER_C_7D, 0.1);
//        map.put(MatcherEnum.PREFER_V_3D, -0.3);


        Map<MatcherEnum, MatcherWeightInfo> resMap = getMatcherWeightInfo(map, 0.5, true);
        for (Map.Entry<MatcherEnum, MatcherWeightInfo> entry : resMap.entrySet()) {
            System.out.println(entry.getKey() + ": " + entry.getValue());
        }

    }

}
