package cn.com.duiba.nezha.alg.common.model.roipid;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

public class RoiPidController2 {


    private static final Logger logger = LoggerFactory.getLogger(cn.com.duiba.nezha.alg.common.model.roipid.RoiPidController2.class);

    private double P = 0;
    private double I = 0;
    private double D = 0;
    private double F = 0;

    private boolean firstRun=true;
    private double errorSum=0;
    private double lastOutput=0;
    private double lastActual=0;
    private double lastFactor=0;

    private double maxIOutput=0;
    private double maxError=0;
    private double maxOutput=0;
    private double minOutput=0;

    private double outputRampRate=0;
    private double outputFilter=0;

    public enum LifeStatus{
        NORMAL,DIE
    }

    public enum RunStatus{
        COLDBOOT,CHANGING,STAMBLE
    }

    static class Constant{

        static double DEFAULT_FACTOR = 1.0;
        static double THRESHOLD1 = 20;//
        static double THRESHOLD2 = 12;//媒体维度转化数量
        static double THRESHOLD3 = 12;//配置维度PID转化数量

        //控制器参数
        static double DEFAULT_P = 0.5;
        static double DEFAULT_I = 0.01;
        static double DEFAULT_D = 0.5;
    }

    public RoiPidController2(){
        P = RoiPidController2.Constant.DEFAULT_P;
        I = RoiPidController2.Constant.DEFAULT_I;
        D = RoiPidController2.Constant.DEFAULT_D;
        checkSigns();
    }

    private double getInitFactor(List<StatInfo> infos, double targetCpa)
    {
        double initFactor = RoiPidController2.Constant.DEFAULT_FACTOR;
        for(StatInfo info : infos)
        {
            if(info.id.contains("DEFAULT"))
            {
                double actual = info.sumConv > 0 ? info.sumFee / info.sumConv : targetCpa;

                initFactor = targetCpa / actual;

                if(info.sumConv >= 3 && info.sumConv < 5)
                {
                    initFactor = initFactor > 1.2 ? 1.2 : initFactor;
                    initFactor = initFactor < 0.8 ? 0.8 : initFactor;
                }
                else if(info.sumConv >= 5 && info.sumConv < 10)
                {
                    if(initFactor < 1)
                    {
                        initFactor = Math.pow(initFactor, 1.5);
                    }
                    initFactor = initFactor > 1.5 ? 1.5 : initFactor;
                    initFactor = initFactor < 0.7 ? 0.7 : initFactor;
                }
                else if(info.sumConv >= 10)
                {
                    if(initFactor < 1)
                    {
                        initFactor = Math.pow(initFactor, 2);
                    }
                    initFactor = initFactor > 2 ? 2 : initFactor;
                    initFactor = initFactor < 0.5 ? 0.5 : initFactor;
                }
                else
                {
                    if(info.sumFee < 5 * targetCpa)
                        initFactor = 1;
                    else {
                        initFactor = targetCpa /info.sumFee;
                        initFactor = initFactor > 1.5 ? 1.5 : initFactor;
                        initFactor = initFactor < 0.5 ? 0.5 : initFactor;
                    }
                }

                break;
            }
        }
        return initFactor;
    }

