package cn.com.duiba.boot.ext.autoconfigure.threadpool.wrapper;

import cn.com.duiba.wolf.threadpool.MonitorCallable;
import cn.com.duiba.wolf.threadpool.MonitorRunnable;
import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlCallable;
import com.alibaba.ttl.TtlRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;

/**
 *
 * @author hwq
 */
public class ThreadPoolExecutorWrapper implements ExecutorService, SmartLifecycle {
    private final ThreadPoolExecutor threadPoolExecutor;
    //优雅停机时间，超时则强制关闭
    private final long shutdownTimeout;
    private static final Logger logger = LoggerFactory.getLogger(ThreadPoolExecutorWrapper.class);

    /**
     *
     * @param threadPoolExecutor 原始线程池
     * @param shutdownTimeout 优雅停机时间，超时则强制关闭
     */
    public ThreadPoolExecutorWrapper(ThreadPoolExecutor threadPoolExecutor, long shutdownTimeout) {
        this.threadPoolExecutor = threadPoolExecutor;
        this.shutdownTimeout = shutdownTimeout;
    }

    @Override
    public void shutdown() {
        threadPoolExecutor.shutdown();
    }

    /**
     * 优雅关闭
     */
    public void shotdownGracefully() {
        //先尝试软关闭线程池
        threadPoolExecutor.shutdown();
        try {
            threadPoolExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            logger.error("", e);
        }
        //超时后仍未关闭则强制关闭线程池
        if(!threadPoolExecutor.isShutdown()){
            logger.warn("线程池超过{}ms仍未关闭，强制关闭。", shutdownTimeout);
            threadPoolExecutor.shutdownNow();
        }
    }

    @Override
    public List<Runnable> shutdownNow() {
        return threadPoolExecutor.shutdownNow();
    }

    @Override
    public boolean isShutdown() {
        return threadPoolExecutor.isShutdown();
    }

    @Override
    public boolean isTerminated() {
        return threadPoolExecutor.isTerminated();
    }

    @Override
    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
        return threadPoolExecutor.awaitTermination(timeout, unit);
    }

    @Override
    public void execute(Runnable command) {
        threadPoolExecutor.execute(TtlRunnable.get(new MonitorRunnable(command, getQueue())));
    }

    @Override
    public <T> Future<T> submit(Callable<T> task) {
        return threadPoolExecutor.submit(TtlCallable.get(new MonitorCallable(task, getQueue())));
    }

    @Override
    public <T> Future<T> submit(Runnable task, T result) {
        return threadPoolExecutor.submit(TtlRunnable.get(new MonitorRunnable(task, getQueue())), result);
    }

    @Override
    public Future<?> submit(Runnable task) {
        return threadPoolExecutor.submit(TtlRunnable.get(new MonitorRunnable(task, getQueue())));
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
        return threadPoolExecutor.invokeAll(TtlCallable.gets(tasks));
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException {
        return threadPoolExecutor.invokeAll(TtlCallable.gets(tasks), timeout, unit);
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
        return threadPoolExecutor.invokeAny(TtlCallable.gets(tasks));
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
        return threadPoolExecutor.invokeAny(TtlCallable.gets(tasks), timeout, unit);
    }

    public boolean isTerminating() {
        return threadPoolExecutor.isTerminating();
    }

    public BlockingQueue<Runnable> getQueue() {
        return threadPoolExecutor.getQueue();
    }

    public int getPoolSize() {
        return threadPoolExecutor.getPoolSize();
    }

    public int getCorePoolSize() {
        return threadPoolExecutor.getCorePoolSize();
    }

    public int getActiveCount() {
        return threadPoolExecutor.getActiveCount();
    }

    public int getLargestPoolSize() {
        return threadPoolExecutor.getLargestPoolSize();
    }

    public int getMaximumPoolSize() {
        return threadPoolExecutor.getMaximumPoolSize();
    }

    public long getTaskCount() {
        return threadPoolExecutor.getTaskCount();
    }

    public long getCompletedTaskCount() {
        return threadPoolExecutor.getCompletedTaskCount();
    }


    @Override
    public boolean isAutoStartup() {
        return true;
    }

    /**
     * 多个phase相同的SmartLifecycle的stop可以并发执行，所以要在新线程中关闭
     * @param callback
     */
    @Override
    public void stop(final Runnable callback) {
        new Thread(){
            @Override
            public void run(){
                try {
                    ThreadPoolExecutorWrapper.this.shotdownGracefully();
                }finally {
                    callback.run();
                }
            }
        }.start();
    }

    @Override
    public void start() {
    }

    @Override
    public void stop() {
        //will not be invoked
    }

    @Override
    public boolean isRunning() {
        return !threadPoolExecutor.isTerminating() && !threadPoolExecutor.isTerminated();
    }

    @Override
    public int getPhase() {
        return 0;
    }
}
