package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.coder.FeatureNewCoder;
import cn.com.duiba.nezha.alg.feature.vo.CodeResult;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import com.alibaba.fastjson.JSON;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class DCVR implements Serializable {
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(DCVR.class);
    private int PB_MAX_SIZE = 128;

    /**
     * 后端预估模型
     */
    private IModel dCvrModel;
    private IModel coderModel;

    /**
     *
     * 点击率预估深度学习TF-S-Client
     */
    private TFServingClient dcvrServingClient;

    public IModel getDCvrModel() {
        return dCvrModel;
    }

    public void setDCvrModel(IModel bCvrModel) {
        this.dCvrModel = dCvrModel;
    }

    public TFServingClient getDcvrServingClient() {
        return dcvrServingClient;
    }

    public void setDcvrServingClient(TFServingClient dcvrServingClient) {
        this.dcvrServingClient = dcvrServingClient;
    }


    /**
     * @param dcvrModel
     */
    public DCVR(IModel dcvrModel) {
        this.dCvrModel = dcvrModel;
    }

    public DCVR(IModel dCvrModel, IModel coderModel, TFServingClient dcvrServingClient) {
        this.dCvrModel = dCvrModel;
        this.coderModel = coderModel;
        this.dcvrServingClient = dcvrServingClient;
    }

    /**
     * 预估主接口
     *
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<PredictResultType, Map<T, Double>> predictDCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> dcvrMap = dCvrModel.predictsNew(featureMap);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.DCVR, dcvrMap);
        } else {
            ret = null;
        }

        return ret;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrWithTF(Map<T, FeatureMapDo> featureMap,TFServingClient dcvrServingClient) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>(PB_MAX_SIZE);

        Map<T, Double> dcvrMap = coderModel.predictWithTFNew(featureMap, dcvrServingClient);

        if (AssertUtil.isAllNotEmpty(dcvrMap)) {
            ret.put(PredictResultType.DCVR, dcvrMap);
        } else {
            logger.warn(coderModel.getModelId() + ": dcvr fusing");
            dcvrMap = dCvrModel.predictsNew(featureMap);
            if (AssertUtil.isAllNotEmpty(dcvrMap)) {
                ret.put(PredictResultType.DCVR, dcvrMap);
            } else {
                ret = null;
            }
        }
        return ret;
    }


    public <T> Map<PredictResultType, Map<T, Double>> predictDCvrAndTF(Map<T, FeatureMapDo> featureMap)throws Exception{
        if(dcvrServingClient != null){
            return predictDCvrWithTF(featureMap,dcvrServingClient);
        }
        return predictDCvr(featureMap);

    }

}