package cn.com.duiba.nezha.alg.alg.adx.rta;
import cn.com.duiba.nezha.alg.alg.vo.adx.rta.AccountIdDo;
import cn.com.duiba.nezha.alg.alg.vo.adx.rta.RtaBidRet;
import cn.com.duiba.nezha.alg.alg.vo.adx.rta.RtaRoiControlModel;
import cn.com.duiba.nezha.alg.alg.vo.adx.rta.RtaRoiControlParams;
import cn.com.duiba.nezha.alg.api.model.E2ELocalTFModel;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.parse.RtaFeatureParse;
import cn.com.duiba.nezha.alg.feature.vo.RtaFeatureDo;
import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

/**
 * Created by pzx on 2022/1/12.
 */
public class RtaAlgoService {
    private static final Logger logger = LoggerFactory.getLogger(RtaAlgoService.class);

    /**
     * 1.构建特征
     * 描述：构建可投账户的特征(当前仅有用户侧特征)
     * @param rtaFeatureDo 静态特征
     * @param accountList 开启算法出价的账户列表
     *
     * @return <账户ID, 特征集合>
     */
    public static Map<Long, Map<String, String>> getFeatureMap(RtaFeatureDo rtaFeatureDo, List<Long> accountList) {


        Map<Long, Map<String, String>> featureMap = new HashMap<>();

        // 1 流量特征：静态特征解析
        Map<String, String> staticFeatureMap = RtaFeatureParse.generateFeatureMapStatic(rtaFeatureDo);

        // 2 账户特征：动态特征解析
        final Long accountId = 1L;

        // 3 封装计划特征集合
        featureMap.put(accountId, staticFeatureMap);

        return featureMap;
    }


    /**
     * 2.模型预估
     * 描述：账户维度的指标预估
     * @param featureMap 构建算法出价账户的特征集合
     * @param model TF模型
     * @param accountList 开启算法出价的账户列表
     *
     * @return <账户ID, 预估值>
     */
    public static Map<Long, Double> predict(Map<Long, Map<String, String>> featureMap, E2ELocalTFModel model, List<Long> accountList) {

        Map<Long, Double> result;
        Map<Long, Double> ret = new HashMap<>();

        if ((model.getBundle() == null || featureMap == null)) {
            logger.error("RtaAlgoService.predict model or featureMap is null, Bundle{}", model.getBundle());
            for (long account : accountList) {
                ret.put(account, 1.0);
            }
            return ret;
        }

        result = model.predict(featureMap);
        Double scoreOrDefault = result.getOrDefault(1L, 1.0);
        Double score = MathUtil.stdwithBoundary(scoreOrDefault.doubleValue(), 0.799, 1.201); // 打分结果clip
        for (long account : accountList) {
            ret.put(account, score);
        }
        if (Math.random() <= 0.0001) {
            logger.info("RtaAlgoService.predict score{}", JSON.toJSONString(result));
            logger.info("RtaAlgoService.predict featureMap{}", JSON.toJSONString(featureMap));
        }

        return ret;
    }


    /**
     * 3.算法出价
     * 描述：账户维度的算法出价
     * @param accountIdDos 算法出价账户信息
     * @param groupTag AB实验分流标记
     *
     * @return 算法出价账户的出价系数【0.5~5】
     */
    public static List<RtaBidRet> bidding(List<AccountIdDo> accountIdDos, Integer groupTag, RtaRoiControlModel roiControlModel) {

        List<RtaBidRet> rtaBidRets = new ArrayList<>();

        try {
            rtaBidRets = accountIdDos.stream().map(accountIdDo -> {
                Long accountId = accountIdDo.getAccountId();
                Double factor = getRoiFactor(accountId, roiControlModel);
                Double rtaScore = accountIdDo.getRtaScore();
                RtaBidRet ret = new RtaBidRet();
                ret.setAccountId(accountId);
                ret.setFactor(factor);
                ret.setRtaScore(rtaScore);

//                Double bidRatio =MathUtil.stdwithBoundary(1 * rtaScore, 0.5, 2.5);
                Double bidRatio = MathUtil.stdwithBoundary(factor * rtaScore, 0.5, 5.0);
                ret.setBidRatio(bidRatio);
                return ret;
            }).collect(Collectors.toList());

        } catch (Exception e) {
            logger.error("RtaAlgoService.bidding error", e);
        }

        return rtaBidRets;
    }




    public static Double getRoiFactor(Long accountId, RtaRoiControlModel roiControlModel) {
        Double factor = 1.0;

        roiControlModel = Optional.ofNullable(roiControlModel).orElse(new RtaRoiControlModel());
        RtaRoiControlParams roiControlParams = roiControlModel.getRoiControlParams(accountId);
        if (AssertUtil.isAllNotEmpty(roiControlParams, roiControlParams.getCConf())
                && roiControlParams.getCConf()) {
            factor = roiControlParams.getRFactor();
        }
        return factor;
    }

}
