package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import com.alibaba.fastjson.JSON;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

/* loaded from: input_file:cn/com/duiba/nezha/alg/model/TFServingUtil.class */
public class TFServingUtil {
    public static Map<String, ManagedChannel> channelMap;
    public static String DF_INPUT = "feat_vals";
    public static String DF_SIGNAtURE_NAME = "serving_default";
    public static String DF_OUTPUT = "prob";

    public static ManagedChannel getManagedChannel(String str, int i) {
        String str2 = str + "-" + i;
        if (channelMap == null) {
            channelMap = new HashMap();
        }
        if (!channelMap.containsKey(str2)) {
            setManagedChannel(str, i);
        }
        ManagedChannel managedChannel = channelMap.get(str2);
        if (managedChannel.isShutdown() || managedChannel.isTerminated()) {
            setManagedChannel(str, i);
            managedChannel = channelMap.get(str2);
        }
        return managedChannel;
    }

    public static void setManagedChannel(String str, int i) {
        channelMap.put(str + "-" + i, ManagedChannelBuilder.forAddress(str, i).usePlaintext(true).build());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v39, types: [java.util.List] */
    private static List<Float> predict(int i, List<Float> list, String str, ManagedChannel managedChannel, String str2, String str3, String str4) {
        ArrayList arrayList = new ArrayList();
        if (managedChannel != null && list.size() % i == 0) {
            PredictionServiceGrpc.PredictionServiceBlockingStub newBlockingStub = PredictionServiceGrpc.newBlockingStub(managedChannel);
            Predict.PredictRequest.Builder newBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder newBuilder2 = Model.ModelSpec.newBuilder();
            newBuilder2.setName(str);
            newBuilder2.setSignatureName(str2);
            newBuilder.setModelSpec(newBuilder2);
            int size = list.size() / i;
            TensorProto.Builder newBuilder3 = TensorProto.newBuilder();
            newBuilder3.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder newBuilder4 = TensorShapeProto.newBuilder();
            newBuilder4.addDim(TensorShapeProto.Dim.newBuilder().setSize(i));
            newBuilder4.addDim(TensorShapeProto.Dim.newBuilder().setSize(size));
            newBuilder3.setTensorShape(newBuilder4.build());
            newBuilder3.addAllFloatVal(list);
            newBuilder.putInputs(str3, newBuilder3.build());
            arrayList = newBlockingStub.predict(newBuilder.build()).getOutputsOrThrow(str4).getFloatValList();
        }
        return arrayList;
    }

    public static <T> Map<T, Double> predict(Map<T, List<Float>> map, String str, int i, String str2) throws Exception {
        HashMap hashMap = new HashMap();
        if (AssertUtil.isAllNotEmpty(new Object[]{map, str, Integer.valueOf(i), str2})) {
            ArrayList arrayList = new ArrayList(map.keySet());
            int size = arrayList.size();
            ArrayList arrayList2 = new ArrayList();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                arrayList2.addAll(map.get(it.next()));
            }
            List<Float> predict = predict(size, arrayList2, str2, getManagedChannel(str, i), DF_SIGNAtURE_NAME, DF_INPUT, DF_OUTPUT);
            for (int i2 = 0; i2 < size; i2++) {
                Object obj = arrayList.get(i2);
                if (predict.get(i2) != null) {
                    hashMap.put(obj, Double.valueOf(predict.get(i2).floatValue()));
                }
            }
        }
        return hashMap;
    }

    public static void main(String[] strArr) throws Exception {
        getManagedChannel("47.96.221.208", 9000);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < 10; i++) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < 302; i2++) {
                arrayList.add(Float.valueOf((float) (Math.random() * 1.0E-4d)));
            }
            hashMap.put(i + "-", arrayList);
        }
        System.out.println("RET= : " + JSON.toJSONString(predict(hashMap, "47.96.221.208", 9000, "model-1")));
    }
}
