package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.HyperParams;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import com.alibaba.fastjson.JSON;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ESMMV2 implements Serializable {
    //for 大盘深度模型（用新版特征）；这个是复制的ESSMV2,将LocalTFModel改为了LocalTFModelV2（包括ctr和cvr）,并调用IModel里面的predictWithLocalTFV2；其余没有改动
    private static final long serialVersionUID = -316102112618444131L;

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(ESMMV2.class);

    private int PB_MAX_SIZE = 128;

    /**
     * 算法策略类型
     */
    private Integer algType;

    /**
     * 点击率预估模型   //在线fm
     */
    private IModel ctrModel;

    /**
     * 转化率预估模型  //在线fm
     */
    private IModel cvrModel;


    /**
     * 降级点击率预估模型
     */
    private IModel ctrFusingModel;

    /**
     * 降级转化率预估模型
     */
    private IModel cvrFusingModel;


    /**
     * 深度-本地模型 ctr
     */
    private LocalTFModelV2 ctrLocalTFModel;

    /**
     * 深度-本地模型 cvr
     */
    private LocalTFModelV2 cvrLocalTFModel;  //dapan-deep-model(with-new-feature)

    /**
     * 后端预估模型
     */
//    private IModel bCvrModel;


    /**
     * git
     * 点击率预估深度学习TF-S-Client
     */
    private TFServingClient ctrTFServingClient;

    /**
     * 转化率预估深度学习TF-S-Client
     */
    private TFServingClient cvrTFServingClient;

    /**
     * 模型类型 ESMM  UN_ESMM
     */
    private MutModelType mutModelType;

    public Integer getAlgType() {
        return algType;
    }

    public void setAlgType(Integer algType) {
        this.algType = algType;
    }

    public LocalTFModelV2 getCtrLocalTFModel() {
        return ctrLocalTFModel;
    }

    public void setCtrLocalTFModel(LocalTFModelV2 ctrLocalTFModel) {
        this.ctrLocalTFModel = ctrLocalTFModel;
    }

    public LocalTFModelV2 getCvrLocalTFModel() {
        return cvrLocalTFModel;
    }

    public void setCvrLocalTFModel(LocalTFModelV2 cvrLocalTFModel) {
        this.cvrLocalTFModel = cvrLocalTFModel;
    }

    public IModel getCtrModel() {
        return ctrModel;
    }

    public void setCtrModel(IModel ctrModel) {
        this.ctrModel = ctrModel;
    }

    public IModel getCvrModel() {
        return cvrModel;
    }

    public void setCvrModel(IModel cvrModel) {
        this.cvrModel = cvrModel;
    }

//    public IModel getBCvrModel() {
//        return bCvrModel;
//    }
//
//    public void setBCvrModel(IModel bCvrModel) {
//        this.bCvrModel = bCvrModel;
//    }


    public TFServingClient getCtrTFServingClient() {
        return ctrTFServingClient;
    }

    public void setCtrTFServingClient(TFServingClient ctrTFServingClient) {
        this.ctrTFServingClient = ctrTFServingClient;
    }

    public TFServingClient getCvrTFServingClient() {
        return cvrTFServingClient;
    }

    public void setCvrTFServingClient(TFServingClient cvrTFServingClient) {
        this.cvrTFServingClient = cvrTFServingClient;
    }

    public MutModelType getMutModelType() {
        return mutModelType;
    }

    public void setMutModelType(MutModelType mutModelType) {
        this.mutModelType = mutModelType;
    }

    public IModel getCtrFusingModel() {
        return ctrFusingModel;
    }

    public void setCtrFusingModel(IModel ctrFusingModel) {
        this.ctrFusingModel = ctrFusingModel;
    }

    public IModel getCvrFusingModel() {
        return cvrFusingModel;
    }

    public void setCvrFusingModel(IModel cvrFusingModel) {
        this.cvrFusingModel = cvrFusingModel;
    }


    /**
     * 7.29 lwj
     *
     * @param ctrModel
     * @param cvrModel
     * @param ctrFusingModel
     * @param cvrFusingModel
     * @param ctrTFServingClient
     * @param cvrTFServingClient
     * @param ctrLocalTFModel;USE V2
     * @param cvrLocalTFModel;调用V2
     * @param mutModelType
     */
    public ESMMV2(IModel ctrModel,
                  IModel cvrModel,
                  IModel ctrFusingModel,
                  IModel cvrFusingModel,
                  TFServingClient ctrTFServingClient,
                  TFServingClient cvrTFServingClient,
                  LocalTFModelV2 ctrLocalTFModel,
                  LocalTFModelV2 cvrLocalTFModel,
                  MutModelType mutModelType) {

        this.ctrModel = ctrModel;
        this.cvrModel = cvrModel;
        this.ctrFusingModel = ctrFusingModel;
        this.cvrFusingModel = cvrFusingModel;

        this.ctrTFServingClient = ctrTFServingClient;
        this.cvrTFServingClient = cvrTFServingClient;

        this.ctrLocalTFModel = ctrLocalTFModel;
        this.cvrLocalTFModel = cvrLocalTFModel;

        this.mutModelType = mutModelType;
    }

    /**
     * @param ctrModel
     * @param cvrModel           //     * @param bCvrModel
     * @param ctrTFServingClient
     * @param cvrTFServingClient
     * @param mutModelType
     */
    public ESMMV2(IModel ctrModel, IModel cvrModel, TFServingClient ctrTFServingClient, TFServingClient cvrTFServingClient, MutModelType mutModelType) {
        this.ctrModel = ctrModel;
        this.cvrModel = cvrModel;
//        this.bCvrModel = bCvrModel;
        this.ctrTFServingClient = ctrTFServingClient;
        this.cvrTFServingClient = cvrTFServingClient;
        this.mutModelType = mutModelType;
    }

    /**
     * 预估主接口
     * <p>
     * <p>
     * <p>
     * <p>
     * <p>
     * 点击率预估 && 转化率预估  &&  深度学习
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTF(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;


        if (MutModelType.CTR == mutModelType) {
            ret = predictCTRs(featureMap);
        } else {
            ret = predictCTRsAndCVRsWithTFDEEP(featureMap);
        }
        return ret;
    }


    /**
     * 预估主接口-20200729
     * <p>
     * <p>
     * <p>
     * <p>
     * <p>
     * 点击率预估 && 转化率预估  &&  深度学习
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTFNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;


        if (MutModelType.CTR == mutModelType) {
            ret = predictCTRsNew(featureMap);
        } else {
            ret = predictCTRsAndCVRsWithTFDEEPNew(featureMap);
        }
        return ret;
    }


    /**
     * 预估主接口-20201029
     *
     * @auther Brook
     * @param featureMap 预估样本集   //wyj;run this
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTFDeBias(Map<T, FeatureMapDo> featureMap, Map<HyperParams, Double> params) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret;


        if (MutModelType.CTR == mutModelType) {
            ret = predictCTRsNew(featureMap);
        } else if (MutModelType.DEBIAS == mutModelType) {   // 深度，负样本采样修正？
            ret = predictCTRsAndCVRsWithTFDEEPNew(featureMap);

            Map<T, Double> cvrPrd = ret.get(PredictResultType.CVR);
            Map<T, Double> newCvrPrd = ret.get(PredictResultType.CVR);

            Double sampleRate = params.getOrDefault(HyperParams.NCW, 1.0);// ==NCW // 负采样修正系数

            for (Map.Entry<T, Double> entry: cvrPrd.entrySet()) {
                Double prob = entry.getValue();
                newCvrPrd.put(entry.getKey(), prob / (prob + (1-prob)/sampleRate));
            }

            ret.put(PredictResultType.CVR, newCvrPrd); // 覆盖修正前
        } else {
            ret = predictCTRsAndCVRsWithTFDEEPNew(featureMap);
        }
        return ret;
    }



    /**
     * CTR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithTF(featureMap, ctrTFServingClient);

        if (AssertUtil.isAllNotEmpty(ctrMap)) {
            ret.put(PredictResultType.CTR, ctrMap);
        } else {
            ret = null;
        }
        return ret;
    }

    /**
     * CTR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithTFNew(featureMap, ctrTFServingClient);

        if (AssertUtil.isAllNotEmpty(ctrMap)) {
            ret.put(PredictResultType.CTR, ctrMap);
        } else {
            ret = null;
        }
        return ret;
    }

    /**
     * CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = cvrModel.predictWithTF(featureMap, cvrTFServingClient);

        if (AssertUtil.isAllNotEmpty(ctrMap)) {
            ret.put(PredictResultType.CVR, ctrMap);
        } else {
            ret = null;
        }

        return ret;

    }

    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTFDEEP(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;

        if (MutModelType.DEEP_ESMM == mutModelType) {
            /**
             * esmm 深度
             */
            ret = ctrTFServingClient.predictMut(getParams(featureMap));

        } else {
            /**
             * 普通 深度
             *
             */
            ret = predictCTRsAndCVRs(featureMap);
        }


        if (ret == null) {
            ret = predictFusingCTRsAndCVRs(featureMap);

            if (ret != null) {

                for (Map.Entry<PredictResultType, Map<T, Double>> entry : ret.entrySet()) {
                    PredictResultType key = entry.getKey();
                    if (key == PredictResultType.CTR) {
                        for (Map.Entry<T, Double> subEntry : entry.getValue().entrySet()) {
                            Double predictCtr = subEntry.getValue();

                            if (predictCtr == 0) {
                                String logInfo = ctrTFServingClient.modelName + " predictFusing warn preCtr=0";
                                logger.warn(logInfo);
                            }
                        }
                    }
                }
            } else {
                String logInfo = ctrTFServingClient.modelName + " predictFusing res=null";
                logger.warn(logInfo);
            }

        }


        return ret;

    }
      //run
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTFDEEPNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;

        if (MutModelType.DEEP_ESMM == mutModelType) {
            /**
             * esmm 深度
             */
            ret = ctrTFServingClient.predictMut(getParamsNew(featureMap));

        } else if (MutModelType.DEEP_E2E_LOCAL == mutModelType) {

            /**
             * 本地模型
             */
            ret = predictCTRsAndCVRsLocalTFNew(featureMap);

        } else {
            /**
             * 普通 深度
             *
             */
            ret = predictCTRsAndCVRsNew(featureMap);
        }


        if (ret == null) {
            logger.warn("熔断模型预估 algType:{}", algType);
            ret = predictFusingCTRsAndCVRsNew(featureMap);

            if (ret != null) {

                for (Map.Entry<PredictResultType, Map<T, Double>> entry : ret.entrySet()) {
                    PredictResultType key = entry.getKey();
                    if (key == PredictResultType.CTR) {
                        for (Map.Entry<T, Double> subEntry : entry.getValue().entrySet()) {
                            Double predictCtr = subEntry.getValue();

                            if (predictCtr == 0) {
                                String logInfo = ctrTFServingClient.modelName + " predictFusing warn preCtr=0";
                                logger.warn(logInfo);
                            }
                        }
                    }
                }
            } else {
                String logInfo = ctrTFServingClient.modelName + " predictFusing res=null";
                logger.warn(logInfo);
            }

        }


        return ret;

    }

