package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.OneLocalTfModel;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

/**
 * 包括了ctr cvr预估的大个儿模型类
 */
@AllArgsConstructor
@Data
public class NormalModel implements Serializable {
    private static final long serialVersionUID = 22348095L;
    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(NormalModel.class);

    //算法策略类型
    private Integer algType;
    //点击率预估模型，自研fm
    private IModel ctrFMModel;
    //转化率预估模型，自研fm
    private IModel cvrFMModel;


    //降级点击率预估模型，自研fm
    private IModel ctrFusingFMModel;
    //降级转化率预估模型，自研fm
    private IModel cvrFusingFMModel;

    //tensorflow模型，ctr
    private OneLocalTfModel ctrTfModel;
    //tensorflow模型，cvr
    private OneLocalTfModel cvrTfModel;

    //模型类型
    private MutModelType mutModelType;

    /**
     * @param featureMap，放了特征的Map，分为动态的跟静态的2个在里面。不过也自带1个方法可以2map合一。
     * @param <T> 样本的key的类型
     * @return 返回ctr跟cvr的预估结果，以嵌套Map的形式保存
     * @throws Exception 如果熔断模型都预估异常了，就抛出异常
     */
    public <T> Map<PredictResultType, Map<T, Double>> predict(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> result = new HashMap<>();
        Map<T, Double> ctrMap = predictCtr(featureMap);
        Map<T, Double> cvrMap = predictCvr(featureMap);
        if (AssertUtil.isAllNotEmpty(ctrMap, cvrMap)) {
            result.put(PredictResultType.CTR, ctrMap);
            result.put(PredictResultType.CVR, cvrMap);
        } else {
            result = null;
        }
        return result;
    }

    /**
     * @param featureMap，放了特征的Map，分为动态的跟静态的2个在里面。不过也自带1个方法可以2map合一。
     * @param <T> 样本的key的类型
     * @return 返回ctr的预估结果，以Map形式
     * @throws Exception 如果熔断模型都预估异常了，就抛出异常
     */
    private <T> Map<T, Double> predictCtr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<T, Double> result;
        try {
            if (ctrTfModel != null) {
                Map<T, Map<String, String>> normalFeatureMap = processFeatureMap(featureMap);
                result = ctrTfModel.predict(normalFeatureMap);
            } else {
                result = ctrFMModel.predictsNew(featureMap);
            }
        } catch (Exception e) {
            logger.warn("ctr fusing.. alg type = " + algType, e);
            result = ctrFusingFMModel.predictsNew(featureMap);
        }
        return result;
    }

    //这个就不复读机了
    private <T> Map<T, Double> predictCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<T, Double> result;
        try {
            if (cvrTfModel != null) {
                Map<T, Map<String, String>> normalFeatureMap = processFeatureMap(featureMap);
                result = cvrTfModel.predict(normalFeatureMap);
            } else {
                result = cvrFMModel.predictsNew(featureMap);
            }
        } catch (Exception e) {
            logger.warn("cvr fusing.. alg type = " + algType, e);
            result = cvrFusingFMModel.predictsNew(featureMap);
        }
        return result;
    }

    /**
     * 处理一下featureMap，二合一就是，方便喂到tf模型中
     * @param featureMap value是FeatureMapDo的featureMap
     * @param <T> 样本的key类型
     * @return 嵌套的Map，第一层key是样本key，然后里面是key value均为String类型的featureMap
     */
    private <T> Map<T, Map<String, String>> processFeatureMap(Map<T, FeatureMapDo> featureMap) {
        Map<T, Map<String, String>> normalFeatureMap = new HashMap<>();
        for (Map.Entry<T, FeatureMapDo> featureMapDoEntry : featureMap.entrySet()) {
            normalFeatureMap.put(featureMapDoEntry.getKey(), featureMapDoEntry.getValue().getFeatureMap());
        }
        return normalFeatureMap;
    }
}
