package cn.com.duiba.nezha.alg.alg.adxhd.model;

import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.DeepModelV2;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import cn.com.duiba.nezha.alg.model.tf.OnnxModel;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.Map;


@Data
public class AHOnnxModel {
    private static final Logger logger = LoggerFactory.getLogger(AHOnnxModel.class);

    public static void main(String[] args) {
        System.out.println("test");
    }

    /**
     * ctr模型
     */
    private DeepModelV2 ctrCodeModel;    // 点击模型，编码器
    private OnnxModel ctrOnnxModel;  // 点击模型，TF模型

    /**
     * cvr模型
     */
    private DeepModelV2 cvrCodeModel;    // 转化模型，编码器
    private OnnxModel cvrOnnxModel;   // 转化模型，TF模型


    public <T> Map<T, Double> predictCtr(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            logger.info("panhangyi predict ctr!");
            ret = ctrCodeModel.predictWithOnnx(featureMap, ctrOnnxModel);
            return ret;
        } catch (Exception e) {
            logger.error("AHOnnxModel.predictCtr error{} ctrModel{}", e, ctrCodeModel.getModelId());
        }
        return Collections.emptyMap();
    }

    public <T> Map<T, Double> predictCvr(Map<T, FeatureMapDo> featureMap) throws Exception {
        try {
            Map<T, Double> ret;
            ret = cvrCodeModel.predictWithOnnx(featureMap, cvrOnnxModel);
            return ret;
        } catch (Exception e) {
            logger.error("AHOnnxModel.predictCvr error{} cvrModel{}", e, cvrCodeModel.getModelId());
        }
        return Collections.emptyMap();
    }

}