    public List<StatInfo> getPriceFactor2(List<StatInfo> infos, double orientTargetCpa, double appTargetCpa, double budget, int mode)
    {
        List<StatInfo> newInfos = new ArrayList<>();

        double initFactor = getInitFactor(infos,orientTargetCpa);
//        System.out.println(initFactor);

        for(StatInfo info : infos)
        {
            if(info.sumClick < RoiPidController2.Constant.THRESHOLD1)
                continue;
            if(info.id.contains("ACTIVITY"))
            {
                info.factor = 1;
                info.lastSumFee = info.sumFee;
                info.lastSumConv = info.sumConv;
                newInfos.add(info);
                continue;
            }

            if(!checkParam(info, orientTargetCpa) || !checkParam(info, appTargetCpa) || !checkStart(info, orientTargetCpa)) {
                continue;
            }

            if(info.id.contains("APP") || info.id.contains("SLOT") || info.id.contains("TAG"))
            {
                if(info.sumConv < RoiPidController2.Constant.THRESHOLD2) {
                    continue;
                }
                double factor = getOnePriceFactor(info, appTargetCpa, budget, initFactor, mode);
                info.factor = Math.min(Math.max(factor, 0.2),4);

            }
            else if(info.id.contains("DEFAULT"))
            {
                double factor = getDefaultPriceFactor(info, appTargetCpa, budget, initFactor,mode);
                info.factor = Math.min(Math.max(factor, 0.5),2);
            }

            //update
            info.lastSumFee = info.sumFee;
            info.lastSumConv = info.sumConv;
            newInfos.add(info);

            if (Math.random()<0.1) {
                logger.info("getPriceFactor2 id{} factor {}", info.id, info.factor);
            }
        }
        return newInfos;
    }

    private double init(StatInfo info, double targetCpa, double budget, double initFactor) {

        lastFactor = info.factor;

        //1、start
        if (info.lastSumFee == 0 && Math.abs(info.factor - 1) < 0.00001) //from default
        {
            info.factor = initFactor;
        } else if (info.lastSumFee == 0) // from yesterday
        {
            info.factor = info.factor < initFactor * 0.6 ? initFactor * 0.6 : info.factor;
            info.factor = info.factor > initFactor * 1.4 ? initFactor * 1.4 : info.factor;
            if(info.factor > 1 && initFactor < 1)
                info.factor = 1;
        }

        double beforeFactor = info.factor;
        info.factor = smoothInitWith7d(info, targetCpa, budget, info.factor);
        if(Math.abs(beforeFactor - info.factor)>0.1)
            logger.info("pid controller smooth Id:{} beforeFactor:{} afterFactor:{} ", info.id, beforeFactor ,info.factor);

        //2、controller
        if (info.sumConv >= 5) {
            double actual = info.sumFee / info.sumConv;
            setOutputLimits(targetCpa);
            lastOutput = info.lastSumConv > 0 ? lastFactor * targetCpa - 1 : 0; //近似算法
            lastActual = info.lastSumConv > 0 ? info.lastSumFee / info.lastSumConv : actual;
            maxError = 100;
            errorSum = constrain(info.lastSumConv * targetCpa - info.lastSumFee, -maxError, maxError);//近似算法

            double output = getOutput(actual, targetCpa);
            info.factor = adjust(output,info.factor,lastFactor,targetCpa,0.9,1.1);

            //3、smooth
            double rFactor = targetCpa / actual;
            info.factor = output < 0 && info.factor > rFactor ? rFactor : info.factor;
            info.factor = output > 0 && info.factor < rFactor ? rFactor : info.factor;
            info.factor = output < 0 ? Math.max(info.factor, lastFactor * 0.8) : Math.min(info.factor, lastFactor * 1.4);
        }

        return info.factor;
    }

    private double boot(StatInfo info, double targetCpa, double budget, double initFactor)
    {
        lastFactor = info.factor;

        if(info.sumConv > 0) {
            double actual = info.sumFee / info.sumConv;
            setOutputLimits(targetCpa);

            lastOutput = info.lastSumConv > 0 ? lastFactor * targetCpa - 1 : 0; //近似算法
            lastActual = info.lastSumConv > 0 ? info.lastSumFee / info.lastSumConv : actual;
            maxError = 100;
            errorSum = constrain(info.lastSumConv * targetCpa - info.lastSumFee, -maxError, maxError);//近似算法

            double output = getOutput(actual, targetCpa);
            info.factor = adjust(output, info.factor, lastFactor, targetCpa, 0.9, 1.1);

        }
        else if(info.sumFee > 2 * targetCpa)//没有转化
        {
            info.factor = Math.min(initFactor * 0.9,info.factor);
        }
        else
        {
            info.factor = Math.min(initFactor,info.factor);
        }

        // updated on 2019-10-28 for stable cpa
        info.factor = Math.max(info.factor, initFactor * 0.8);
        info.factor = Math.min(info.factor, initFactor * 1.2);

        return  info.factor;
    }


