package cn.com.duiba.nezha.alg.alg.adx.rcmd2;

import cn.com.duiba.nezha.alg.api.model.AdxLocalTFModel;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.IModel;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

@Data
public class Model {
    private static final Logger logger = LoggerFactory.getLogger(Model.class);

    private IModel ctrModel;
    private IModel launchPvCoder;
    private IModel ARPUCoder;
    private LocalTFModel launchPvTFModel;  // launchPvCoder对应的TF model
    private LocalTFModel arpuTFModel;   // ARPUCoder对应的TF model

    private AdxLocalTFModel esmmTFModel;




    public <T> Map<T, Double> predictCtr(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ctrModel.predictsNew(featureMap);
            return ret;
        } catch (Exception e) {
            logger.error("Model.predictCtr error{} ctrModel{}", e, ctrModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictLaunchPv(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = launchPvCoder.predictWithLocalTFNew(featureMap, launchPvTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictLaunchPv error{}, launchPvCoder{}, launchPvTFModel{}", e, launchPvCoder, launchPvTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictARPU(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = ARPUCoder.predictWithLocalTFNew(featureMap, arpuTFModel);
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error{}, ARPUCoder{}, arpuTFModel{}", e, ARPUCoder, arpuTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCtr2(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Double> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            //Map<T, Map<T, T>> modelOutput = esmmTFModel.predict(modelInput);
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);

            for (T key : modelOutput.keySet()) {
                Float ctr = modelOutput.get(key).get("ctr")[0];
                ret.put(key, ctr.doubleValue());
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCtr2 error{}, TFModel{}", e, esmmTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCvr(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Double> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);
            for (T key : modelOutput.keySet()) {
                Float ctcvr = modelOutput.get(key).get("ctcvr")[0];
                ret.put(key, ctcvr.doubleValue());
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictCVR error{}, esmmTFModel{}", e, esmmTFModel);
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Map<String, Double>> predictCtrCvr(Map<T, FeatureMapDo> featureMap) {
        try {
            Map<T, Map<String, Double>> ret = new HashMap<>();
            Map<T, Map<String, String>> modelInput = new HashMap<>();
            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {
                modelInput.put(entry.getKey(), entry.getValue().getFeatureMap());
            }
            Map<T, Map<String, Float[]>> modelOutput = esmmTFModel.predict(modelInput);
            for (T key : modelOutput.keySet()) {
                String ctcvrKey = "ctcvr";
                String ctrKey = "ctr";
                HashMap<String, Double> perSampleMap = new HashMap<>(5);
                Float ctcvr = modelOutput.get(key).get(ctcvrKey)[0];
                perSampleMap.put(ctcvrKey, ctcvr.doubleValue());
                Float ctr = modelOutput.get(key).get("ctr")[0];
                perSampleMap.put(ctrKey, ctr.doubleValue());
            }
            return ret;
        } catch (Exception e) {
            logger.error("AdxRecommend.predictARPU error {}, esmmTFModel {}", e, esmmTFModel);
        }
        return Collections.emptyMap();
    }
}