//    /**
//     * 后端预估
//     *
//     * @param featureMap
//     * @param <T>
//     * @return
//     * @throws Exception
//     */
//    public <T> Map<T, Double> predictBCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
//        return bCvrModel.predicts(featureMap);
//    }


    /**
     * CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithTF(featureMap, ctrTFServingClient);

        Map<T, Double> cvrMap = cvrModel.predictWithTF(featureMap, cvrTFServingClient);
        if (AssertUtil.isAllNotEmpty(ctrMap, cvrMap)) {

            ret.put(PredictResultType.CTR, ctrMap);
            ret.put(PredictResultType.CVR, cvrMap);
        } else {
            ret = null;
        }

        return ret;

    }

    /**
     * CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithTFNew(featureMap, ctrTFServingClient);

        Map<T, Double> cvrMap = cvrModel.predictWithTFNew(featureMap, cvrTFServingClient);
        if (AssertUtil.isAllNotEmpty(ctrMap, cvrMap)) {

            ret.put(PredictResultType.CTR, ctrMap);
            ret.put(PredictResultType.CVR, cvrMap);
        } else {
            ret = null;
        }

        return ret;

    }

    /**
     * CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return   //localTFV2，wyj；run3
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsLocalTFNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);  // PB_MAX_SIZE==128
        Map<T, Double> ctrMap;    //为空时，用之前的；不为空则用深度tf
        if (ctrLocalTFModel == null) {
            ctrMap = ctrModel.predictWithTFNew(featureMap, ctrTFServingClient);
        } else {
            ctrMap = ctrModel.predictWithLocalTFV2(featureMap, ctrLocalTFModel);    //wyj,替换成IModelV2
        }
        Map<T, Double> cvrMap;
        if (cvrLocalTFModel == null) {
            cvrMap = cvrModel.predictWithTFNew(featureMap, cvrTFServingClient);
        } else {
            cvrMap = cvrModel.predictWithLocalTFV2(featureMap, cvrLocalTFModel);   //wyj,替换成IModelV2

            /**
             * 记录预估值
             */

            String cvrMapStr= null;
            if(featureMap.size()<40){
                cvrMapStr= JSON.toJSONString(cvrMap);
            }

            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                FeatureMapDo featureMapDo = entry.getValue();
                T key = entry.getKey();
                Double preValue= cvrMap.get(key);
                featureMapDo.setCodeDoStr(cvrMapStr+";"+preValue+";"+featureMapDo.getCodeDoStr());
            }


        }
        if (AssertUtil.isAllNotEmpty(ctrMap, cvrMap)) {   //对象非空

            ret.put(PredictResultType.CTR, ctrMap);
            ret.put(PredictResultType.CVR, cvrMap);
        } else {
            ret = null;   //为空则返回值直接为空
        }

        return ret;

    }

    public static void main(String[] args) {
        HashMap<String, Double> ctrMap = new HashMap<>(10);
        ctrMap.put("1", 0.01);

        HashMap<String, Double> cvrMap = new HashMap<>(10);
        cvrMap.put("2", 0.22);

        System.out.println(AssertUtil.isAllNotEmpty(ctrMap, cvrMap));
    }


    /**
     * 降级 CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictFusingCTRsAndCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        ret.put(PredictResultType.CTR, ctrFusingModel.predictWithTF(featureMap, null));
        ret.put(PredictResultType.CVR, cvrFusingModel.predictWithTF(featureMap, null));
        return ret;

    }

    /**
     * 降级 CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictFusingCTRsAndCVRsNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        ret.put(PredictResultType.CTR, ctrFusingModel.predictWithTFNew(featureMap, null));
        ret.put(PredictResultType.CVR, cvrFusingModel.predictWithTFNew(featureMap, null));
        return ret;

    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    private <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                List<Float> ctrP = ctrModel.getParam(entry.getValue());
                List<Float> cvrP = cvrModel.getParam(entry.getValue());
                ret.put(entry.getKey(), paramsSplicing(ctrP, cvrP));
            }
        }
        return ret;

    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    private <T> Map<T, List<Float>> getParamsNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                List<Float> ctrP = ctrModel.getParam(entry.getValue());
                List<Float> cvrP = cvrModel.getParam(entry.getValue());
                ret.put(entry.getKey(), paramsSplicing(ctrP, cvrP));
            }
        }
        return ret;

    }


    private List<Float> paramsSplicing(List<Float> ctrP, List<Float> cvrP) {
        ctrP.addAll(cvrP);
        return ctrP;
    }


}