    private double run(StatInfo info, double targetCpa, double budget, double initFactor)
    {
        lastFactor = info.factor;
        //1、controller
        double actual = info.sumConv > 0 ? info.sumFee / info.sumConv : info.sumFee * 2;
        setOutputLimits(targetCpa);

        lastOutput = info.lastSumConv > 0 ? lastFactor * targetCpa - 1 : 0; //近似算法
        lastActual = info.lastSumConv > 0 ? info.lastSumFee / info.lastSumConv : actual;
        maxError = 100;
        errorSum = constrain(info.lastSumConv * targetCpa - info.lastSumFee,-maxError,maxError);//近似算法

        double output = getOutput(actual, targetCpa);
        if(info.factor > 0.6 && info.factor < 1)
            info.factor = adjust(output,info.factor,lastFactor,targetCpa,0.8,1.1);
        else
            info.factor = adjust(output,info.factor,lastFactor,targetCpa,0.9,1.1);

        //2、smooth
        double rFactor = targetCpa / actual;
        if(targetCpa >= 3000 && info.sumFee < 5 * targetCpa) {
            info.factor = Math.max(info.factor, initFactor * 0.5);
        }

        info.factor = Math.max(info.factor, initFactor * 0.2);


        info.factor = Math.min(info.factor, initFactor * 2);
        if(info.sumConv >= 5 && info.sumConv < 10) {
            info.factor = output < 0 && info.factor > rFactor ? rFactor : info.factor;
            info.factor = output > 0 && info.factor < rFactor ? rFactor : info.factor;
            info.factor = output < 0 ? Math.max(info.factor, lastFactor * 0.85) : Math.min(info.factor, lastFactor * 1.3);
        }
        else if(info.sumConv >= 10) {
            info.factor = output < 0 && info.factor > rFactor ? rFactor : info.factor;
            info.factor = output > 0 && info.factor < rFactor ? rFactor : info.factor;
            info.factor = output < 0 ? Math.max(info.factor, lastFactor * 0.7) : Math.min(info.factor, lastFactor * 1.4);
        }
        //info.factor = smoothWith7d(info, targetCpa, budget, info.factor);
        return  info.factor;
    }



    public enum InitStatus {
        low, high, contradict, unkown;
    }


    private  double smoothInitWith7d(StatInfo info,double targetCpa, double budget, double initFactor)
    {
        class LocalConstant{
            double minConfi = 0.8, minCpcGap = 0.3, defaultFactor = 1.0;
        }
        LocalConstant constant = new LocalConstant();

        double cpa7d = info.conv7d > 0 ? info.fee7d / info.conv7d : -1;
        double cpa1d = info.lastSumConv > 0 ? info.lastSumFee / info.lastSumConv : -1;
        double cpc1d = info.sumClick > 0 ? (info.sumFee - info.lastSumFee) / info.sumClick : -1;
        double cpc7d = info.click7d > 0 ? info.fee7d / info.click7d : -1;
        double confi7d = getConfidence(info.fee7d / targetCpa,10);
        double confi1d = getConfidence(info.lastSumConv,10);
        double bias7d =  targetCpa / cpa7d;
        double bias1d =  targetCpa / cpa1d;

        InitStatus stat = InitStatus.unkown;
        if(cpa7d == -1 || cpa1d == -1 || cpc1d == -1 || cpc7d == -1)
            return initFactor;

        if(Math.abs(cpc1d - cpc7d) < cpc1d * constant.minCpcGap && confi7d > constant.minConfi)
        {
            if(bias7d < constant.defaultFactor && initFactor < constant.defaultFactor)
                stat = InitStatus.low;
            else if(bias7d > constant.defaultFactor && initFactor > constant.defaultFactor)
                stat = InitStatus.high;
            else
                stat = InitStatus.contradict;
        }

        double newInitFactor = initFactor;
        switch (stat)
        {
            case low:
                newInitFactor = Math.min(bias7d, initFactor); break;
            case high:
                newInitFactor = Math.max(bias7d, initFactor); break;
            case contradict:
                newInitFactor = (1 - confi1d) * bias7d + confi1d * initFactor; break;
            default:
                newInitFactor = initFactor; break;
        }

        return  newInitFactor;
    }

