/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.model.grpc.GrpcClientSingle;
import cn.com.duiba.nezha.alg.model.grpc.GrpcPool;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import my.com.google.protobuf.ByteString;
import my.org.tensorflow.framework.DataType;
import my.org.tensorflow.framework.TensorProto;
import my.org.tensorflow.framework.TensorShapeProto;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

public class TFServingClientUtil {
    public static Map<String, GrpcPool> grpcPoolMap = new ConcurrentHashMap<String, GrpcPool>();
    public static String DF_INPUT = "feat_vals";
    public static String DF_SIGNAtURE_NAME = "serving_default";
    private static final Logger logger = LoggerFactory.getLogger(TFServingClientUtil.class);

    public static GrpcPool getGrpcPool(String host, int port, GenericObjectPoolConfig poolConfig) {
        String key = host + "-" + port;
        if (!grpcPoolMap.containsKey(key)) {
            TFServingClientUtil.init(host, port, poolConfig, key);
        }
        return grpcPoolMap.get(key);
    }

    public static synchronized void init(String host, int port, GenericObjectPoolConfig poolConfig, String key) {
        if (!grpcPoolMap.containsKey(key)) {
            grpcPoolMap.put(key, new GrpcPool(host, port, poolConfig));
        }
    }

    public static ManagedChannel getManagedChannel(String host, int port, GenericObjectPoolConfig poolConfig) throws Exception {
        return TFServingClientUtil.getGrpcPool(host, port, poolConfig).borrowObject().getChannel();
    }

    public static GrpcClientSingle getGrpcClientSingle(String host, int port, GenericObjectPoolConfig poolConfig) throws Exception {
        return TFServingClientUtil.getGrpcPool(host, port, poolConfig).borrowObject();
    }

    public static void returnGrpcClientSingle(String host, int port, GenericObjectPoolConfig poolConfig, GrpcClientSingle grpcClientSingle) throws Exception {
        TFServingClientUtil.getGrpcPool(host, port, poolConfig).returnObject(grpcClientSingle);
    }

