package cn.com.duibaboot.ext.autoconfigure.batch;

import cn.com.duiba.wolf.threadpool.NamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.context.environment.EnvironmentChangeEvent;
import org.springframework.context.event.EventListener;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

/**
 * 请求打包合并处理器
 * 1.定时任务不断拉去队列中的请求
 * 2.将请求分组合并提交到执行线程池中
 *
 * @author houwen
 */
public class ReqPacker {

    private static final Logger logger = LoggerFactory.getLogger(ReqPacker.class);

    private ScheduledExecutorService scheduled = Executors.newSingleThreadScheduledExecutor(new NamedThreadFactory("merge-request-scheduled"));

    private ThreadPoolExecutor processPool;

    @Resource
    private ReqBucket reqBucket;

    @Resource
    private BatchHandler handler;

    @Resource
    private BatchProperties properties;

    @EventListener(EnvironmentChangeEvent.class)
    public void refresh() {
        processPool.setCorePoolSize(properties.getPoolsize());
        processPool.setMaximumPoolSize(processPool.getPoolSize());
        logger.warn("merge-request-thread-pool refresh poolSize:{}", properties.getPoolsize());
    }

    @PostConstruct
    public void init() {
        processPool = (ThreadPoolExecutor) Executors.newFixedThreadPool(properties.getPoolsize(), new NamedThreadFactory("merge-request"));
        scheduled.scheduleAtFixedRate(new Runnable() {
            @Override
            public void run() {
                try {
                    Map<String, List<ReqContext>> reqs = reqBucket.poll(properties.getBatchsize());
                    if (!reqs.isEmpty()) {
                        process(reqs);
                    }
                } catch (Exception e) {
                    logger.error("ReqBucket.poll()", e);
                }
            }
        }, 1, 1, TimeUnit.MILLISECONDS);
        logger.warn("merge-request-scheduled started poolSize:{} batchSize:{}", properties.getPoolsize(), properties.getBatchsize());
    }

    private void process(Map<String, List<ReqContext>> batch) {
        // 将每个桶中的批量调用请求，提交到执行线程池中
        for (Map.Entry<String, List<ReqContext>> entry : batch.entrySet()) {
            try {
                List<ReqContext> reqs = entry.getValue();
                if (!reqs.isEmpty()) {
                    doBatch(entry.getKey(), reqs);
                }
            } catch (RejectedExecutionException rejected) {
                //通知请求线程，走原有同步流程
                doNotify(entry.getValue(), rejected);
            }
        }
    }

    /**
     * 批量请求求提交到执行线程池中
     *
     * @param bucketId
     * @param reqs
     */
    private void doBatch(String bucketId, List<ReqContext> reqs) {
        processPool.submit(new Runnable() {
            @Override
            public void run() {
                Exception exception = null;
                try {
                    handler.doBatch(bucketId, reqs);
                } catch (Exception e) {
                    exception = e;
                    logger.error(bucketId, e);
                } finally {
                    doNotify(reqs, exception);
                }
            }
        });
    }

    /**
     * Notify 请求线程
     *
     * @param reqs
     * @param exception
     */
    private void doNotify(List<ReqContext> reqs, Exception exception) {
        for (ReqContext req : reqs) {
            //如果批量调用出现了异常，需要把异常抛个上层
            if (exception != null && req.getVolatileResult() == null) {
                req.setResult(exception);
            }
            //如果是oneway()任务，忽略notify
            if (req.getAnnotation().oneway()) {
                continue;
            }
            synchronized (req) {
                req.setState(ReqContext.NOTIFY);
                req.notify();
            }
        }
    }

}
