package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

/**
 * 该模型类，是回传模型v3.0项目，回传上报媒体的时候根据dcvr值，将高深度转化率的样本优先回传，
 * 和DCVRv2模型类，唯一的区别是：
 * 返回结果的时候，map里面塞的「trade_dcvr」的key，而不是「dcvr」
 *
 * 2024.03.13 新增 运营商行业的dcvr预估接口，预估结果的key为「opt_dcvr」，和货到的「trade_dcvr」有所区别
 */

public class DCVRv3 implements Serializable {

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DCVRv3.class);

    private IModel dCvrModel;
    private IModel coderModel;
    private LocalTFModelV2 dcvrLocalTFModel;

    public LocalTFModelV2 getDcvrLocalTFModel() {
        return dcvrLocalTFModel;
    }

    public void setDcvrLocalTFModel(LocalTFModelV2 dcvrLocalTFModel) {
        this.dcvrLocalTFModel = dcvrLocalTFModel;
    }

    /**
     * 点击率预估深度学习TF-S-Client
     */
    private TFServingClient dcvrServingClient;

    public IModel getDCvrModel() {
        return dCvrModel;
    }

    public void setDCvrModel(IModel dCvrModel) {
        this.dCvrModel = dCvrModel;
    }

    public TFServingClient getDcvrServingClient() {
        return dcvrServingClient;
    }

    public void setDcvrServingClient(TFServingClient dcvrServingClient) {
        this.dcvrServingClient = dcvrServingClient;
    }


    /**
     * @param
     */
    public DCVRv3() {

    }

    public DCVRv3(IModel coderModel, LocalTFModelV2 dcvrLocalTFModel) {
        this.coderModel = coderModel;
        this.dcvrLocalTFModel = dcvrLocalTFModel;
    }

    /**
     * 预估主接口
     *
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": trade_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV2(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": trade_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.TRADE_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            return predictDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            return predictDCvrWithLocalTF(featureMap);
        }
        return predictDCvr(featureMap);

    }




    /**
     * 运营商行业的Opt_dcvr预估主接口
     *
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrWithTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": opt_dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.OPT_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrWithLocalTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<T, Double> dcvrMap = coderModel.predictWithLocalTFV2(featureMap, dcvrLocalTFModel);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.OPT_DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": opt_dcvr fusing");
            //dcvr没有降级模型，导致dCvrModel可能会为空
            if(dCvrModel!=null) {
                dcvrMap = dCvrModel.predictsNew(featureMap);
            }
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.OPT_DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictOptDCvrAndTF(Map<T, FeatureMapDo> featureMap) throws Exception {
        if (dcvrServingClient != null) {
            return predictOptDCvrWithTF(featureMap);
        }

        if (dcvrLocalTFModel != null) {
            return predictOptDCvrWithLocalTF(featureMap);
        }
        return predictOptDCvr(featureMap);

    }
}
