package cn.com.duiba.nezha.alg.alg.advert;

import cn.com.duiba.nezha.alg.alg.vo.AdvertRcmdDo;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.FM;
import cn.com.duiba.nezha.alg.model.IModel;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @author yudian
 * 2020-09-22
 */

public class AdvertRcmder {

    private static final Long modelChannelId = 1L;

    /**
     *
     *  广告召回接口
     *
     * @param model             预估通道模型
     * @param featureMap        预估通道特征
     * @param modelThresholdMap 预估通道过滤阈值
     * @return
     * @throws Exception
     */
    public static List<AdvertRcmdDo> rcmds(IModel model, Map<Long, FeatureMapDo> featureMap, Map<Long, Double> modelThresholdMap) throws Exception {
        return modelRcmds(model, featureMap, modelThresholdMap);
    }


    /**
     *  模型预估通道（1）
     *
     * @param model        预估模型
     * @param featureMap   特征
     * @param thresholdMap 过滤阈值
     * @return
     * @throws Exception
     */
    public static List<AdvertRcmdDo> modelRcmds(IModel model, Map<Long, FeatureMapDo> featureMap, Map<Long, Double> thresholdMap) throws Exception {

        // 模型预估
        Map<Long, Double> predicts = model.predictsNew(featureMap);

        // 阈值过滤
        // 新广告在外层使用 ctr * cvr * 1.2填充过滤阈值
        predicts.entrySet()
                .removeIf(entry -> entry.getValue() < thresholdMap.getOrDefault(entry.getKey(), 1D));

        // 排序
        return predicts.entrySet()
                .stream()
                .sorted((e1, e2) -> e2.getValue().compareTo(e1.getValue()))
                .map(entry -> new AdvertRcmdDo(entry.getKey(), entry.getValue(), modelChannelId))
                .collect(Collectors.toList());

    }

    public static void main(String[] args) {
        Map<Long, Double> predicts = new HashMap<>();
        predicts.put(54228L, 0.16908212560386471);
        predicts.put(49968L, 0.1180555555555555);
        predicts.put(52921L, 0.067542706964520347);
        predicts.put(56238L, 0.0083333333333333367);
        predicts.put(59044L, 0.040380047505938238);


        Map<Long, Double> thresholdMap = new HashMap<>();
        thresholdMap.put(54228L, 0.0);
        thresholdMap.put(49968L, 0.0);
        thresholdMap.put(52921L, 0.0);
        thresholdMap.put(56238L, 0.0);

        // 阈值过滤
        // 新广告在外层使用 ctr * cvr填充过滤阈值
        predicts.entrySet()
                .removeIf(entry -> entry.getValue() < thresholdMap.getOrDefault(entry.getKey(), 1D));

        List<AdvertRcmdDo> res = predicts.entrySet()
                .stream()
                .sorted((e1, e2) -> e2.getValue().compareTo(e1.getValue()))
                .map(entry -> new AdvertRcmdDo(entry.getKey(), entry.getValue(), modelChannelId))
                .collect(Collectors.toList());


        for (AdvertRcmdDo re : res) {
            System.out.println(re);
        }
    }

}
