package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.model.grpc.GrpcClientSingle;
import cn.com.duiba.nezha.alg.model.grpc.GrpcPool;
import io.grpc.ManagedChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

/* loaded from: input_file:cn/com/duiba/nezha/alg/model/tf/TFServingClientUtil.class */
public class TFServingClientUtil {
    public static Map<String, GrpcPool> grpcPoolMap = new ConcurrentHashMap();
    public static String DF_INPUT = "feat_vals";
    public static String DF_SIGNAtURE_NAME = "serving_default";

    public static GrpcPool getGrpcPool(String str, int i, GenericObjectPoolConfig genericObjectPoolConfig) {
        String str2 = str + "-" + i;
        if (!grpcPoolMap.containsKey(str2)) {
            grpcPoolMap.put(str2, new GrpcPool(str, i, genericObjectPoolConfig));
        }
        return grpcPoolMap.get(str2);
    }

    public static ManagedChannel getManagedChannel(String str, int i, GenericObjectPoolConfig genericObjectPoolConfig) throws Exception {
        return getGrpcPool(str, i, genericObjectPoolConfig).borrowObject().getChannel();
    }

    public static GrpcClientSingle getGrpcClientSingle(String str, int i, GenericObjectPoolConfig genericObjectPoolConfig) throws Exception {
        return getGrpcPool(str, i, genericObjectPoolConfig).borrowObject();
    }

    public static void returnGrpcClientSingle(String str, int i, GenericObjectPoolConfig genericObjectPoolConfig, GrpcClientSingle grpcClientSingle) throws Exception {
        getGrpcPool(str, i, genericObjectPoolConfig).returnObject(grpcClientSingle);
    }

    private static Map<String, List<Float>> predict(int i, List<Float> list, String str, ManagedChannel managedChannel, String str2, String str3, List<String> list2) {
        HashMap hashMap = new HashMap();
        if (managedChannel != null && list.size() % i == 0) {
            PredictionServiceGrpc.PredictionServiceBlockingStub newBlockingStub = PredictionServiceGrpc.newBlockingStub(managedChannel);
            Predict.PredictRequest.Builder newBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder newBuilder2 = Model.ModelSpec.newBuilder();
            newBuilder2.setName(str);
            newBuilder2.setSignatureName(str2);
            newBuilder.setModelSpec(newBuilder2);
            int size = list.size() / i;
            TensorProto.Builder newBuilder3 = TensorProto.newBuilder();
            newBuilder3.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder newBuilder4 = TensorShapeProto.newBuilder();
            newBuilder4.addDim(TensorShapeProto.Dim.newBuilder().setSize(i));
            newBuilder4.addDim(TensorShapeProto.Dim.newBuilder().setSize(size));
            newBuilder3.setTensorShape(newBuilder4.build());
            newBuilder3.addAllFloatVal(list);
            newBuilder.putInputs(str3, newBuilder3.build());
            Predict.PredictResponse predict = newBlockingStub.predict(newBuilder.build());
            for (String str4 : list2) {
                hashMap.put(str4, predict.getOutputsOrThrow(str4).getFloatValList());
            }
        }
        return hashMap;
    }

    public static <T> Map<String, Map<T, Double>> predict(Map<T, List<Float>> map, String str, int i, String str2, GenericObjectPoolConfig genericObjectPoolConfig, List<String> list) throws Exception {
        HashMap hashMap = new HashMap();
        if (AssertUtil.isAllNotEmpty(new Object[]{map, str, Integer.valueOf(i), str2})) {
            ArrayList arrayList = new ArrayList(map.keySet());
            int size = arrayList.size();
            ArrayList arrayList2 = new ArrayList();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                arrayList2.addAll(map.get(it.next()));
            }
            GrpcPool grpcPool = getGrpcPool(str, i, genericObjectPoolConfig);
            GrpcClientSingle borrowObject = grpcPool.borrowObject();
            Map<String, List<Float>> predict = predict(size, arrayList2, str2, borrowObject.getChannel(), DF_SIGNAtURE_NAME, DF_INPUT, list);
            for (String str3 : list) {
                HashMap hashMap2 = new HashMap();
                List<Float> list2 = predict.get(str3);
                for (int i2 = 0; i2 < size; i2++) {
                    Object obj = arrayList.get(i2);
                    if (list2.get(i2) != null) {
                        hashMap2.put(obj, Double.valueOf(r0.floatValue()));
                    }
                }
                hashMap.put(str3, hashMap2);
            }
            grpcPool.returnObject(borrowObject);
        }
        return hashMap;
    }

    public static void main(String[] strArr) throws Exception {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < 10; i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < 302; i2++) {
                arrayList.add(Float.valueOf((float) (Math.random() * 1.0E-4d)));
            }
            hashMap.put(i + "-", arrayList);
        }
    }
}
