package cn.com.duiba.boot.ext.autoconfigure.web;

import cn.com.duiba.wolf.threadpool.MonitorCallable;
import cn.com.duiba.wolf.threadpool.MonitorRunnable;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.tomcat.util.net.AbstractEndpoint;
import org.apache.tomcat.util.threads.TaskQueue;
import org.apache.tomcat.util.threads.TaskThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.context.embedded.ConfigurableEmbeddedServletContainer;
import org.springframework.boot.context.embedded.EmbeddedServletContainerCustomizer;
import org.springframework.boot.context.embedded.tomcat.TomcatEmbeddedServletContainerFactory;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.*;

/**
 * 自动配置Tomcat、Jetty等容器，增加http线程池监控，如果在队列中等待超过1秒，则打印error日志提示线程池过小
 */
public class ThreadPoolServletContainerCustomizer implements EmbeddedServletContainerCustomizer {
    private static final Logger logger = LoggerFactory.getLogger(ThreadPoolServletContainerCustomizer.class);
    @Override
    public void customize(ConfigurableEmbeddedServletContainer container) {
        if(container instanceof TomcatEmbeddedServletContainerFactory){
            TomcatEmbeddedServletContainerFactory factory = (TomcatEmbeddedServletContainerFactory)container;
            factory.setProtocol(DuibaHttp11NioProtocol.class.getName());
        }
    }

    public static class DuibaHttp11NioProtocol extends Http11NioProtocol{

        private ThreadPoolExecutor http11NioExecutor;
        private long executorTerminationTimeoutMillis;
        private String endpointName;

        public AbstractEndpoint getEndpointInner() throws InvocationTargetException, IllegalAccessException{
            //return super.getEndpoint();

            Method endpointMethod = null;
            Class<?> clazz = Http11NioProtocol.class;
            while(clazz != null) {
                try {
                    endpointMethod = clazz.getDeclaredMethod("getEndpoint", null);
                    endpointMethod.setAccessible(true);
                    break;
                } catch (NoSuchMethodException e) {
                    // Ignore
                }
                clazz = clazz.getSuperclass();
            }
            AbstractEndpoint endpoint = (AbstractEndpoint)endpointMethod.invoke(this, null);
            return endpoint;
        }

        private void createExecutor() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
            AbstractEndpoint endpoint = getEndpointInner();
            executorTerminationTimeoutMillis = endpoint.getExecutorTerminationTimeoutMillis();
            endpointName = endpoint.getName();
            if (http11NioExecutor == null) {
                //使用这个taskQueue可以确保线程池会先使用到MaxPoolSize才开始入队
                TaskQueue taskqueue = new TaskQueue();
                TaskThreadFactory tf = new TaskThreadFactory(endpointName + "-exec-", true, getThreadPriority());
                TomcatMonitorThreadPoolExecutor executor = new TomcatMonitorThreadPoolExecutor(getMinSpareThreads(), getMaxThreads(), 60, TimeUnit.SECONDS, taskqueue, tf);
                http11NioExecutor = executor;
                taskqueue.setParent((org.apache.tomcat.util.threads.ThreadPoolExecutor) http11NioExecutor);
            }
        }

        @Override
        public void start() throws Exception {
            try {
                createExecutor();
                setExecutor(http11NioExecutor);
            }catch(Exception e){
                logger.warn("getEndpoint 失败，不会给tomcat注入线程池监控功能", e);
            }

            super.start();
        }

        @Override
        public void stop() throws Exception {
            super.stop();

            if(http11NioExecutor != null){
                if (http11NioExecutor instanceof org.apache.tomcat.util.threads.ThreadPoolExecutor) {
                    //this is our internal one, so we need to shut it down
                    org.apache.tomcat.util.threads.ThreadPoolExecutor tpe = (org.apache.tomcat.util.threads.ThreadPoolExecutor) http11NioExecutor;
                    tpe.shutdownNow();
                    if (executorTerminationTimeoutMillis > 0) {
                        try {
                            tpe.awaitTermination(executorTerminationTimeoutMillis, TimeUnit.MILLISECONDS);
                        } catch (InterruptedException e) {
                            // Ignore
                        }
                        if (tpe.isTerminating()) {
                            getLog().warn(sm.getString("endpoint.warn.executorShutdown", endpointName));
                        }
                    }
                    TaskQueue queue = (TaskQueue) tpe.getQueue();
                    queue.setParent(null);
                }
            }
        }

    }

    private static class TomcatMonitorThreadPoolExecutor extends org.apache.tomcat.util.threads.ThreadPoolExecutor {

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
        }

        public TomcatMonitorThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
            super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
        }

        @Override
        public void execute(Runnable command) {
            super.execute(new MonitorRunnable(command, getQueue()));
        }

        @Override
        public <T> Future<T> submit(Callable<T> task) {
            return super.submit(new MonitorCallable(task, getQueue()));
        }

        @Override
        public <T> Future<T> submit(Runnable task, T result) {
            return super.submit(new MonitorRunnable(task, getQueue()), result);
        }

        @Override
        public Future<?> submit(Runnable task) {
            return super.submit(new MonitorRunnable(task, getQueue()));
        }

    }
}
