package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class LocalESMM implements Serializable {
    private static final long serialVersionUID = -316102112618444131L;

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(LocalESMM.class);

    private int PB_MAX_SIZE = 128;

    /**
     * 点击率预估模型
     */
    private IModel ctrModel;

    /**
     * 后端预估模型
     */
    private IModel cvrModel;

    /**
     * git
     * 点击率预估深度学习TF-本地
     */
    private LocalTFModel ctrLocalTFModel;

    /**
     * 转化率预估深度学习TF-本地
     */
    private LocalTFModel cvrLocalTFModel;

    /**
     * 模型类型 ESMM  NORMAL
     */
    private MutModelType mutModelType;


    public LocalTFModel getCtrLocalTFModel() {
        return ctrLocalTFModel;
    }

    public void setCtrLocalTFModel(LocalTFModel ctrLocalTFModel) {
        this.ctrLocalTFModel = ctrLocalTFModel;
    }

    public LocalTFModel getCvrLocalTFModel() {
        return cvrLocalTFModel;
    }

    public void setCvrLocalTFModel(LocalTFModel cvrLocalTFModel) {
        this.cvrLocalTFModel = cvrLocalTFModel;
    }

    public MutModelType getMutModelType() {
        return mutModelType;
    }

    public void setMutModelType(MutModelType mutModelType) {
        this.mutModelType = mutModelType;
    }

    /**
     * @param ctrModel
     * @param cvrModel
     * @param ctrLocalTFModel
     * @param cvrLocalTFModel
     * @param mutModelType
     */
    public LocalESMM(IModel ctrModel, IModel cvrModel, LocalTFModel ctrLocalTFModel, LocalTFModel cvrLocalTFModel, MutModelType mutModelType) {
        this.ctrModel = ctrModel;
        this.cvrModel = cvrModel;
        this.ctrLocalTFModel = ctrLocalTFModel;
        this.cvrLocalTFModel = cvrLocalTFModel;
        this.mutModelType = mutModelType;
    }

    /**
     * 预估主接口
     * <p>
     * <p>
     * <p>
     * <p>
     * <p>
     * 点击率预估 && 转化率预估  &&  深度学习
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTF(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;


        if (MutModelType.CTR == mutModelType) {

            /**
             * 单模型
             */

            ret = predictCTRs(featureMap);

        } else if (MutModelType.DEEP_ESMM == mutModelType) {

            /**
             * 双模型 esmm 深度
             */
            ret = ctrLocalTFModel.predictESMM(getESMMParams(featureMap));

        } else {
            /**
             * 双模型 普通 深度
             *
             */
            ret = predictCTRsAndCVRs(featureMap);
        }
        return ret;
    }


    /**
     * CTR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithLocalTF(featureMap, ctrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(ctrMap)) {
            ret.put(PredictResultType.CTR, ctrMap);
        } else {
            ret = null;
        }
        return ret;
    }


    /**
     * CTR && CVR预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> ctrMap = ctrModel.predictWithLocalTF(featureMap, ctrLocalTFModel);
        Map<T, Double> cvrMap = cvrModel.predictWithLocalTF(featureMap, cvrLocalTFModel);
        if (AssertUtil.isAllNotEmpty(ctrMap, cvrMap)) {
            ret.put(PredictResultType.CTR, ctrMap);
            ret.put(PredictResultType.CVR, cvrMap);
        } else {
            ret = null;
        }

        return ret;

    }


    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    private <T> Map<T, List<Float>> getESMMParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                List<Float> ctrP = ctrModel.getParam(entry.getValue());
                List<Float> cvrP = cvrModel.getParam(entry.getValue());
                ret.put(entry.getKey(), paramsSplicing(ctrP, cvrP));
            }
        }
        return ret;

    }


    private List<Float> paramsSplicing(List<Float> ctrP, List<Float> cvrP) {
        ctrP.addAll(cvrP);
        return ctrP;
    }


}