package cn.com.duiba.nezha.compute.alg;

import cn.com.duiba.nezha.compute.api.dto.CorrectionInfo;
import cn.com.duiba.nezha.compute.api.dto.NezhaStatDto;
import cn.com.duiba.nezha.compute.common.util.AssertUtil;
import org.apache.log4j.Logger;

import java.util.List;

/**
 * Created by pc on 2017/9/12.
 */
public class ModelPredRectifier {
    private static Logger logger = Logger.getLogger(ModelPredRectifier.class);

    private static Double lndWeight = 0.5;

    private static Double ndWeight = 0.5;


    /**
     * 批量模型预估值统计水平纠偏
     *
     * @param advertCorrectionInfoList
     * @throws Exception
     */
    public static void getCorrectionFactor(List<CorrectionInfo> advertCorrectionInfoList) throws Exception {

        if (AssertUtil.isEmpty(advertCorrectionInfoList)) {
            return;
        }
        for (CorrectionInfo info : advertCorrectionInfoList) {
            if (info != null) {
                // 1 调节因子赋默认值
                info.setCorrectionFactor(1.0);
                info.setReconstructionFactor(1.0);
                setCorrectionFactor(info);
            }

        }


    }

    /**
     * 批量模型预估值 统计水平纠偏 & 分布重构
     *
     * @param advertCorrectionInfoList
     * @throws Exception
     */
    public static void getCorrectionReconstructionFactor(List<CorrectionInfo> advertCorrectionInfoList) throws Exception {


        if (AssertUtil.isEmpty(advertCorrectionInfoList)) {
            return;
        }
        for (CorrectionInfo info : advertCorrectionInfoList) {
            if (info != null) {
                // 1 调节因子赋默认值
                info.setCorrectionFactor(1.0);
                info.setReconstructionFactor(1.0);

                setCorrectionFactor(info);
                setReconstructionFactor(info);
            }

        }

    }


    /**
     * 模型预估值统计水平纠偏
     *
     * @param info
     * @throws Exception
     */
    public static void setCorrectionFactor(CorrectionInfo info) throws Exception {

        if (info == null) {
            logger.warn("setCorrectionFactor input invalid,with info = null");
            return;
        }
        if (info.getCurrentPreValue() == null) {
            logger.warn("setCorrectionFactor input invalid,with info.getCurrentPreValue() = null");
            return;
        }

        if (info.getCurrentPreValue()<0.0000000001) {
            return;
        }

        if (info.getNezhaStatDto() == null) {
            return;
        }



        // 2 解析
        Long type = info.getType();
        Double currentPreValue = info.getCurrentPreValue();
        NezhaStatDto nezhaStatDto = info.getNezhaStatDto();

        Double avgPreValue = null;
        Double avgFeedbackValue = null;
        Double vUpperLimit = null;
        Double vLowerLimit = null;

        Double fUpperLimit = null;
        Double fLowerLimit = null;

        if (type.equals(1L)) {
            // CTR
            avgPreValue = nezhaStatDto.getPreCtrAvg();
            avgFeedbackValue = nezhaStatDto.getStatCtrAvg();

            vUpperLimit = 0.99;
            vLowerLimit = 0.01;

            fUpperLimit = 4.0;
            fLowerLimit = 0.25;


        }
        if (type.equals(2L)) {


            // CVR
            avgPreValue = nezhaStatDto.getPreCvrAvg();
            avgFeedbackValue = nezhaStatDto.getStatCvrAvg();

            vUpperLimit = 0.99;
            vLowerLimit = 0.0005;
            fUpperLimit = 4.0;
            fLowerLimit = 0.1;
        }



        Double newCurrentPreValue = correctionValue(currentPreValue, avgPreValue, avgFeedbackValue, vLowerLimit, vUpperLimit);

//        System.out.println("newCurrentPreValue"+newCurrentPreValue);


        Double currentPreFactor = noiseSmoother(newCurrentPreValue / currentPreValue, fLowerLimit, fUpperLimit);

        info.setCorrectionFactor(currentPreFactor);


    }

