package cn.com.duibaboot.ext.autoconfigure.web;

import cn.com.duibaboot.ext.autoconfigure.core.DuibaServerProperties;
import cn.com.duiba.wolf.threadpool.MonitorCallable;
import cn.com.duiba.wolf.threadpool.MonitorRunnable;
import cn.com.duiba.wolf.threadpool.NamedThreadFactory;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.tomcat.util.net.AbstractEndpoint;
import org.apache.tomcat.util.threads.TaskQueue;
import org.apache.tomcat.util.threads.TaskThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.context.embedded.ConfigurableEmbeddedServletContainer;
import org.springframework.boot.context.embedded.EmbeddedServletContainerCustomizer;
import org.springframework.boot.context.embedded.tomcat.TomcatEmbeddedServletContainerFactory;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.*;

/**
 * 自动配置Tomcat、Jetty等容器，增加http线程池监控，如果在队列中等待超过1秒，则打印error日志提示线程池过小<br/>
 * 如果是内部服务，使用固定大小的线程池,并拒绝在http线程池工作队列中等待过久的任务，让客户端可以尽快重试
 */
public class ThreadPoolServletContainerCustomizer implements EmbeddedServletContainerCustomizer {
    public static final String HTTP_FAIL_FAST_THREAD_NAME_PREFIX = "HttpFailFastThread";
    private static final Logger logger = LoggerFactory.getLogger(ThreadPoolServletContainerCustomizer.class);
    private static volatile DuibaServerProperties duibaServerProperties;

    public ThreadPoolServletContainerCustomizer(DuibaServerProperties duibaServerProperties){
        this.duibaServerProperties = duibaServerProperties;
    }

    @Override
    public void customize(ConfigurableEmbeddedServletContainer container) {
        if(container instanceof TomcatEmbeddedServletContainerFactory){
            TomcatEmbeddedServletContainerFactory factory = (TomcatEmbeddedServletContainerFactory)container;
            factory.setProtocol(DuibaHttp11NioProtocol.class.getName());
        }
    }

    public static class DuibaHttp11NioProtocol extends Http11NioProtocol{

        private org.apache.tomcat.util.threads.ThreadPoolExecutor http11NioExecutor;
        private ThreadPoolExecutor httpFailFastExecutor;
        private ScheduledExecutorService scheduledExecutorService;
        private long executorTerminationTimeoutMillis;
        private String endpointName;

        public AbstractEndpoint getEndpointInner() throws InvocationTargetException, IllegalAccessException{
            try {
                //直接调用可能会报错(tomcat不同版本api变化较大导致)，作为降级，使用反射调用一次。
                return super.getEndpoint();
            }catch(Throwable e) {
                Class<?> clazz = Http11NioProtocol.class;
                Method endpointMethod = ReflectionUtils.findMethod(clazz, "getEndpoint");
                endpointMethod.setAccessible(true);
                AbstractEndpoint endpoint = (AbstractEndpoint) endpointMethod.invoke(this, null);
                return endpoint;
            }
        }

        private void createExecutor() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
            AbstractEndpoint endpoint = getEndpointInner();
            executorTerminationTimeoutMillis = endpoint.getExecutorTerminationTimeoutMillis();
            endpointName = endpoint.getName();
            if (http11NioExecutor == null) {
                TaskThreadFactory tf = new TaskThreadFactory(endpointName + "-exec-", true, getThreadPriority());
                TomcatMonitorThreadPoolExecutor executor;
                int minThreads = getMinSpareThreads();
                long timeToLive = 60;
                if(duibaServerProperties.isInternalMode()){//如果是内部服务，使用固定大小的线程池
                    minThreads = getMaxThreads();
                    timeToLive = 0;
                }

                //使用这个taskQueue可以确保线程池会先使用到MaxPoolSize才开始入队
                TaskQueue taskqueue = new TaskQueue();
                executor = new TomcatMonitorThreadPoolExecutor(minThreads, getMaxThreads(), timeToLive, TimeUnit.SECONDS, taskqueue, tf);
                taskqueue.setParent( executor);

                http11NioExecutor = executor;
            }
        }

        @Override
        public void start() throws Exception {
            try {
                createExecutor();
                setExecutor(http11NioExecutor);

                //如果是内部服务，则开启异步扫描线程，扫描出所有在队列中等待过久的任务，直接拒绝，以达到failfast的目的.（配合FailFastFilter共同生效）
                if(duibaServerProperties.isInternalMode()){
                    startFailFastService();
                }

            }catch(Exception e){
                logger.warn("getEndpoint 失败，不会给tomcat注入线程池监控功能", e);
            }

            super.start();
        }

