package cn.com.duiba.nezha.alg.model.util;

import org.apache.commons.collections.CollectionUtils;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorInfo;

import java.io.File;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.commons.lang.StringUtils.isNumeric;

public class OneTFUtils {


    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(OneTFUtils.class);

    public static SavedModelBundle loadModel(String modelPath, String version) {

        String path = Paths.get(modelPath,version).toString();
        logger.info("local model path ="+path);
        SavedModelBundle bundle = SavedModelBundle.load(path, "serve");
        return bundle;
    }

    //获取路径下最新的模型版本
    public static Long getLastVersion(String path) throws Exception {
        Long ret = null;
        File file = new File(path);
        if (file.isDirectory()) {
            File[] listFiles = file.listFiles();
            for (int i = 0; i < Objects.requireNonNull(listFiles).length; i++) {
                File f = listFiles[i];
                String fName = f.getName();
                if (isNumeric(fName)) {
                    long version = Long.parseLong(fName);
                    if (ret == null || ret < version) {
                        ret = version;
                    }
                }
            }
        }

        return ret;
    }

    /***
     * 把字符串的java list换成Tensor
     * 只支持Float跟String两种类型
     * @param input 输入的字符串List
     * @param inputInfo 包括了需要的输入类型，字符串类型默认值为“”，Float类型默认值为0.0
     * @return 转好的Tensor
     */
    public static Tensor<?> listToTensor(List<String> input, TensorInfo inputInfo) {
        DataType dtype = inputInfo.getDtype();
        int inputSize = input.size();
        if (dtype == DataType.DT_FLOAT) {
            float[] arrayInput=new float[inputSize];
            for (int i = 0; i < input.size(); i++) {
                String s=input.get(i);
                float inputWithType = 0.0f;
                if (s!=null) {
                    try {
                        inputWithType = Float.parseFloat(s);
                    } catch (NumberFormatException e) {
                        logger.warn("listToTensor exception", e);
                    }
                }
                arrayInput[i]=inputWithType;
            }
            return Tensors.create(arrayInput);
        } else if (dtype == DataType.DT_STRING) {
            byte[][] arrayInput = new byte[inputSize][];
            for (int i = 0; i < input.size(); i++) {
                String s=input.get(i);
                String inputWithType = "";
                if (s!=null) {
                    inputWithType=s;
                }
                arrayInput[i]=inputWithType.getBytes(UTF_8);
            }
            return Tensors.create(arrayInput);
        } else {
            return null;
        }
    }


    //只适用于shape为(None,1)的Tensor
    //只支持Float一种输出
    public static float[][] tensorToArray(Tensor<?> tensorResult) {
        long[] tensorShape = tensorResult.shape();
        if (org.tensorflow.DataType.FLOAT == tensorResult.dataType()) {
            float[][] listResult = new float[(int) tensorShape[0]][(int) tensorShape[1]];
            tensorResult.copyTo(listResult);
            return listResult;
        }
        return null;
    }

    public static void closeQuitely(AutoCloseable closeable) {
        try {
            if (closeable != null) {
                closeable.close();
            }
        } catch (Exception e) {
            // ignore
        }
    }

    public static <T extends AutoCloseable> void closeQuitely(Collection<T> closeables) {
        if (CollectionUtils.isNotEmpty(closeables)) {
            for (T closeable : closeables) {
                closeQuitely(closeable);
            }
        }
    }
}