    private double getConfidence(double cur, double max)
    {
        return cur < max ? cur / max : 1;
    }

    private double getOnePriceFactor(StatInfo info, double targetCpa,double budget, double initFactor, int mode)
    {
        //1、start
        reset();
        if(initFactor < 0.9)
        {
            targetCpa = targetCpa * 0.85;
        }
        if((mode == 1 || mode == 3) && initFactor < 1 && initFactor > 0)
        {
            targetCpa = targetCpa * Math.min(1.25, Math.sqrt(1/initFactor));
        }

        if(info.lastSumFee ==0) {
            return init(info, targetCpa, budget, initFactor);
        }
        return run(info, targetCpa, budget, initFactor);
    }


    private double getDefaultPriceFactor(StatInfo info, double targetCpa,double budget, double initFactor, int mode)
    {
        double lastfactor = info.factor;
        if(info.lastSumFee == 0) {
            lastfactor = 1;
        }
        if((mode == 1 || mode == 3) && initFactor < 1 && initFactor > 0)
        {
            initFactor = Math.sqrt(initFactor);
        }

        //消耗少的时候探价
        if (info.PIDFactorExplorationDo != null && info.PIDFactorExplorationDo.factorExploreMap != null && info.PIDFactorExplorationDo.factorFlowRateMap != null
                && info.sumFee < 30000 && info.factor < 0.65) {
            Map<String, Double> factorMap = info.PIDFactorExplorationDo.factorExploreMap;
            Map<String, Double> FlowRateMap = info.PIDFactorExplorationDo.factorFlowRateMap;
            Map<String, String> result = flowSplit(FlowRateMap, factorMap, initFactor);
            String level = result.get("level");
            info.factor = DataUtil.string2Double(result.get("factor"));
            logger.info("pid explore factor id {} factor {}", info.id, info.factor);

            return info.factor;

        } else if (info.sumConv >= RoiPidController2.Constant.THRESHOLD3) {
            if(initFactor < 0.9)
            {
                targetCpa = targetCpa * 0.9;
            }
            if((mode == 1 || mode == 3) && initFactor < 1 && initFactor > 0)
            {
                targetCpa = targetCpa * Math.min(1.25, Math.sqrt(1/initFactor));
            }
            return run(info, targetCpa, budget, initFactor);

        } else {
            info.factor = Math.max(lastfactor * 0.8, initFactor);
            info.factor = Math.min(lastfactor * 1.2, info.factor);

            if(mode >= 8)
            {
                initFactor = 1;
                info.factor = 1;
            }
            return info.factor;
        }

    }


    private double adjust(double output, double factor, double lastFactor, double targetCpa, double negtive, double positive)
    {
        if(output < 0) {
            factor = lastFactor * Math.max(1 + output / targetCpa, negtive);
        }
        else {
            factor = lastFactor * Math.min(1 + output / targetCpa, positive);
        }
        return  factor;
    }


    private boolean checkParam(StatInfo info, double targetCpa)
    {
        if(info.sumFee <= 0 || info.sumConv < 0 || info.lastSumFee < 0 || info.lastSumConv <0 || info.factor <= 0 || targetCpa <= 0)
        {
            return false;
        }
        return true;
    }

