package cn.com.duiba.nezha.alg.alg.adx.rcmd2;

import cn.com.duiba.nezha.alg.api.model.AdxLocalTFModel;
import cn.com.duiba.nezha.alg.api.model.E2ELocalTFModel;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.IModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

@Data
public class Model {
    private static final Logger logger = LoggerFactory.getLogger(Model.class);

    private IModel ctrModel;
    private IModel launchPvCoder;
    private IModel ARPUCoder;
    private LocalTFModel launchPvTFModel;  // launchPvCoder对应的TF model
    private LocalTFModel arpuTFModel;   // ARPUCoder对应的TF model

    private AdxLocalTFModel esmmTFModel;


    // v2版本
    private IModel clickPvCoder; //每pv点券 coder模型
    private IModel adCpcCoder; //券cpc coder模型
    private LocalTFModelV2 clickPvTFModel;  // clickPvCoder对应的TF model
    private LocalTFModelV2 adCpcTFModel;   // adCpcCoder对应的TF model

    private IModel launchPvCoderV2;
    private LocalTFModelV2 launchPvTFModelV2;  // launchPvCoder对应的TF model
    private IModel ARPUCoderV2;
    private LocalTFModelV2 arpuTFModelV2;  // launchPvCoder对应的TF model

    private E2ELocalTFModel userScoreModel; //用户打分模型

    private IModel userScoreCoderV2;
    private LocalTFModelV2 userScoreTFModelV2;  // 用户打分模型v2对应的TF model

    //adCpc模型v2
    private IModel adCpcCoderV2;//coder_004
    private LocalTFModelV2 adCpcTFModelV2;

    //clickpv模型v2
    private IModel ClickPvCoderV2;
    private LocalTFModelV2 ClickPvTFModelV2;

    //ctr模型v2
    private IModel CtrCoderV2;
    private LocalTFModelV2 CtrTFModelV2;




    /**
     * 百度模型-编码器
     * ctr、click、cpc均使用一个coder
     * ID：adx_coder_v004
     */
    private IModel allBaiduCoderV2;//coder_004

    /**
     * 百度模型-ctr
     * 类型，CTR_BAIDU_V2(15, "ctr-baidu-v2"),
     * 名称，mid-adx-ctr-baidu-v001
     */
    private LocalTFModelV2 ctrBaiduTFModelV2;//


    /**
     * 百度模型-click
     * 类型，CLICKPV_BAIDU_V2(16, "clickpv-baidu-v2"),
     * 名称，mid-adx-click-baidu-v001
     */
    private LocalTFModelV2 clickPvBaiduTFModelV2;//

    /**
     * 百度模型-cpc
     * 类型，CPC_BAIDU_V2(17, "cpc-baidu-v2"),
     * 名称，mid-adx-cpc-baidu-v001
     */
    private LocalTFModelV2 adCpcBaiduTFModelV2;//





    /**
     * 新版本模型  53个特征 7。14号
     * ctr：ID1
     * click：ID2
     * cpc：ID3 已有
     *
     * coder：ID4
     */







