/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.model.tf;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import com.alibaba.fastjson.JSON;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import my.org.tensorflow.framework.DataType;
import my.org.tensorflow.framework.TensorProto;
import my.org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

public class TFServingUtil {
    public static Map<String, ManagedChannel> channelMap;
    public static String DF_INPUT;
    public static String DF_SIGNAtURE_NAME;
    public static String DF_OUTPUT;

    public static ManagedChannel getManagedChannel(String host, int port) {
        ManagedChannel ret = null;
        String key = host + "-" + port;
        if (channelMap == null) {
            channelMap = new HashMap<String, ManagedChannel>();
        }
        if (!channelMap.containsKey(key)) {
            TFServingUtil.setManagedChannel(host, port);
        }
        if ((ret = channelMap.get(key)).isShutdown() || ret.isTerminated()) {
            TFServingUtil.setManagedChannel(host, port);
            ret = channelMap.get(key);
        }
        return ret;
    }

    public static void setManagedChannel(String host, int port) {
        String key = host + "-" + port;
        System.out.println("TFServingUtil create new channel " + key);
        ManagedChannel channel = ManagedChannelBuilder.forAddress((String)host, (int)port).usePlaintext(true).build();
        channelMap.put(key, channel);
    }

    private static List<Float> predict(int featureNums, List<Float> data, String modelName, ManagedChannel channel, String signatureName, String input, String output) {
        List<Float> ret = new ArrayList<Float>();
        if (channel != null && data.size() % featureNums == 0) {
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub((Channel)channel);
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName(modelName);
            modelSpecBuilder.setSignatureName(signatureName);
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
            int totalSize = data.size();
            int featureSize = totalSize / featureNums;
            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize((long)featureNums));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize((long)featureSize));
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
            tensorProtoBuilder.addAllFloatVal(data);
            predictRequestBuilder.putInputs(input, tensorProtoBuilder.build());
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
            TensorProto tensorProto = predictResponse.getOutputsOrThrow(output);
            ret = tensorProto.getFloatValList();
        }
        return ret;
    }

    public static <T> Map<T, Double> predict(Map<T, List<Float>> dataMap, String host, int port, String modelName) throws Exception {
        HashMap ret = new HashMap();
        if (AssertUtil.isAllNotEmpty((Object[])new Object[]{dataMap, host, port, modelName})) {
            ArrayList<T> keyList = new ArrayList<T>(dataMap.keySet());
            int featureNums = keyList.size();
            ArrayList<Float> data = new ArrayList<Float>();
            for (Object key : keyList) {
                data.addAll((Collection)dataMap.get(key));
            }
            ManagedChannel channel = TFServingUtil.getManagedChannel(host, port);
            List<Float> preValueList = TFServingUtil.predict(featureNums, data, modelName, channel, DF_SIGNAtURE_NAME, DF_INPUT, DF_OUTPUT);
            for (int i = 0; i < featureNums; ++i) {
                Object key = keyList.get(i);
                Float valueF = preValueList.get(i);
                if (valueF == null) continue;
                ret.put(key, Double.valueOf(preValueList.get(i).floatValue()));
            }
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        ManagedChannel channel = TFServingUtil.getManagedChannel("47.96.221.208", 9000);
        HashMap dataMap = new HashMap();
        for (int i = 0; i < 10; ++i) {
            ArrayList<Float> data = new ArrayList<Float>();
            for (int j = 0; j < 302; ++j) {
                data.add(Float.valueOf((float)(Math.random() * 1.0E-4)));
            }
            dataMap.put(i + "-", data);
        }
        Map ret = TFServingUtil.predict(dataMap, "47.96.221.208", 9000, "model-1");
        System.out.println("RET= : " + JSON.toJSONString(ret));
    }

    static {
        DF_INPUT = "feat_vals";
        DF_SIGNAtURE_NAME = "serving_default";
        DF_OUTPUT = "prob";
    }
}

