package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ESMM implements Serializable {
    private static final long serialVersionUID = -316102112618444131L;


    /**
     * 点击率预估模型
     */
    private IModel ctrModel;

    /**
     * 转化率预估模型
     */
    private IModel cvrModel;

    /**
     * 后端预估模型
     */
    private IModel bCvrModel;


    /**
     * 点击率预估深度学习TF-S-Client
     */
    private TFServingClient ctrTFServingClient;

    /**
     * 转化率预估深度学习TF-S-Client
     */
    private TFServingClient cvrTFServingClient;

    /**
     * 模型类型 ESMM  UN_ESMM
     */
    private MutModelType mutModelType;


    public IModel getCtrModel() {
        return ctrModel;
    }

    public void setCtrModel(IModel ctrModel) {
        this.ctrModel = ctrModel;
    }

    public IModel getCvrModel() {
        return cvrModel;
    }

    public void setCvrModel(IModel cvrModel) {
        this.cvrModel = cvrModel;
    }

    public IModel getBCvrModel() {
        return bCvrModel;
    }

    public void setBCvrModel(IModel bCvrModel) {
        this.bCvrModel = bCvrModel;
    }


    public TFServingClient getCtrTFServingClient() {
        return ctrTFServingClient;
    }

    public void setCtrTFServingClient(TFServingClient ctrTFServingClient) {
        this.ctrTFServingClient = ctrTFServingClient;
    }

    public TFServingClient getCvrTFServingClient() {
        return cvrTFServingClient;
    }

    public void setCvrTFServingClient(TFServingClient cvrTFServingClient) {
        this.cvrTFServingClient = cvrTFServingClient;
    }

    public MutModelType getMutModelType() {
        return mutModelType;
    }

    public void setMutModelType(MutModelType mutModelType) {
        this.mutModelType = mutModelType;
    }


    /**
     *
     * @param ctrModel
     * @param cvrModel
     * @param bCvrModel
     * @param ctrTFServingClient
     * @param cvrTFServingClient
     * @param mutModelType
     */
    public ESMM(IModel ctrModel, IModel cvrModel, IModel bCvrModel, TFServingClient ctrTFServingClient, TFServingClient cvrTFServingClient, MutModelType mutModelType) {
        this.ctrModel = ctrModel;
        this.cvrModel = cvrModel;
        this.bCvrModel = bCvrModel;
        this.ctrTFServingClient = ctrTFServingClient;
        this.cvrTFServingClient = cvrTFServingClient;
        this.mutModelType = mutModelType;
    }

    /**
     * 预估主接口
     * <p>
     * <p>
     * <p>
     * <p>
     * <p>
     * 点击率预估 && 转化率预估  &&  深度学习
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsWithTF(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = null;

        if (MutModelType.DEEP_ESMM == mutModelType) {
            /**
             * esmm 深度
             */
            ret = ctrTFServingClient.predictMut(getParams(featureMap));
        } else {
            /**
             * 普通
             *
             */
            ret = predictCTRsAndCVRs(featureMap);
        }

        return ret;
    }


    /**
     * 后端预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<T, Double> predictBCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        return bCvrModel.predicts(featureMap);
    }


    /**
     * 点击率 && 转化率预估
     *
     * @param featureMap
     * @param <T>
     * @return
     * @throws Exception
     */
    private <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        ret.put(PredictResultType.CTR, ctrModel.predictWithTF(featureMap, ctrTFServingClient));
        ret.put(PredictResultType.CVR, cvrModel.predictWithTF(featureMap, cvrTFServingClient));
        return ret;
    }


//    /**
//     * 点击率预估
//     *
//     * @param featureMap
//     * @param <T>
//     * @return
//     * @throws Exception
//     */
//    public <T> Map<T, Double> predictCTRs(Map<T, Map<String, String>> featureMap) throws Exception {
//        return ctrModel.predicts(featureMap);
//    }
//
//    /**
//     * 转化率预估
//     *
//     * @param featureMap
//     * @param <T>
//     * @return
//     * @throws Exception
//     */
//    public <T> Map<T, Double> predictCVRs(Map<T, Map<String, String>> featureMap) throws Exception {
//        return cvrModel.predicts(featureMap);
//    }

//
//    /**
//     * 点击率预估 && 深度学习
//     *
//     * @param featureMap
//     * @param <T>
//     * @return
//     * @throws Exception
//     */
//    public <T> Map<T, Double> predicCTRsWithTF(Map<T, Map<String, String>> featureMap) throws Exception {
//
//        return ctrModel.predictWithTF(featureMap, ctrTFServingClient);
//    }
//
//
//    /**
//     * 转化率预估 && 深度学习
//     *
//     * @param featureMap
//     * @param <T>
//     * @return
//     * @throws Exception
//     */
//    public <T> Map<T, Double> predicCVRsWithTF(Map<T, Map<String, String>> featureMap) throws Exception {
//
//        return cvrModel.predictWithTF(featureMap, cvrTFServingClient);
//    }


    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    private <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>();

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                List<Float> ctrP = ctrModel.getParam(entry.getValue());
                List<Float> cvrP = cvrModel.getParam(entry.getValue());
                ret.put(entry.getKey(), paramsSplicing(ctrP, cvrP));
            }
        }
        return ret;

    }


    private List<Float> paramsSplicing(List<Float> ctrP, List<Float> cvrP) {
        ctrP.addAll(cvrP);
        return ctrP;
    }


}