package cn.com.duiba.nezha.alg.model.tf;//package cn.com.duiba.nezha.compute.alg.tf;
//


import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.model.grpc.GrpcClientSingle;
import cn.com.duiba.nezha.alg.model.grpc.GrpcPool;
import com.alibaba.fastjson.JSON;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

public class TFServingClientUtil {


    public static Map<String, GrpcPool> grpcPoolMap = new ConcurrentHashMap<>();

    public static String DF_INPUT = "feat_vals";

    public static String DF_SIGNAtURE_NAME = "serving_default";

//    public static String DF_OUTPUT = "prob";


    public static GrpcPool getGrpcPool(String host, int port, GenericObjectPoolConfig poolConfig) {
        String key = host + "-" + port;
        if (!grpcPoolMap.containsKey(key)) {
            grpcPoolMap.put(key, new GrpcPool(host, port, poolConfig));
        }
        return grpcPoolMap.get(key);
    }

    public static ManagedChannel getManagedChannel(String host, int port, GenericObjectPoolConfig poolConfig) throws Exception {
        return getGrpcPool(host, port, poolConfig).borrowObject().getChannel();
    }

    public static GrpcClientSingle getGrpcClientSingle(String host, int port, GenericObjectPoolConfig poolConfig) throws Exception {
        return getGrpcPool(host, port, poolConfig).borrowObject();

    }

    public static void returnGrpcClientSingle(String host, int port, GenericObjectPoolConfig poolConfig, GrpcClientSingle grpcClientSingle) throws Exception {
        getGrpcPool(host, port, poolConfig).returnObject(grpcClientSingle);

    }


    private static Map<String, List<Float>> predict(int featureNums,
                                                    List<Float> data,
                                                    String modelName,
                                                    ManagedChannel channel,
                                                    String signatureName,
                                                    String input,
                                                    List<String> output) {
        Map<String, List<Float>> ret = new HashMap<>();
        if (channel != null && data.size() % featureNums == 0) {
            //block模式
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel).withDeadlineAfter(500, TimeUnit.MILLISECONDS);

            //创建请求
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            //模型名称和模型方法名预设
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName(modelName);
            modelSpecBuilder.setSignatureName(signatureName);
            predictRequestBuilder.setModelSpec(modelSpecBuilder);

            //设置入参,访问默认是最新版本，如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法

            int totalSize = data.size();
            int featureSize = totalSize / featureNums;

            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(featureNums));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(featureSize));
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

            tensorProtoBuilder.addAllFloatVal(data);
            predictRequestBuilder.putInputs(input, tensorProtoBuilder.build());
            //访问并获取结果
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());


            for (String op : output) {
                TensorProto tensorProto = predictResponse.getOutputsOrThrow(op);
                ret.put(op, tensorProto.getFloatValList());
            }


        }
        return ret;
    }


    public static <T> Map<String, Map<T, Double>> predict(Map<T, List<Float>> dataMap,
                                                          String host,
                                                          int port,
                                                          String modelName,
                                                          GenericObjectPoolConfig poolConfig, List<String> output) throws Exception {
        Map<String, Map<T, Double>> ret = new HashMap<>();

        if (AssertUtil.isAllNotEmpty(dataMap, host, port, modelName)) {

            List<T> keyList = new ArrayList<>(dataMap.keySet());

            int featureNums = keyList.size();

            //数据格式转化
            List<Float> data = new ArrayList<>();
            for (T key : keyList) {
                data.addAll(dataMap.get(key));
            }

            GrpcPool grpcPool = getGrpcPool(host, port, poolConfig);
            GrpcClientSingle grpcClientSingle = grpcPool.borrowObject();

            try {
                Map<String, List<Float>> preValueMap = predict(featureNums, data, modelName, grpcClientSingle.getChannel(), DF_SIGNAtURE_NAME, DF_INPUT, output);
                for (String op : output) {
                    Map<T, Double> retSub = new HashMap<>();
                    List<Float> preValueList = preValueMap.get(op);
                    for (int i = 0; i < featureNums; i++) {
                        T key = keyList.get(i);
                        Float valueF = preValueList.get(i);
                        if (valueF != null) {
                            retSub.put(key, (double) valueF);
                        }

                    }
                    ret.put(op, retSub);
                }

            } catch (Exception e) {
                throw e;
            } finally {
                grpcPool.returnObject(grpcClientSingle);
            }

        }


        return ret;
    }


    public static void main(String[] args) throws Exception {

        // generate data
        Map<String, List<Float>> dataMap = new HashMap<>();
        for (int i = 0; i < 10; i++) {
            List<Float> data = new ArrayList<>();
            for (int j = 0; j < 302; j++) {
                data.add((float) (Math.random() * 0.0001));
            }

            dataMap.put(i + "-", data);
        }

//
//        Map<String, Double> ret = predict(dataMap, "47.96.221.208", 9000, "model-1");
//
//        System.out.println("RET= : " + JSON.toJSONString(ret));

    }
}
