package cn.com.duiba.nezha.alg.alg.correct;

import cn.com.duiba.nezha.alg.alg.vo.advert.AdBidParamsDo;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import lombok.Data;

import java.util.Arrays;

@Data
public class IsotonicCalibrator {
    private static final String redisKeyPrefix = "NZ_CALIBRATION"; // 该类的rediskey前缀
    private long targetConvertCount; // 该维度下的转化次数, 英文名记得要改一下
    private String dim; // 校准所在的维度
    private String redisKey; // 存储该模型所用的rediskey
    private String updateTime; // 模型更新时间

    private double[] boundaries; // 保序回归的边界点
    private double[] predictions; // 保序回归对应边界点的预估值

    private Double lowerBound; //校准因子的下限
    private Double upperBound; //校准因子的上限

    public static String getModelRedisKey(Long slotId, String abTestId, Long advertId) {
        String slotAlgPrefix = AdBidParamsDo.concatSlotAlgPrefix(slotId, abTestId);
        return getModelRedisKeyWithSlotAlgPrefix(slotAlgPrefix, advertId);
    }

    public static String getModelRedisKeyWithSlotAlgPrefix(String slotAlgPrefix, Long advertId) {
        if (slotAlgPrefix == null) {
            return redisKeyPrefix;
        } else if (advertId == null) {
            return redisKeyPrefix + "_" + slotAlgPrefix;
        } else {
            return redisKeyPrefix + "_" + slotAlgPrefix + "_" + advertId;
        }
    }

    public void valid() throws RuntimeException {
        int sizeB = boundaries.length;
        int sizeP = predictions.length;
        if (sizeB != sizeP) {
            throw new RuntimeException("the length of boundaries and predictions must be same");
        }
        for (int i = 1; i < boundaries.length; i++) {
            if (boundaries[i - 1] > boundaries[i]) {
                throw new RuntimeException("Isotonic need monotone boundaries and predictions");
            }
        }

        for (int i = 1; i < predictions.length; i++) {
            if (predictions[i - 1] > predictions[i]) {
                throw new RuntimeException("Isotonic need monotone boundaries and predictions");
            }
        }
    }

    private double linearInterpolation(double x1, double y1, double x2, double y2, double x) {
        return y1 + (y2 - y1) * (x - x1) / (x2 - x1);
    }

    public double calibrateRatio(double testData) {


        int foundIndex = Arrays.binarySearch(boundaries, testData);
        int insertIndex = -foundIndex - 1;

        // Find if the index was lower than all values,
        // higher than all values, in between two values or exact match.
        if (insertIndex == 0) {
            return predictions[0];
        } else if (insertIndex == boundaries.length) {
            return predictions[predictions.length - 1];
        } else if (foundIndex < 0) {
            return linearInterpolation(boundaries[insertIndex - 1], predictions[insertIndex - 1], boundaries[insertIndex], predictions[insertIndex], testData);
        } else {
            return predictions[foundIndex];
        }
    }

    public double calibrate(double testData) {
        double rawOutput = calibrateRatio(testData);
        return rawOutput / (testData + 1e-12);
    }

}