    private boolean checkStart(StatInfo info, double targetCpa)
    {
        if(info.sumFee <= 3000 && info.sumConv < 1)
            return false;
        return true;
    }


        private boolean checkColdboot(StatInfo info, double targetCpa)
        {
            if(info.sumConv < 5 && info.sumFee/targetCpa < 5)
                return true;
            return false;
        }

    private RoiPidController2(double p, double i, double d){
        P=p; I=i; D=d;
        checkSigns();
    }

    private RoiPidController2(double p, double i, double d, double f){
        P=p; I=i; D=d; F=f;
        checkSigns();
    }

    private void setI(double i){  /*I翻倍，累计错误减半*/
        if(I != 0){
            errorSum = errorSum * I / i;
        }
        if(maxIOutput!=0){
            maxError = maxIOutput / i;
        }
        I = i;
        checkSigns();
    }


    private void setPID(double p, double i, double d){
        P=p;D=d;
        setI(i);    //积分项具有额外的计算，需要专门set
        checkSigns();
    }
    private void setPID(double p, double i, double d,double f){
        P=p;D=d;F=f;
        setI(i);
        checkSigns();
    }

    private void setMaxIOutput(double maximum){
        maxIOutput = maximum;
        if(I!=0){
            maxError = maxIOutput/I;
        }
    }

    private void setOutputLimits(double output){
        setOutputLimits(-output,output);
    }

    private void setOutputLimits(double minimum,double maximum){
        if(maximum < minimum)return;
        maxOutput=maximum;
        minOutput=minimum;

        // Ensure the bounds of the I term are within the bounds of the allowable output swing
        if(maxIOutput==0 || maxIOutput>(maximum-minimum) ){
            setMaxIOutput(maximum-minimum);
        }
    }

    private double getOutput(double actual, double setpoint){
        double output;
        double Poutput;
        double Ioutput;
        double Doutput;
        double Foutput;

        double error = setpoint - actual;

        Foutput = F * setpoint;
        Poutput = P * error;

        if(firstRun){
            lastActual = actual;
            lastOutput = Poutput + Foutput;
            firstRun = false;
        }

        // D：负数，减缓控制系统，防止调整信号突变与超调
        Doutput = - D * (actual - lastActual);
        lastActual = actual;

        // I: 积分项
        // 1. maxIoutput 限制了I项的权重。
        // 2. prevent windup by not increasing errorSum if we're already running against our max Ioutput
        // 3. prevent windup by not increasing errorSum if output is output= maxOutput
        Ioutput = I * errorSum;
        if(maxIOutput != 0){
            Ioutput = constrain(Ioutput, -maxIOutput, maxIOutput);
        }

        output = Foutput + Poutput + Ioutput + Doutput;

        // error平滑
        if(minOutput != maxOutput && !bounded(output, minOutput,maxOutput) ){ //error超出限制重置，让系统更平滑。P充分大，I开始作用。
            errorSum = error;
        }
        else if(outputRampRate!=0 && !bounded(output, lastOutput-outputRampRate,lastOutput+outputRampRate) ){
            errorSum = error;
        }
        else if(maxIOutput!=0){
            errorSum = constrain(errorSum + error,-maxError,maxError);
        }
        else{
            errorSum += error;
        }

        // 控制信号平滑
        if(outputRampRate!=0){
            output = constrain(output, lastOutput - outputRampRate,lastOutput + outputRampRate);
        }
        if(minOutput!=maxOutput){
            output = constrain(output, minOutput,maxOutput);
        }
        if(outputFilter!=0){
            output = lastOutput*outputFilter+output*(1-outputFilter);
        }

        lastOutput = output;
        return output;
    }

    private void reset(){
        firstRun=true;
        errorSum=0;
    }

    private void setOutputRampRate(double rate){
        outputRampRate=rate;
    }


    private void setOutputFilter(double strength){
        if(strength==0 || bounded(strength,0,1)){
            outputFilter=strength;
        }
    }

