package cn.com.duiba.nezha.alg.alg.material;

import cn.com.duiba.nezha.alg.alg.vo.material.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public class MaterialMatch {

    private static final Logger logger = LoggerFactory.getLogger(MaterialMatch.class);

    /**
     * 召回
     *
     * @param materialRecallDo
     * @return MaterialExtractDo （包括ctr top100优质及试投内容池）
     */
    public static MaterialExtractDo match(MaterialRecallDo materialRecallDo) {
        return onlineVersion(materialRecallDo);
    }

    /*
     * @description 随机策略，用于在初始版本上线，获取训练样本
     * @return 返回值说明
     * @date 2020/7/21
     */
    private static MaterialExtractDo randomStrategy(List<MaterialStatInfo> materialList) {
        MaterialExtractDo materialExtractDo = new MaterialExtractDo();

        // 随机策略，将可投素材放进作为优质素材池,ctr直接设置为0；同时试投池设置为空，不用做试投
        List<MaterialMatchDo> rawMaterialMatchDoList = new ArrayList<>(materialList.size() + 10);
        List<Long> materialExposeDoList = new ArrayList<>();
        materialList.forEach(materialStatInfo -> {
            MaterialMatchDo materialMatchDo = new MaterialMatchDo(materialStatInfo.getMaterialId(), 0.0);
            rawMaterialMatchDoList.add(materialMatchDo);
        });
        materialExtractDo.setMaterialMatchDoList(rawMaterialMatchDoList);
        materialExtractDo.setMaterialExposeDoList(materialExposeDoList);
        return materialExtractDo;
    }

    /*
     * @description 线上真实版本
     * @return 返回值说明
     * @date 2020/7/21
     */
    private static MaterialExtractDo onlineVersion(MaterialRecallDo materialRecallDo) {
        List<MaterialStatInfo> materialList = materialRecallDo.getMaterialList() ;
        List<MaterialMatchDo> rawMaterialMatchDoList = new ArrayList<>(600);
        List<Long> materialExposeDoList = new ArrayList<>(600);
        MaterialExtractDo materialExtractDo = new MaterialExtractDo();
        double exposeMaxCount = 500;
        int materialSize = materialList.size();
        int shortCutSize = 100;
        // 可投素材池的个数少于100，直接全量作为优质素材池,ctr直接设置为0 ；同时试投池设置为空，不用做试投
        if (materialSize <= shortCutSize) {
            materialList.forEach(materialStatInfo -> {
                MaterialMatchDo materialMatchDo = new MaterialMatchDo(materialStatInfo.getMaterialId(), 0.0);
                rawMaterialMatchDoList.add(materialMatchDo);
            });
            materialExtractDo.setMaterialMatchDoList(rawMaterialMatchDoList);
            // 空试投
            materialExtractDo.setMaterialExposeDoList(materialExposeDoList);
            // 空离线召回
            materialExtractDo.setMaterialCostMatchDoList(new ArrayList<>());
            return materialExtractDo;
        }

        // 可投素材池大于100，筛选ctr top100素材池 + 试投素材池
        for (MaterialStatInfo materialStatInfo : materialList) {
            double wilsonScore = calWilsonScore(materialStatInfo.getExposeCnt(), materialStatInfo.getClickCnt());
            MaterialMatchDo materialMatchDo = new MaterialMatchDo(materialStatInfo.getMaterialId(), wilsonScore);
            rawMaterialMatchDoList.add(materialMatchDo);
            // 试投池的逻辑为：最近三天的曝光量少于exposeMaxCount
            if (materialStatInfo.getExposeCnt() <= exposeMaxCount) {
                materialExposeDoList.add(materialStatInfo.getMaterialId());
            }
        }
        // 根据ctr排序做截断
        int ctrSize = shortCutSize / 2;
        List<MaterialMatchDo> filterMaterialMatchDoList = rawMaterialMatchDoList
                .stream()
                .sorted(Comparator.comparing(MaterialMatchDo::getCtr).reversed())
                .limit(ctrSize).collect(Collectors.toList());


        materialExtractDo.setMaterialMatchDoList(filterMaterialMatchDoList);
        materialExtractDo.setMaterialExposeDoList(materialExposeDoList);
        materialExtractDo.setMaterialCostMatchDoList(materialRecallDo.getMaterialCostMatchDoList());
        return materialExtractDo;
    }

    /*
     * @description 计算calWilson分数
     * @return 返回值说明
     * @date 2020/7/21
     */
    private static double calWilsonScore(long exposeCnt, long clickCnt) {
        // 计算 wilson 置信分数，取99%置信度
        if (exposeCnt == 0) {
            return 0.0;
        }
        double ratio = clickCnt * 1.0 / exposeCnt;
        double faithLevel = 0.99;
        double faithSquare = faithLevel * faithLevel;
        return (ratio + (faithSquare / (2 * exposeCnt)) - faithLevel * Math.sqrt(4 * exposeCnt * ratio * (1 - ratio) + faithSquare) / (2 * exposeCnt))
                / (1 + faithSquare / exposeCnt);
    }


}
