package cn.com.duibaboot.ext.autoconfigure.batch;

import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicLong;

/**
 * 暂存请求的管理类
 * 1.根据请求的方法名称，分组放入桶中
 */
public class ReqBucket {

    private static final ConcurrentMap<String, Bucket> buckets = new ConcurrentHashMap<>();

    @Resource
    private BatchProperties properties;

    /**
     * 请求提交到对应的桶中
     *
     * @param req
     */
    public void submit(ReqContext req) {
        String bucketId = req.getFullMethodName();
        if (req.isCurrentInPerfTestMode()) {
            bucketId = "PerfTest:" + bucketId;
        }
        Bucket bucket = buckets.computeIfAbsent(bucketId, k -> new Bucket());
        req.setState(ReqContext.WAITING);
        bucket.submit(req);
    }

    /**
     * 获取一批桶中的请求
     *
     * @param num
     * @return
     */
    public Map<String, List<ReqContext>> poll(int num) {
        Map<String, List<ReqContext>> reqs = new HashMap<>();
        for (Map.Entry<String, Bucket> entry : buckets.entrySet()) {
            Bucket bucket = entry.getValue();
            if (bucket != null) {
                List<ReqContext> batch = bucket.poll(num);
                if (!batch.isEmpty()) {
                    reqs.put(entry.getKey(), batch);
                }
            }
        }
        return reqs;
    }

    class Bucket {

        private AtomicLong size = new AtomicLong(0L);

        private ConcurrentLinkedQueue<ReqContext> queue = new ConcurrentLinkedQueue<>();

        /**
         * 能否提交到队列
         *
         * @return
         */
        public void canSubmit() {
            ReqContext req = queue.peek();
            if (req != null) {
                long wait = System.currentTimeMillis() - req.getTimestamp();
                if (wait > 10) {
                    throw new RejectedExecutionException("Queue wait:" + wait + "ms");
                }
            }
            if (size.get() >= properties.getMaxqueue()) {
                throw new RejectedExecutionException("Queue is Full :" + properties.getMaxqueue());
            }
        }

        /**
         * 提交对队列
         */
        public void submit(ReqContext req) {
            canSubmit();
            size.incrementAndGet();
            queue.add(req);
        }

        /**
         * 批量拉取
         *
         * @param num
         */
        public List<ReqContext> poll(int num) {
            List<ReqContext> reqs = new ArrayList<>(num);
            for (int i = 0; i < num; i++) {
                ReqContext req = queue.poll();
                if (req == null) {
                    break;
                }
                reqs.add(req);
                req.setState(ReqContext.RUNNING);
            }
            int batchNum = reqs.size();
            if (batchNum > 0) {
                size.getAndAdd(-batchNum);
            }
            return reqs;
        }

    }

}