    public <T> Map<T, Double> predictCtr(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ctrModel.predictsNew(featureMap);
            return ret;
        } catch (Exception e) {
            logger.error("Model.predictCtr error{} ctrModel{}", e, ctrModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictLaunchPv(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = launchPvCoder.predictWithLocalTFNew(featureMap, launchPvTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictLaunchPv error{}, launchPvCoder{}, launchPvTFModel{}", e, launchPvCoder, launchPvTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictLaunchPvV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = launchPvCoderV2.predictWithLocalTFV2(featureMap, launchPvTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictLaunchPv error{}, launchPvCoder{}, launchPvTFModel{}", e, launchPvCoderV2, launchPvTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictARPU(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ARPUCoder.predictWithLocalTFNew(featureMap, arpuTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error{}, ARPUCoder{}, arpuTFModel{}", e, ARPUCoder, arpuTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictARPUV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ARPUCoderV2.predictWithLocalTFV2(featureMap, arpuTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error{}, ARPUCoder{}, arpuTFModel{}", e, ARPUCoderV2, arpuTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictClickPv(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = clickPvCoder.predictWithLocalTFV2(featureMap, clickPvTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictClickPv error{}, clickPvCoder{}, clickPvTFModel{}", e, clickPvCoder, clickPvTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictAdCpc(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = adCpcCoder.predictWithLocalTFV2(featureMap, adCpcTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictAdCpc error{}, adCpcCoder{}, adCpcTFModel{}", e, adCpcCoder, adCpcTFModel);
        }
        return Collections.emptyMap();
    }


    public <T> Map<String, Map<T, Double>> predictCtrClick(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<String, Map<T, Double>> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<String, Map<T, Float[]>> modelOutput = esmmTFModel.predict(modelInput);
            for (String key : modelOutput.keySet()) {
                Map<T, Double> midRet = new HashMap<>();
                for (Map.Entry<T, Float[]> tEntry : modelOutput.get(key).entrySet()) {
                    double v = MathUtil.formatDouble(tEntry.getValue()[0].doubleValue(),8);
                    midRet.put(tEntry.getKey(), v);
                }
                ret.put(key, midRet);
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictESMM error TFModel {}", esmmTFModel, e);
        }
        return Collections.emptyMap();
    }

    public static void main(String[] args) throws Exception {
        Model model = new Model();

        AdxLocalTFModel adxLocalTFModel = new AdxLocalTFModel();
        adxLocalTFModel.loadLatestModel("C:\\Users\\aaa\\Desktop\\adx\\model\\test");
        model.setEsmmTFModel(adxLocalTFModel);
        HashMap<String, String> featureMap = new HashMap<>();
        featureMap.put("ft1","aaa");
        FeatureMapDo featureMapDo = new FeatureMapDo();
        featureMapDo.setDynamicFeatureMap(featureMap);
        featureMapDo.setStaticFeatureMap(new HashMap<>());

        Map<Long, FeatureMapDo> sample = new HashMap<>();
        sample.put(123L,featureMapDo);
        sample.put(124L,featureMapDo);
        Map<String, Map<Long, Double>> result = model.predictCtrClick(sample);

        System.out.println(result);
    }

    /**
     * 用户打分模型预估
     */
    public <T> Map<T, Double> predictUserScore(Map<T, FeatureMapDo> featureMap) {

        try {
            Map<T, Double> ret;

            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            ret = userScoreModel.predict(modelInput);
            return ret;

        } catch (Exception e) {
            logger.error("AdxRecommend.predictUserScore error{}, userScoreModel{}", e, userScoreModel);
        }
        return Collections.emptyMap();

    }


    public <T> Map<T, Double> predictUserScoreV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = userScoreCoderV2.predictWithLocalTFV2(featureMap, userScoreTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictUserScoreV2 error{}, userScoreCoderV2{}, userScoreTFModelV2{}", e, userScoreCoderV2, userScoreTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictAdCpcV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = adCpcCoderV2.predictWithLocalTFV2(featureMap, adCpcTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictAdCpcV2 error{}, adCpcCoderV2{}, adCpcTFModelV2{}", e, adCpcCoderV2, adCpcTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCtr_V2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = CtrCoderV2.predictWithLocalTFV2(featureMap, CtrTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCtr_V2 error{}, CtrCoderV2{}, CtrTFModelV2{}", e, CtrCoderV2, CtrTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictClickPv_V2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ClickPvCoderV2.predictWithLocalTFV2(featureMap, ClickPvTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictClickPv_V2 error{}, ClickPvCoderV2{}, ClickPvTFModelV2{}", e, ClickPvCoderV2, ClickPvTFModelV2);
        }
        return Collections.emptyMap();
    }




    public <T> Map<T, Double> predictCtrBaidu(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = allBaiduCoderV2.predictWithLocalTFV2(featureMap, ctrBaiduTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCtrBaidu error{}, CtrCoderV2{}, CtrTFModelV2{}", e, allBaiduCoderV2, ctrBaiduTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictClickPvBaidu(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = allBaiduCoderV2.predictWithLocalTFV2(featureMap, clickPvBaiduTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictClickPv_V2 error{}, ClickPvCoderV2{}, ClickPvTFModelV2{}", e, allBaiduCoderV2, clickPvBaiduTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictAdCpcBaidu(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = allBaiduCoderV2.predictWithLocalTFV2(featureMap, adCpcBaiduTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictAdCpcV2 error{}, adCpcCoderV2{}, adCpcTFModelV2{}", e, allBaiduCoderV2, adCpcBaiduTFModelV2);
        }
        return Collections.emptyMap();
    }

}