    private static Map<String, List<Float>> predict(int featureNums, List<Float> data, String modelName, ManagedChannel channel, String signatureName, String input, List<String> output) {
        HashMap<String, List<Float>> ret = new HashMap<String, List<Float>>();
        if (channel != null && data.size() % featureNums == 0) {
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = (PredictionServiceGrpc.PredictionServiceBlockingStub)PredictionServiceGrpc.newBlockingStub((Channel)channel).withDeadlineAfter(500L, TimeUnit.MILLISECONDS);
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName(modelName);
            modelSpecBuilder.setSignatureName(signatureName);
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
            int totalSize = data.size();
            int featureSize = totalSize / featureNums;
            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize((long)featureNums));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize((long)featureSize));
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
            tensorProtoBuilder.addAllFloatVal(data);
            predictRequestBuilder.putInputs(input, tensorProtoBuilder.build());
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
            for (String op : output) {
                TensorProto tensorProto = predictResponse.getOutputsOrThrow(op);
                ret.put(op, tensorProto.getFloatValList());
            }
        }
        return ret;
    }

    private static Map<String, List<Float>> predictString(int featureNums, List<String> data, String modelName, ManagedChannel channel, String signatureName, String input, List<String> output) {
        HashMap<String, List<Float>> ret = new HashMap<String, List<Float>>();
        if (channel != null && data.size() % featureNums == 0) {
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub((Channel)channel);
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName(modelName);
            modelSpecBuilder.setSignatureName(signatureName);
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
            int totalSize = data.size();
            int featureSize = totalSize / featureNums;
            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_STRING);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1L));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize((long)totalSize));
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
            ArrayList<ByteString> dataByte = new ArrayList<ByteString>();
            for (int i = 0; i < data.size(); ++i) {
                dataByte.add(ByteString.copyFromUtf8((String)data.get(i)));
            }
            tensorProtoBuilder.addAllStringVal(dataByte);
            predictRequestBuilder.putInputs(input, tensorProtoBuilder.build());
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
            for (String op : output) {
                TensorProto tensorProto = predictResponse.getOutputsOrThrow(op);
                ret.put(op, tensorProto.getFloatValList());
            }
        }
        return ret;
    }

    public static <T> Map<String, Map<T, Double>> predict(Map<T, List<Float>> dataMap, String host, int port, String modelName, GenericObjectPoolConfig poolConfig, List<String> output) throws Exception {
        HashMap<String, Map<T, Double>> ret = new HashMap<String, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{dataMap, host, port, modelName})) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            int featureNums = keyList.size();
            ArrayList<Float> data = new ArrayList<Float>();
            for (Object key : keyList) {
                data.addAll((Collection)dataMap.get(key));
            }
            GrpcPool grpcPool = TFServingClientUtil.getGrpcPool(host, port, poolConfig);
            GrpcClientSingle grpcClientSingle = grpcPool.borrowObject();
            try {
                Map<String, List<Float>> preValueMap = TFServingClientUtil.predict(featureNums, data, modelName, grpcClientSingle.getChannel(), DF_SIGNAtURE_NAME, DF_INPUT, output);
                for (String op : output) {
                    HashMap retSub = new HashMap();
                    List<Float> preValueList = preValueMap.get(op);
                    for (int i = 0; i < featureNums; ++i) {
                        Object key = keyList.get(i);
                        Float valueF = preValueList.get(i);
                        if (valueF == null) continue;
                        retSub.put(key, Double.valueOf(valueF.floatValue()));
                    }
                    ret.put(op, retSub);
                }
            }
            catch (Exception e) {
                throw e;
            }
            finally {
                grpcPool.returnObject(grpcClientSingle);
            }
        }
        return ret;
    }

    public static <T> Map<String, Map<T, Double>> predictString(Map<T, String> dataMap, String host, int port, String modelName, GenericObjectPoolConfig poolConfig, List<String> output) throws Exception {
        HashMap<String, Map<T, Double>> ret = new HashMap<String, Map<T, Double>>();
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{dataMap, host, port, modelName})) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            int featureNums = keyList.size();
            ArrayList<String> data = new ArrayList<String>();
            for (Object key : keyList) {
                data.add(dataMap.get(key));
            }
            Long t1 = System.currentTimeMillis();
            GrpcPool grpcPool = TFServingClientUtil.getGrpcPool(host, port, poolConfig);
            GrpcClientSingle grpcClientSingle = grpcPool.borrowObject();
            Long t2 = System.currentTimeMillis();
            Long diffTF1 = t2 - t1;
            try {
                Long t3 = System.currentTimeMillis();
                Map<String, List<Float>> preValueMap = TFServingClientUtil.predictString(featureNums, data, modelName, grpcClientSingle.getChannel(), DF_SIGNAtURE_NAME, DF_INPUT, output);
                Long t4 = System.currentTimeMillis();
                Long diffTF2 = t4 - t3;
                for (String op : output) {
                    HashMap retSub = new HashMap();
                    List<Float> preValueList = preValueMap.get(op);
                    for (int i = 0; i < featureNums; ++i) {
                        Object key = keyList.get(i);
                        Float valueF = preValueList.get(i);
                        if (valueF != null) {
                            retSub.put(key, Double.valueOf(valueF.floatValue()));
                            continue;
                        }
                        String logInfo = modelName + " single null value in predict res";
                        logger.warn(logInfo);
                    }
                    ret.put(op, retSub);
                }
            }
            catch (Exception e) {
                throw e;
            }
            finally {
                grpcPool.returnObject(grpcClientSingle);
            }
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        HashMap dataMap = new HashMap();
        for (int i = 0; i < 10; ++i) {
            ArrayList<Float> data = new ArrayList<Float>();
            for (int j = 0; j < 302; ++j) {
                data.add(Float.valueOf((float)(Math.random() * 1.0E-4)));
            }
            dataMap.put(i + "-", data);
        }
    }
}

