package cn.com.duiba.nezha.compute.common.model;

import java.util.ArrayList;
import java.util.HashMap;

/**
 * Created by jiali on 2017/6/9.
 */

public class BayesianBandit {
    private ArrayList<Double> rewards;
    private ArrayList<Double> counts;
    private ArrayList<Double> alphas;
    private ArrayList<Double> betas;
    private int numMachines;
    private double decay = 1;

    static class Constant{
        static double RT_REWARD_WEIGHT = 0.03;
        static double MIN_REWARD = 0.01;
        static long MAX_CLICK = 1000;
        static long DISCOUNT = 2;
        static double DECAY = 0.999;  //1000次以前的观察，无效
    }

    public MaterialInfo selectMaterial(ArrayList<MaterialInfo> materialList,Long appid)
    {
        //1、init
        clear();
        this.decay = Constant.DECAY;
        numMachines = materialList.size();

        //2、global
        HashMap<Long,Double> gCtrMap = new HashMap<>();
        for(MaterialInfo m : materialList)
        {
            if(m.appId == -1)
            {
                double click = sum(m.click);
                double exposure = sum(m.exposure);

                if (m.lastClick > Constant.MAX_CLICK) {
                    m.lastClick = m.lastClick / Constant.DISCOUNT;
                    m.lastExposure = m.lastExposure / Constant.DISCOUNT;
                }
                m.lastClick = m.lastClick +  getCtr(exposure,click);
                m.lastExposure = m.lastExposure + 1;

                double gCtr = getCtr(m.lastExposure,m.lastClick);
                gCtrMap.put(m.getMaterialId(),gCtr);
            }
        }

        //3、update
        for(MaterialInfo m : materialList)
        {
            double click = sum(m.click);
            double exposure = sum(m.exposure);

            if(exposure > 0) {

                if(m.lastClick == 0)
                {
                    m.reward = 0;
                    m.count = 0;
                }

                if (m.lastClick > Constant.MAX_CLICK) {
                    m.lastClick = m.lastClick / Constant.DISCOUNT;
                    m.lastExposure = m.lastExposure / Constant.DISCOUNT;
                }

                m.lastClick = m.lastClick +  getCtr(exposure,click);
                m.lastExposure = m.lastExposure + 1;

                double rtCtr = getCtr(exposure,click); //最近ctr会低估
                double hisCtr = getCtr(m.lastExposure,m.lastClick);
                double gCtr = gCtrMap.containsKey(m.getMaterialId()) ? gCtrMap.get(m.getMaterialId()) : 0;
                double confidence = Math.min(m.lastExposure/500,1);
                hisCtr = confidence * hisCtr + (1-confidence) * gCtr;

                double reward = click > 0 ? Constant.RT_REWARD_WEIGHT * rtCtr + (1-Constant.RT_REWARD_WEIGHT) * hisCtr : hisCtr;
                reward = Math.max(reward, Constant.MIN_REWARD);
                m.reward = m.reward * decay + reward;
                m.count = m.count * decay + 1.0;
                m.alpha = 1.0 + m.reward;
                m.beta = 1.0 + (m.count - m.reward);
            }

            rewards.add(m.reward);
            counts.add(m.count);
            alphas.add(m.alpha);
            betas.add(m.beta);
        }

        gCtrMap.clear();
        //3、select
        MaterialInfo material = materialList.get(selectMachine());
        return material;
    }


    public BayesianBandit() { //下周
        rewards = new ArrayList<Double>();
        counts = new ArrayList<Double>();
        alphas = new ArrayList<Double>();
        betas = new ArrayList<Double>();
    }

    private void clear()
    {
        rewards = new ArrayList<Double>();
        counts = new ArrayList<Double>();
        alphas = new ArrayList<Double>();
        betas = new ArrayList<Double>();
    }


    /*对各个备选策略的后验Beta分布采样得到当次最优策略*/
    private int selectMachine() {
        int selectMachine = 0;
        double maxTheta = 0;
        for (int i = 0; i < numMachines; i++) {
            double theta = BetaDistribution.BetaDist(alphas.get(i), betas.get(i));
            if (theta > maxTheta) {
                maxTheta = theta;
                selectMachine = i;
            }
        }
        return selectMachine;
    }

    private double getCtr(double exp, double clk)
    {
        return exp > 0 ? clk / exp : 0;
    }

    private double sum(ArrayList<Long> list){
        double sum = 0;
        for(Long val: list)
        {
            sum += val;
        }
        return sum;
    }
}