package cn.com.duiba.live.normal.service.api.util;

import cn.com.duiba.boot.perftest.InternalPerfTestContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

/**
 * 支持压测、并发、批量消费的Queue
 * @author lizhi
 * @date 2020/9/19 4:58 PM
 */
@Slf4j
public class BatchConsumeBlockingQueue<T> {

    /**
     * 一次批量消费的个数
     */
    private final int batchSize;

    /**
     * 一次批量消费的阈值，当已入队的数据个数达到此阈值时，触发批量消费（若消费回调阻塞，则无法消费）。
     * 如果此阈值 <= 0则不会触发个数阈值消费（这种情况下数据只会在时间到达maxWaitTimeMillis阈值后才会被消费，这意味着数据可能会在队列中存在比较长的时间）
     */
    private final int maxWaitSize;

    /**
     * 数据最大延迟消费时间的阈值(ms)，当有数据在队列中存在时间超过此阈值时，触发批量消费（若消费回调阻塞，则无法消费）。
     * 如果此阈值 <= 0则不会触发延时消费（这种情况下数据只会在堆积到达maxWaitSize阈值后才会被消费，这意味着数据可能会在队列中存在非常长的时间）
     */
    private final int maxWaitTimeMillis;

    /**
     * 批量消费回调，注意你需要自己确保消息消费成功，如果失败，我们不会再重试消费。
     * 如果回调处理时间过长，将会阻塞阈值消费，数据可能会在队列中存在比较长的时间
     */
    private final Consumer<List<T>> batchConsumer;

    /**
     * 给正常流量使用的queue
     */
    private final LinkedBlockingQueue<T> queue4Normal;

    /**
     * 给压测流量使用的queue
     */
    private final LinkedBlockingQueue<T> queue4PerfTest;

    /**
     * 最近一次消费时间(正常流量队列)
     */
    private volatile long lastConsumeTimeMillis4Normal = 0;

    /**
     * 最近一次消费时间(压测流量队列)
     */
    private volatile long lastConsumeTimeMillis4PerfTest = 0;

    /**
     * 定时触发批量消费线程
     */
    private Thread scheduleThread;

    /**
     * 消费线程(正常流量队列)
     */
    private final Thread batchConsumeNormalThread;

    /**
     * 消费线程(压测流量队列)
     */
    private final Thread batchConsumePerfThread;

    /**
     * 线程是否开启
     */
    private volatile boolean isStarted = false;

    /**
     * 消费锁(正常流量队列)
     */
    private final ReentrantLock consumeNormalLock = new ReentrantLock();

    /**
     * 等待消费(正常流量队列)
     */
    private final Condition normalCondition = consumeNormalLock.newCondition();

    /**
     * 消费锁(压测流量队列)
     */
    private final ReentrantLock consumePerfLock = new ReentrantLock();

    /**
     * 等待消费(压测流量队列)
     */
    private final Condition perfCondition = consumePerfLock.newCondition();

    private volatile String lastPerfTestSceneId = null;

    private volatile boolean isLastTestCluster = false;

    /**
     * 构造一个支持压测、并发、批量消费的Queue
     * @param capacity BlockingQueue的最大允许容量, 如果<=0 视为无界队列
     * @param batchSize 一次批量消费的个数
     * @param maxWaitSize 一次批量消费的阈值，当已入队的数据个数达到此阈值时，触发批量消费（若消费回调阻塞，则无法消费）
     *                  如果此阈值 <= 0则不会触发个数阈值消费（这种情况下数据只会在时间到达maxWaitTimeMillis阈值后才会被消费，这意味着数据可能会在队列中存在比较长的时间）
     * @param maxWaitTimeMillis 数据最大延迟消费时间的阈值(ms)，当有数据在队列中存在时间超过此阈值时，触发批量消费（若消费回调阻塞，则无法消费）。
     *                          如果此阈值 <= 0则不会触发延时消费（这种情况下数据只会在堆积到达maxWaitSize阈值后才会被消费，这意味着数据可能会在队列中存在非常长的时间）
     * @param batchConsumer 批量消费回调，注意你需要自己确保消息消费成功，如果失败，我们不会再重试消费。
     *                      如果回调处理时间过长，将会阻塞阈值消费，数据可能会在队列中存在比较长的时间
     */
    public BatchConsumeBlockingQueue(int capacity, int batchSize, int maxWaitSize, int maxWaitTimeMillis, Consumer<List<T>> batchConsumer) {
        Assert.notNull(batchConsumer, "batchConsumer must not be null");
        Assert.isTrue(maxWaitSize <= 0 || batchSize <= maxWaitSize, "batchSize must be less than or equal to maxWaitSize");
        Assert.isTrue(maxWaitSize > 0 || maxWaitTimeMillis > 0, "maxWaitSize and maxWaitTimeMillis cannot both be less than or equal to zero");
        if(capacity <= 0){
            queue4Normal = new LinkedBlockingQueue<>();
            queue4PerfTest = new LinkedBlockingQueue<>();
        }else {
            queue4Normal = new LinkedBlockingQueue<>(capacity);
            queue4PerfTest = new LinkedBlockingQueue<>(capacity);
        }
        this.batchSize = batchSize;
        this.maxWaitSize = maxWaitSize;
        this.maxWaitTimeMillis = maxWaitTimeMillis;
        this.batchConsumer = batchConsumer;
        if(maxWaitTimeMillis > 0 ) {
            scheduleThread = new Thread(this::schedule);
        }
        batchConsumeNormalThread = new Thread(this::batchConsumeNormalQueue);
        batchConsumePerfThread = new Thread(this::batchConsumePerfQueue);
    }

