package cn.com.duiba.nezha.compute.common.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.List;



/**
 * Created by jiali on 2017/6/9.
 */

public class BayesianBandit {
    private ArrayList<Double> rewards;
    private ArrayList<Double> counts;
    private ArrayList<Double> alphas;
    private ArrayList<Double> betas;
    private int numMachines;
    private double decay = 1;

    static class Constant{
        static double RT_REWARD_WEIGHT = 0;
        static double MIN_REWARD = 0.01;
        static long MAX_CLICK = 1000;
        static long DISCOUNT = 2;
        static double DECAY = 0.999;  //1000次以前的观察，无效
    }

    class Info{
        long id;
        double gctr;
        double gexp;
        double hctr;
        double hexp;
        double rctr;
        double rexp;
        MaterialInfo gMaterial;
    }

    public List<MaterialInfo> selectMaterial(ArrayList<MaterialInfo> materialList,Long appid)
    {
        //1、init
        clear();
        this.decay = Constant.DECAY;
        numMachines = materialList.size();

        HashMap<Long,Info> mMap = new HashMap();

        double maxG = 0.01, maxH = 0.01;

        //2、info
        for(MaterialInfo m : materialList)
        {
            double click = sum(m.click);
            double exposure = sum(m.exposure);

            //不发券不会更新
            if (m.lastClick > Constant.MAX_CLICK) {
                m.lastClick = m.lastClick / Constant.DISCOUNT;
                m.lastExposure = m.lastExposure / Constant.DISCOUNT;
            }
            m.lastClick = m.lastClick + getCtr(exposure, click);
            m.lastExposure = m.lastExposure + 1;

            if(m.appId == -1)
            {
                double gCtr = getCtr(m.lastExposure,m.lastClick);
                Info info = mMap.containsKey(m.materialId) ? mMap.get(m.materialId) : new Info();
                maxG = Math.max(gCtr, maxG);
                info.gctr = gCtr;
                info.gexp = m.lastExposure;
                info.id = m.materialId;
                info.gMaterial = m;
                mMap.put(m.materialId,info);
            }
            else
            {
                double rtCtr = getCtr(exposure,click);
                double hisCtr = getCtr(m.lastExposure,m.lastClick);
                Info info = mMap.containsKey(m.materialId) ? mMap.get(m.materialId) : new Info();
                info.hctr = hisCtr;
                info.hexp = m.lastExposure;
                if(info.hexp > 50)
                    maxH = Math.max(hisCtr, maxH);
                info.rctr = rtCtr;
                info.rexp = exposure;
                info.id = m.materialId;
                mMap.put(m.materialId,info);
            }
        }

        //3、update
        for(MaterialInfo m : materialList)
        {
            double exposure = sum(m.exposure);

            if(exposure > 0 && m.appId != -1) {
                double reward =  Constant.MIN_REWARD;
                double confidence = m.lastExposure > 100 ? 1 : m.lastExposure/100;
                reward =  (1-confidence) * mMap.get(m.materialId).gctr * 0.5 / maxG
                        + confidence * Math.min(mMap.get(m.materialId).hctr * 0.5 / maxH, 0.5) ;

                reward = Math.max(reward, Constant.MIN_REWARD);

                //init
                if(m.count < 10)
                {
                    m.count = 10;
                    m.reward = mMap.get(m.materialId).gctr * 0.5 * m.count / maxG;
                }

                m.reward = m.reward * decay + reward;
                m.count = m.count * decay + 1.0;
                m.alpha = 1.0 + m.reward;
                m.beta = 1.0 + (m.count - m.reward);
            }

            rewards.add(m.reward);
            counts.add(m.count);
            alphas.add(m.alpha);
            betas.add(m.beta);
        }

        //3、select
        MaterialInfo material = materialList.get(selectMachine());
        MaterialInfo gMaterial = mMap.get(material.materialId).gMaterial;
        List<MaterialInfo> result = new ArrayList<>();
        result.add(material);
        result.add(gMaterial);
        mMap.clear();
        return result;
    }


    public BayesianBandit() { //下周
        rewards = new ArrayList<Double>();
        counts = new ArrayList<Double>();
        alphas = new ArrayList<Double>();
        betas = new ArrayList<Double>();
    }

    private void clear()
    {
        rewards = new ArrayList<Double>();
        counts = new ArrayList<Double>();
        alphas = new ArrayList<Double>();
        betas = new ArrayList<Double>();
    }


    /*对各个备选策略的后验Beta分布采样得到当次最优策略*/
    private int selectMachine() {
        int selectMachine = 0;
        double maxTheta = 0;
        for (int i = 0; i < numMachines; i++) {
            double theta = BetaDistribution.BetaDist(alphas.get(i), betas.get(i));
            if (theta > maxTheta) {
                maxTheta = theta;
                selectMachine = i;
            }
        }
        return selectMachine;
    }

    private double getCtr(double exp, double clk)
    {
        return exp > 0 ? clk / exp : 0;
    }

    private double sum(ArrayList<Long> list){
        double sum = 0;
        for(Long val: list)
        {
            sum += val;
        }
        return sum;
    }
}