    /**
     * 批量模型预估值 分布重构
     * @param info
     * @throws Exception
     */
    public static void setReconstructionFactor(CorrectionInfo info) throws Exception {

        if (info == null) {
            logger.warn("setReconstructionFactor input invalid,with info = null");
            return;
        }
        if (info.getCurrentPreValue() == null) {
            logger.warn("setReconstructionFactor input invalid,with info.getCurrentPreValue() = null");
            return;
        }

        if (info.getNezhaStatDto() == null) {
            return;
        }


        // 2 解析
        Long type = info.getType();
        Double currentPreValue = info.getCurrentPreValue();
        NezhaStatDto nezhaStatDto = info.getNezhaStatDto();

        Double avgFeedbackValue = null;

        Double lowerLimit =null;
        Double upperLimit = null;
        if (type.equals(1L)) {
            // CTR
            avgFeedbackValue = nezhaStatDto.getStatCtrAvg();

            lowerLimit = 0.5;
            upperLimit = 1.0;
        }

        if (type.equals(2L)) {
            // CVR
            avgFeedbackValue = nezhaStatDto.getStatCvrAvg();

            lowerLimit = 0.8;
            upperLimit = 1.0;

            return;
        }
        




        Double correctionFactor = info.getCorrectionFactor();

        Double newCurrentPreValue = currentPreValue * correctionFactor;

        Double reconstructionValue = reconstructionValue(newCurrentPreValue, avgFeedbackValue, lowerLimit, upperLimit);

        Double reconstructionFactor = noiseSmoother(reconstructionValue / newCurrentPreValue, lowerLimit, upperLimit);

        info.setReconstructionFactor(reconstructionFactor);

    }


    /**
     * 结果平滑,限定调节范围
     *
     * @param point
     * @param upperLimit
     * @param lowerLimit
     * @return
     * @throws Exception
     */
    public static Double noiseSmoother(Double point, Double lowerLimit, Double upperLimit) throws Exception {
        Double ret = point;

        if (AssertUtil.isAnyEmpty(upperLimit, lowerLimit)) {
            logger.warn("noiseSmoother input invalid,with upperLimit=" + upperLimit + ",lowerLimit=" + lowerLimit);
            return ret;
        }

        if (point != null) {
            ret = point > upperLimit ? upperLimit : (point < lowerLimit ? lowerLimit : point);
        }

        return ret;
    }


    /**
     *
     *
     * @param currentPreValue
     * @param avgPreValue
     * @param avgFeedbackValue
     * @param lowerLimit
     * @param upperLimit
     * @return
     * @throws Exception
     */
    public static Double correctionValue(Double currentPreValue, Double avgPreValue, Double avgFeedbackValue, Double lowerLimit, Double upperLimit) throws Exception {
        Double ret = currentPreValue;

        if (AssertUtil.isAnyEmpty(currentPreValue, avgPreValue, avgFeedbackValue, upperLimit, lowerLimit)) {
            return ret;
        }

        if (currentPreValue > 0.0 && avgPreValue > 0.0 && avgFeedbackValue > 0.0) {
            // 对数正太分布,同方差假设,调整
            Double lndValue = noiseSmoother(currentPreValue * avgFeedbackValue / avgPreValue, lowerLimit, upperLimit);

            // 正太分布,同方差,调整
            Double ndValue = noiseSmoother(currentPreValue + avgFeedbackValue - avgPreValue, lowerLimit, upperLimit);

            // 交叉调节
            if (avgPreValue > avgFeedbackValue) {
                ret = Math.max(lndValue, ndValue);
            } else {
                ret = Math.min(lndValue, ndValue);
            }

            // 融合调节
//            ret = lndWeight * lndValue + ndWeight * ndValue;

        }

        return ret;
    }

    /**
     *
     * @param newCurrentPreValue
     * @param avgFeedbackValue
     * @param lowerLimit
     * @param upperLimit
     * @return
     * @throws Exception
     */
    public static Double reconstructionValue(Double newCurrentPreValue, Double avgFeedbackValue, Double lowerLimit, Double upperLimit) throws Exception {
        Double ret = newCurrentPreValue;

        if (AssertUtil.isAnyEmpty(newCurrentPreValue, avgFeedbackValue, upperLimit, lowerLimit)) {
            return ret;
        }

        if (newCurrentPreValue > 0.0 && avgFeedbackValue > 0.0) {

            if (newCurrentPreValue < avgFeedbackValue) {

                // 打压
                Double ratio = newCurrentPreValue / avgFeedbackValue;


                Double reconstructionValue = newCurrentPreValue * sigmoidWithZoomAndIntervalMap((ratio - 0.5), lowerLimit, upperLimit, 10);

                ret = reconstructionValue;


            }

        }

        return ret;
    }


    /**
     *
     * @param x
     * @param lowerLimit
     * @param upperLimit
     * @param zoom
     * @return
     */
    public static Double sigmoidWithZoomAndIntervalMap(double x, double lowerLimit, double upperLimit, double zoom) {

        return lowerLimit + (upperLimit - lowerLimit) * sigmoidWithZoom(x, zoom);


    }


    /**
     *
     *
     * @param x
     * @param zoom
     * @return
     */
    public static Double sigmoidWithZoom(double x, double zoom) {
        return sigmoid(x*zoom);
    }


    /**
     *
     *
     * @param x
     * @return
     */
    public static Double sigmoid(double x) {
        return 1 / (1+Math.exp(-x));
    }


}