        @Override
        public void stop() throws Exception {
            super.stop();

            if(http11NioExecutor != null){
                if (http11NioExecutor instanceof org.apache.tomcat.util.threads.ThreadPoolExecutor) {
                    //this is our internal one, so we need to shut it down
                    org.apache.tomcat.util.threads.ThreadPoolExecutor tpe = (org.apache.tomcat.util.threads.ThreadPoolExecutor) http11NioExecutor;
                    tpe.shutdownNow();
                    if (executorTerminationTimeoutMillis > 0) {
                        try {
                            tpe.awaitTermination(executorTerminationTimeoutMillis, TimeUnit.MILLISECONDS);
                        } catch (InterruptedException e) {
                            // Ignore
                        }
                        if (tpe.isTerminating()) {
                            getLog().warn(sm.getString("endpoint.warn.executorShutdown", endpointName));
                        }
                    }
                    TaskQueue queue = (TaskQueue) tpe.getQueue();
                    queue.setParent(null);
                }
            }
            if(duibaServerProperties.isInternalMode()) {
                stopFailFastService();
            }
        }

        private void stopFailFastService(){
            if(scheduledExecutorService != null){
                scheduledExecutorService.shutdownNow();
            }
            if(httpFailFastExecutor != null){
                httpFailFastExecutor.shutdownNow();
            }
        }

        private void startFailFastService(){
            httpFailFastExecutor = new ThreadPoolExecutor(1, 3, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(100),
                    new NamedThreadFactory(HTTP_FAIL_FAST_THREAD_NAME_PREFIX));
            scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new NamedThreadFactory(
                    "HttpFailFastScheduledThread"));

            //定时调度，判断是否有任务在http线程池工作队列中等待过久，如果等待超过指定毫秒数，则放入另一个线程池中执行
            //这个线程池中的线程都以HttpFailFastThread为前缀，然后在FailFastFilter中判断，如果线程前缀为HttpFailFastThread，则抛出RejectedExecutionException.
            //Http Rest RPC客户端判断到是此异常后，可以向其他服务器发起重试（因为此http调用并没有真正被处理，所以可以安全地发起调用）
            scheduledExecutorService.scheduleAtFixedRate(new Runnable() {
                @Override
                public void run() {
                    //we have idle threads, so return
                    if (http11NioExecutor.getSubmittedCount() < (http11NioExecutor.getMaximumPoolSize())){
                        return;
                    }
                    BlockingQueue<Runnable> tomcatWorkQueue = http11NioExecutor.getQueue();
                    if(tomcatWorkQueue.isEmpty()){
                        return;
                    }
                    while(!tomcatWorkQueue.isEmpty()){
                        //有http请求积压在请求队列里,判断如果等待时间超过50ms则直接拒绝。
                        Runnable runnable = tomcatWorkQueue.peek();
                        if(runnable == null){
                            break;
                        }

                        boolean needMoveIntoAnotherPool = false;
                        if(runnable instanceof MonitorRunnable) {
                            MonitorRunnable mr = (MonitorRunnable) runnable;
                            long createTimeMillis = mr.getSubmitTimeMillis();
                            if(System.currentTimeMillis() - createTimeMillis > 50){//如果在队列中等待了50ms没有被处理，则放入另一个线程池中并尽快返回。
                                needMoveIntoAnotherPool = true;
                            }
                        }else{
                            logger.warn("[NOTIFYME]will never happens here");// will not happen
                        }

                        if(needMoveIntoAnotherPool) {
                            boolean removed = tomcatWorkQueue.remove(runnable);
                            if (removed) {
                                try {
                                    httpFailFastExecutor.execute(runnable);
                                } catch (RejectedExecutionException e) {
                                    logger.warn("[NOTIFYME]put http thread into HttpFailFastThread failed,this will never happen", e);
                                }
                            }
                        }else{
                            break;//第一个都没有超时，那么后面的就不用看了。
                        }
                    }
                }
            }, 1000, 10, TimeUnit.MILLISECONDS);
        }

    }

    private static class TomcatMonitorThreadPoolExecutor extends org.apache.tomcat.util.threads.ThreadPoolExecutor {

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
        }

        @Override
        public void execute(Runnable command) {
            super.execute(new MonitorRunnable(command, getQueue()));
        }

        @Override
        public <T> Future<T> submit(Callable<T> task) {
            return super.submit(new MonitorCallable(task, getQueue()));
        }

        @Override
        public <T> Future<T> submit(Runnable task, T result) {
            return super.submit(new MonitorRunnable(task, getQueue()), result);
        }

        @Override
        public Future<?> submit(Runnable task) {
            return super.submit(new MonitorRunnable(task, getQueue()));
        }

    }

}