    /**
     * 定时触发批量消费
     */
    private void schedule() {
        while (true) {
            try {
                Thread.sleep(maxWaitTimeMillis);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            if (!isStarted) {
                break;
            }
            //通知消费线程，已经达到等待时间
            signalNormal();
            signalPerf();
        }
    }

    /**
     * 循环批量消费队列(正常流量队列)
     */
    private void batchConsumeNormalQueue() {
        //消费正常队列，若有压测标，从线程中移除
        InternalPerfTestContext.markAsNormal();
        while (true) {
            //批量消费正常流量队列消息
            try {
                batchConsume(queue4Normal, consumeNormalLock, normalCondition);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch(Throwable e){
                log.error("触发消费时发生异常", e);
            }
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            if (!isStarted && queue4Normal.isEmpty()) {
                //队列停止了，并且队列空了，停止循环
                break;
            }
        }
    }

    /**
     * 循环批量消费队列(压测流量队列)
     */
    private void batchConsumePerfQueue() {
        while (true) {
            //批量消费压测流量队列消息
            InternalPerfTestContext.markAsPerfTest(lastPerfTestSceneId, isLastTestCluster);
            try {
                batchConsume(queue4PerfTest, consumePerfLock, perfCondition);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch(Throwable e){
                log.error("触发消费时发生异常", e);
            } finally {
                InternalPerfTestContext.markAsNormal();
            }
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            if (!isStarted) {
                break;
            }
        }
    }

    /**
     * 批量消费队列
     * @throws InterruptedException 线程中断
     */
    private void batchConsume(BlockingQueue<T> queue, ReentrantLock lock, Condition condition) throws InterruptedException {
        List<T> list;
        lock.lockInterruptibly();
        try {
            if (isStarted) {
                //不足批量消费数量，且等待时间也不足
                while (needWait(queue)) {
                    if (!isStarted) {
                        //停止的时候，不用等待
                        break;
                    }
                    condition.await();
                    //每次添加时 释放等待
                    //每隔 间隔时间 释放等待
                }
            }
            list = drainTo(queue);
        } finally {
            lock.unlock();
        }
        if (list.isEmpty()) {
            return;
        }
        batchConsumer.accept(list);
    }

    private void signal() {
        if (InternalPerfTestContext.isCurrentInPerfTestMode()){
            signalPerf();
        } else {
            signalNormal();
        }
    }

    /**
     * 释放消费等待(正常流量队列)
     */
    private void signalNormal() {
        ReentrantLock lock = this.consumeNormalLock;
        lock.lock();
        try {
            normalCondition.signal();
        } finally {
            lock.unlock();
        }
    }

    /**
     * 释放消费等待(压测流量队列)
     */
    private void signalPerf() {
        ReentrantLock lock = this.consumePerfLock;
        lock.lock();
        try {
            perfCondition.signal();
        } finally {
            lock.unlock();
        }
    }

    /**
     * 从队列中获取 batchSize 个对象
     * @param queue 队列
     * @return 对象集合
     */
    private List<T> drainTo(BlockingQueue<T> queue) {
        List<T> list = new ArrayList<>();
        queue.drainTo(list, batchSize);
        setLastConsumeTimeMillis();
        return list;
    }

    /**
     * 根据当前是不是压测决定使用不同的内部queue
     */
    private BlockingQueue<T> determineQueue(){
        if(!isStarted){
            throw new IllegalStateException("not started yet!");
        }
        boolean isCurrentInPerfTestMode = InternalPerfTestContext.isCurrentInPerfTestMode();
        if (isCurrentInPerfTestMode) {
            //记住最近一次压测的场景id和是否独立容器压测的信息，后面设置回去
            lastPerfTestSceneId = InternalPerfTestContext.getCurrentSceneId();
            isLastTestCluster = InternalPerfTestContext.isTestCluster();
        }
        return isCurrentInPerfTestMode ? queue4PerfTest : queue4Normal;
    }

    /**
     * 启动延时消费线程，每隔maxWaitTimeMillis自动扫描一次有没有入队过久的消息，并触发消费
     * 启动正常流量队列消费线程
     * 启动压测流量队列消费线程
     */
    public synchronized void start(){
        if (isStarted) {
            return;
        }
        isStarted = true;
        if(maxWaitTimeMillis > 0) {
            scheduleThread.start();
        }
        batchConsumeNormalThread.start();
        batchConsumePerfThread.start();
    }

    /**
     * 停止延时消费线程
     * 停止压测流量队列消费线程
     */
    public synchronized void stop(){
        if (!isStarted) {
            return;
        }
        isStarted = false;
        if(maxWaitTimeMillis > 0) {
            scheduleThread.interrupt();
        }
        //通知正常队列消费
        signalNormal();
        //关闭时影子库的数据就不管了
        batchConsumePerfThread.interrupt();
    }

    /**
     * true-不满足消费条件，等待；false-满足消费条件，不等待
     */
    private boolean needWait(BlockingQueue<T> queue) {
        //不足批量消费数量，且等待时间也不足
        return sizeWait(queue) && timeOutWait(queue);
    }

    /**
     * true-队列元素个数小于批量消费阈值，阈值批量消费需要等待；false-队列元素个数大于等于阈值，阈值批量消费不需要等待
     */
    private boolean sizeWait(BlockingQueue<T> queue) {
        if (maxWaitSize <= 0) {
            return true;
        }
        return queue.size() < maxWaitSize;
    }

    /**
     * true-队列为空或者等待时间不足，需要等待；false-队列不为空且等待时间足够，不需要等待。
     */
    private boolean timeOutWait(BlockingQueue<T> queue) {
        if (maxWaitTimeMillis <= 0) {
            return true;
        }
        return queue.isEmpty() || getWaitTime() < maxWaitTimeMillis;
    }

    /**
     * 获取当前等待时间
     */
    private long getWaitTime() {
        //等待时间
        return System.currentTimeMillis() - getLastConsumeTimeMillis();
    }

    /**
     * 设置最近消费时间
     */
    private void setLastConsumeTimeMillis(){
        if (InternalPerfTestContext.isCurrentInPerfTestMode()){
            lastConsumeTimeMillis4PerfTest = System.currentTimeMillis();
        }else{
            lastConsumeTimeMillis4Normal = System.currentTimeMillis();
        }
    }

    /**
     * 获取最近消费时间
     */
    private long getLastConsumeTimeMillis(){
        return InternalPerfTestContext.isCurrentInPerfTestMode() ? lastConsumeTimeMillis4PerfTest : lastConsumeTimeMillis4Normal;
    }

    /**
     * 将元素添加到队列
     * @param t 元素
     * @return 添加成功
     * @throws IllegalStateException 队列已满 或者 队列未启用
     * @throws ClassCastException 元素类型错误
     * @throws NullPointerException 指定的元素为空
     * @throws IllegalArgumentException 指定元素的某些属性阻止将其添加到此队列
     */
    public boolean add(T t) {
        final BlockingQueue<T> queue = determineQueue();
        boolean added = queue.add(t);
        if(added){
            signal();
        }
        return added;
    }

    public boolean offer(T t) {
        BlockingQueue<T> queue = determineQueue();
        boolean added =  queue.offer(t);
        if(added){
            signal();
        }
        return added;
    }

    public boolean offer(T t, long timeout, TimeUnit unit) throws InterruptedException {
        BlockingQueue<T> queue = determineQueue();
        boolean added = queue.offer(t, timeout, unit);
        if(added){
            signal();
        }
        return added;
    }

    public void put(T t) throws InterruptedException {
        BlockingQueue<T> queue = determineQueue();
        queue.put(t);
        signal();
    }

    public int size() {
        return determineQueue().size();
    }

    public boolean isEmpty() {
        return determineQueue().isEmpty();
    }

    public boolean contains(Object o) {
        return determineQueue().contains(o);
    }

    public Iterator<T> iterator() {
        return determineQueue().iterator();
    }

    public Object[] toArray() {
        return determineQueue().toArray();
    }

    @Override
    public String toString() {
        return determineQueue().toString();
    }
}
