package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

/**
 * 该模型类，是回传模型v3.0项目，回传上报媒体的时候根据dcvr值，将高深度转化率的样本优先回传，
 * 和DCVRv2模型类，唯一的区别是：
 * 返回结果的时候，map里面塞的「trade_dcvr」的key，而不是「dcvr」
 *
 * 2024.03.13 新增 运营商行业的dcvr预估接口，预估结果的key为「opt_dcvr」，和货到的「trade_dcvr」有所区别
 */

public class DCVRv3 implements Serializable {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DCVRv3.class);

    private IModel dCvrModel;
    private IModel coderModel;
    private LocalTFModelV2 dcvrLocalTFModel;

    public LocalTFModelV2 getDcvrLocalTFModel() {
        return dcvrLocalTFModel;
    }

    public void setDcvrLocalTFModel(LocalTFModelV2 dcvrLocalTFModel) {
        this.dcvrLocalTFModel = dcvrLocalTFModel;
    }

    /**
     * 点击率预估深度学习TF-S-Client
     */
    private TFServingClient dcvrServingClient;

    public IModel getDCvrModel() {
        return dCvrModel;
    }

    public void setDCvrModel(IModel dCvrModel) {
        this.dCvrModel = dCvrModel;
    }

    public TFServingClient getDcvrServingClient() {
        return dcvrServingClient;
    }

    public void setDcvrServingClient(TFServingClient dcvrServingClient) {
        this.dcvrServingClient = dcvrServingClient;
    }


    /**
     * @param
     */
    public DCVRv3() {

    }

    public DCVRv3(IModel coderModel, LocalTFModelV2 dcvrLocalTFModel) {
        this.coderModel = coderModel;
        this.dcvrLocalTFModel = dcvrLocalTFModel;
    }

    /**
     * 预估主接口
     *
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": trade_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV3(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": trade_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            logger.info("dcvrServingClient");
            return predictDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            logger.info("dcvrLocalTFModel");
            return predictDCvrWithLocalTF(featureMap);
        }
        return predictDCvr(featureMap);

    }




    /**
     * 运营商行业的Opt_dcvr预估主接口
     *
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": opt_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.OPT_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV2(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": opt_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.OPT_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            return predictOptDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            return predictOptDCvrWithLocalTF(featureMap);
        }
        return predictOptDCvr(featureMap);

    }

    // 以下四个方法，代扣dcvr模型使用的预估接口，
//    TODO 将行业、广告、代扣三个dcvr模型，可以通用一套预估接口，predictresulttype不同

    public <T> Map<PredictResultType, Map<T, Double>> predictWhDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.WH_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictWhDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.WH_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 代扣trade_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.WH_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictWhDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV3(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.WH_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 代扣trade_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.WH_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictWhDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            logger.info("代扣dcvrServingClient");
            return predictWhDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            logger.info("代扣dcvrLocalTFModel");
            return predictWhDCvrWithLocalTF(featureMap);
        }
        return predictWhDCvr(featureMap);

    }


    // 短支付行业新的客诉行业模型的预估接口
    public <T> Map<PredictResultType, Map<T, Double>> predictKesuTradeDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_TRADE_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictKesuTradeDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 客诉Trade_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.KESU_TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictKesuTradeDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV3(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 客诉Trade_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.KESU_TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictKesuTradeDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            logger.info("客诉tradeDcvrServingClient");
            return predictKesuTradeDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            logger.info("客诉tradeDcvrLocalTFModel");
            return predictKesuTradeDCvrWithLocalTF(featureMap);
        }
        return predictKesuTradeDCvr(featureMap);

    }



    // 短支付行业的新客诉客户模型，


    public <T> Map<PredictResultType, Map<T, Double>> predictKesuPrivateDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_PRIVATE_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictKesuPrivateDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_PRIVATE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 客诉private_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.KESU_PRIVATE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictKesuPrivateDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV3(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.KESU_PRIVATE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 客诉private_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.KESU_PRIVATE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictKesuPrivateDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            logger.info("客诉privateDcvrServingClient");
            return predictKesuPrivateDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            logger.info("客诉privateDcvrLocalTFModel");
            return predictKesuPrivateDCvrWithLocalTF(featureMap);
        }
        return predictKesuPrivateDCvr(featureMap);

    }

    // 货到行业 签收模型 的预估方法

    public <T> Map<PredictResultType, Map<T, Double>> predictSignDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.SIGN_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictSignDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.SIGN_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 签收sign_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.SIGN_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictSignDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV3(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.SIGN_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": 签收sign_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.SIGN_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictSignDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            logger.info("签收SignDcvrServingClient");
            return predictSignDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            logger.info("签收SignDcvrLocalTFModel");
            return predictSignDCvrWithLocalTF(featureMap);
        }
        return predictSignDCvr(featureMap);

    }



}
