package cn.com.duiba.nezha.alg.model.tf;//package cn.com.duiba.nezha.compute.alg.tf;
//


import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class TFServingClient {


    public String host;

    public String modelName;

    public int port;
    public GenericObjectPoolConfig poolConfig;

    public static String DF_OUTPUT = "prob";

    public static String ESMM_CTR_OUTPUT = "ctr_probabilities";
    public static String ESMM_CVR_OUTPUT = "cvr_probabilities";


    public GenericObjectPoolConfig getPoolConfig() {
        return poolConfig;
    }

    public void setPoolConfig(GenericObjectPoolConfig poolConfig) {
        this.poolConfig = poolConfig;
    }


    public void setHost(String host) {
        this.host = host;
    }

    public void setModelName(String modelName) {
        this.modelName = modelName;
    }

    public void setPort(int port) {
        this.port = port;
    }

    public String getHost() {
        return this.host;
    }

    public String getModelName() {
        return this.modelName;
    }

    public int getPort() {
        return this.port;
    }


    public TFServingClient(String host, int port, String modelName, GenericObjectPoolConfig poolConfig) {
        this.host = host;
        this.modelName = modelName;
        this.port = port;
        if (poolConfig != null) {
            this.poolConfig = poolConfig;
        } else {
            this.poolConfig = getDefPoolConfig();
        }
    }

    public GenericObjectPoolConfig getDefPoolConfig() {
        // 连接池的配置
        GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig();
        // 池中的最大连接数
        poolConfig.setMaxTotal(50);
        // 最少的空闲连接数
        poolConfig.setMinIdle(5);
        // 最多的空闲连接数
        poolConfig.setMaxIdle(20);
        // 当连接池资源耗尽时,调用者最大阻塞的时间,超时时抛出异常 单位:毫秒数
        poolConfig.setMaxWaitMillis(500);

        // 连接池存放池化对象方式,true放在空闲队列最前面,false放在空闲队列最后
        poolConfig.setLifo(true);
        // 连接空闲的最小时间,达到此值后空闲连接可能会被移除,默认即为30分钟
        poolConfig.setMinEvictableIdleTimeMillis(1000L * 60L * 30L);
        // 连接耗尽时是否阻塞,默认为true
        poolConfig.setBlockWhenExhausted(true);
        return poolConfig;
    }


    public <T> Map<T, Double> predict(Map<T, List<Float>> dataMap) throws Exception {

        Map<String, Map<T, Double>> ret = TFServingClientUtil.predict(dataMap, this.host, this.port, this.modelName, this.poolConfig, Arrays.asList(DF_OUTPUT));

        return ret.get(DF_OUTPUT);


    }

    public <T> Map<PredictResultType, Map<T, Double>> predictMut(Map<T, List<Float>> dataMap) throws Exception {
        Map<PredictResultType, Map<T, Double>> ret = new HashMap<>();

        Map<String, Map<T, Double>> tmp = TFServingClientUtil.predict(dataMap, this.host, this.port, this.modelName, this.poolConfig, Arrays.asList(ESMM_CTR_OUTPUT, ESMM_CVR_OUTPUT));

        ret.put(PredictResultType.CTR, tmp.get(ESMM_CTR_OUTPUT));
        ret.put(PredictResultType.CVR, tmp.get(ESMM_CVR_OUTPUT));

        return ret;
    }
}
