package cn.com.duiba.nezha.alg.alg.adx.rcmd2;

import cn.com.duiba.nezha.alg.alg.base.MathBase;
import cn.com.duiba.nezha.alg.api.model.AdxLocalTFModel;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.IModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

@Data
public class Model {
    private static final Logger logger = LoggerFactory.getLogger(Model.class);

    private IModel ctrModel;
    private IModel launchPvCoder;
    private IModel ARPUCoder;
    private LocalTFModel launchPvTFModel;  // launchPvCoder对应的TF model
    private LocalTFModel arpuTFModel;   // ARPUCoder对应的TF model

    private AdxLocalTFModel esmmTFModel;


    // v2版本
    private IModel clickPvCoder; //每pv点券 coder模型
    private IModel adCpcCoder; //券cpc coder模型
    private LocalTFModelV2 clickPvTFModel;  // clickPvCoder对应的TF model
    private LocalTFModelV2 adCpcTFModel;   // adCpcCoder对应的TF model

    private IModel launchPvCoderV2;
    private LocalTFModelV2 launchPvTFModelV2;  // launchPvCoder对应的TF model
    private IModel ARPUCoderV2;
    private LocalTFModelV2 arpuTFModelV2;  // launchPvCoder对应的TF model


    public <T> Map<T, Double> predictCtr(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ctrModel.predictsNew(featureMap);
            return ret;
        } catch (Exception e) {
            logger.error("Model.predictCtr error{} ctrModel{}", e, ctrModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictLaunchPv(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = launchPvCoder.predictWithLocalTFNew(featureMap, launchPvTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictLaunchPv error{}, launchPvCoder{}, launchPvTFModel{}", e, launchPvCoder, launchPvTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictLaunchPvV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = launchPvCoderV2.predictWithLocalTFV2(featureMap, launchPvTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictLaunchPv error{}, launchPvCoder{}, launchPvTFModel{}", e, launchPvCoderV2, launchPvTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictARPU(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ARPUCoder.predictWithLocalTFNew(featureMap, arpuTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error{}, ARPUCoder{}, arpuTFModel{}", e, ARPUCoder, arpuTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictARPUV2(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ARPUCoderV2.predictWithLocalTFV2(featureMap, arpuTFModelV2);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error{}, ARPUCoder{}, arpuTFModel{}", e, ARPUCoderV2, arpuTFModelV2);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictClickPv(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = clickPvCoder.predictWithLocalTFV2(featureMap, clickPvTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictClickPv error{}, clickPvCoder{}, clickPvTFModel{}", e, clickPvCoder, clickPvTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictAdCpc(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = adCpcCoder.predictWithLocalTFV2(featureMap, adCpcTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictAdCpc error{}, adCpcCoder{}, adCpcTFModel{}", e, adCpcCoder, adCpcTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCtr2(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Double> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);

            for (T key : modelOutput.keySet()) {
                Float ctr = modelOutput.get(key).get("ctr")[0];
                Double roundCtr = new BigDecimal(ctr.toString()).setScale(5, BigDecimal.ROUND_HALF_UP).doubleValue();
                ret.put(key, roundCtr);
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCtr2 error{}, TFModel{}", e, esmmTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCvr(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Double> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);
            for (T key : modelOutput.keySet()) {
                BigDecimal ctcvr = new BigDecimal(modelOutput.get(key).get("ctcvr")[0].toString());
                BigDecimal ctr = new BigDecimal(modelOutput.get(key).get("ctr")[0].toString());
                Double cvr = ctcvr.divide(ctr, 5, BigDecimal.ROUND_HALF_UP).doubleValue();
                ret.put(key, cvr);
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCVR error TFModel{}", esmmTFModel, e);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Map<String, Double>> predictCtrCvr(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Map<String, Double>> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);
            for (T key : modelOutput.keySet()) {

                BigDecimal ctcvrBD = new BigDecimal(modelOutput.get(key).get("ctcvr")[0].toString());
                BigDecimal ctrBD = new BigDecimal(modelOutput.get(key).get("ctr")[0].toString());
                Double cvr = Objects.equals(ctrBD.doubleValue(), 0d) ? 0d : ctcvrBD.divide(ctrBD, 5, RoundingMode.HALF_UP).doubleValue();

                HashMap<String, Double> perSampleMap = new HashMap<>(5);
                perSampleMap.put("ctr", ctrBD.setScale(5, RoundingMode.HALF_UP).doubleValue());
                perSampleMap.put("cvr", cvr);
                ret.put(key, perSampleMap);
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictESMM error TFModel {}", esmmTFModel, e);
        }
        return Collections.emptyMap();
    }
}