    private double constrain(double value, double min, double max){
        if(value > max){ return max;}
        if(value < min){ return min;}
        return value;
    }

    private boolean bounded(double value, double min, double max){
        return (min<value) && (value < max);
    }

    private void checkSigns(){
        if(P<0) P *= -1;
        if(I<0) I *= -1;
        if(D<0) D *= -1;
        if(F<0) F *= -1;
    }




    /**
     * 流量分配，返回level和试探因子
     * @param flowRateMap
     * @param factorMap
     * @param defaultValue
     * @return
     */
    public static Map<String, String> flowSplit(Map<String, Double> flowRateMap, Map<String, Double> factorMap, Double defaultValue) {

        Map<String, String> retMap = new HashMap<>();
        String level = PIDFactorExploration.FactorLevel.HIGH.getCode();
        Double factor = defaultValue;

        Double sumRate = 0.00;
        Double[] sumRateList = new Double[PIDFactorExploration.FactorLevel.values().length+1];

        if (AssertUtil.isAllNotEmpty(flowRateMap, factorMap)) {

            sumRateList[0] = sumRate;
            for (PIDFactorExploration.FactorLevel factorLevel : PIDFactorExploration.FactorLevel.values()) {
                String key = factorLevel.getCode();
                int i = DataUtil.toInt(DataUtil.string2Long(key));
                sumRate += nullToDefault(flowRateMap.get(key), 0.0);
                sumRateList[i+1] = sumRate;
                }

            double random = Math.random();

            for (int i = 0; i < sumRateList.length; i++) {

                if (random >= sumRateList[i] && random < sumRateList[i+1]) {
                    level = DataUtil.Integer2String(i);
                    factor = factorMap.get(level);
                    if (factor == null) {
                        factor = defaultValue;
                    }
                    break;
                }
            }
        }
        retMap.put("level", level);
        retMap.put("factor", DataUtil.double2String(factor));
        return retMap;
    }


    public static Double nullToDefault(Double value, Double defaultValue) {

        Double ret = defaultValue;
        if (value != null && value > defaultValue) {
            ret = value;
        }
        return ret;
    }


    public static void main(String[] args) {
//        String test = "11789_0_ACTIVITY_23032_1_1092";
        // System.out.print(test.split("_")[3]);

        RoiPidController2 PID2 = new RoiPidController2();
        RoiPidController PID = new RoiPidController();

        StatInfo s1 = new StatInfo();
        s1.id = "58315_134218_DEFAULT";
        s1.sumFee = 344531;
        s1.sumConv = 495;
        s1.lastSumConv = 485;
        s1.lastSumFee = 339267;
        s1.sumClick = 100;
        s1.factor = 0.9;
        s1.parentFactor = 1.0;

//        String str = "{\"click7d\":0.0,\"conv7d\":0.0,\"factor\":0.9999999999999997,\"fee7d\":0.0,\"id\":\"DEFAULT\",\"lastSumConv\":6.0,\"lastSumFee\":13438.0,\"parentFactor\":0.9999999999999997,\"sumClick\":100.0,\"sumConv\":16.0,\"sumFee\":16810.0}";
//        StatInfo exp = JSONObject.parseObject(str, StatInfo.class);

        List<StatInfo> infos = new ArrayList<>();
        infos.add(s1);

        double targetCpa = 625.0;
        double budget = 5000000;
        List<StatInfo> newinfos2 = PID2.getPriceFactor2(infos, targetCpa, targetCpa, budget,4);
        for (StatInfo f :newinfos2) {
            Double factor = f.factor;
            System.err.println(factor);
        }
//        List<StatInfo> newinfos = PID.getPriceFactor(infos, targetCpa, targetCpa, budget,4);
//        for (StatInfo f :newinfos) {
//            Double factor = f.factor;
//            System.err.println(factor);
//        }
    }

}


