package cn.com.duibaboot.ext.autoconfigure.threadpool.wrapper;

import cn.com.duiba.wolf.threadpool.NamedThreadFactory;
import cn.com.duibaboot.ext.autoconfigure.threadpool.policy.AbortPolicyWithReport;
import cn.com.duibaboot.ext.autoconfigure.threadpool.properties.ScheduledThreadPoolProperties;
import cn.com.duibaboot.ext.autoconfigure.threadpool.properties.ThreadPoolProperties;
import cn.com.duibaboot.ext.autoconfigure.threadpool.proxy.ProfileCallable;
import cn.com.duibaboot.ext.autoconfigure.threadpool.proxy.ProfileRunnable;
import com.alibaba.ttl.TtlCallable;
import com.alibaba.ttl.TtlRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.*;

/**
 *
 * @author hwq
 */
public class ThreadPoolExecutorWrapper
        extends ThreadPoolTaskExecutor //2020-02-13之后改为继承了spring的ThreadPoolTaskExecutor类，从而让org.springframework.cloud.sleuth.instrument.async.ExecutorBeanPostProcessor能对当前类进行包装。从而让线程池支持sleuth
        implements AsyncTaskExecutor,ExecutorService,SmartLifecycle {

    private final ThreadPoolProperties poolProperties;

    private static final Logger logger = LoggerFactory.getLogger(cn.com.duibaboot.ext.autoconfigure.threadpool.wrapper.ThreadPoolExecutorWrapper.class);

    public ThreadPoolExecutorWrapper(String threadPoolName, ThreadPoolProperties poolProperties){
        this.poolProperties = poolProperties;
        setThreadNamePrefix(threadPoolName);

        Objects.requireNonNull(threadPoolName);
        Objects.requireNonNull(poolProperties);
        setRejectedExecutionHandler(new AbortPolicyWithReport());
        setThreadFactory(new NamedThreadFactory(threadPoolName));
        setCorePoolSize(poolProperties.getCoreSize());
        setMaxPoolSize(Math.max(poolProperties.getMaxSize(), poolProperties.getCoreSize()));
        setQueueCapacity(poolProperties.getQueueSize());
    }

    @Override
    protected ExecutorService initializeExecutor(
            ThreadFactory threadFactory, RejectedExecutionHandler rejectedExecutionHandler) {
        if (poolProperties instanceof ScheduledThreadPoolProperties) {
            return new ScheduledThreadPoolExecutor(poolProperties.getCoreSize(),
                    threadFactory,
                    rejectedExecutionHandler);
        }else {
            return super.initializeExecutor(threadFactory, rejectedExecutionHandler);
        }
    }

    /**
     * 优雅关闭
     */
    public void shutdownGracefully() {
        //先尝试软关闭线程池
        getThreadPoolExecutor().shutdown();

        Long shutdownTimeout = poolProperties.getShutdownTimeout();

        try {
            getThreadPoolExecutor().awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            logger.error("", e);
            Thread.currentThread().interrupt();
        }
        //超时后仍未关闭则强制关闭线程池
        if(!getThreadPoolExecutor().isShutdown()){
            logger.warn("线程池超过{}ms仍未关闭，强制关闭。", poolProperties.getShutdownTimeout());
            getThreadPoolExecutor().shutdownNow();
        }
    }

    @Override
    public List<Runnable> shutdownNow() {
        return getThreadPoolExecutor().shutdownNow();
    }

    @Override
    public boolean isShutdown() {
        return getThreadPoolExecutor().isShutdown();
    }

    @Override
    public boolean isTerminated() {
        return getThreadPoolExecutor().isTerminated();
    }

    @Override
    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
        return getThreadPoolExecutor().awaitTermination(timeout, unit);
    }

    @Override
    public void execute(Runnable command) {
        getThreadPoolExecutor().execute(TtlRunnable.get(new ProfileRunnable(command, getQueue(), getThreadPoolName())));
    }

    @Override
    public <T> Future<T> submit(Callable<T> task) {
        return getThreadPoolExecutor().submit(TtlCallable.get(new ProfileCallable(task, getQueue(), getThreadPoolName())));
    }

    @Override
    public <T> Future<T> submit(Runnable task, T result) {
        return getThreadPoolExecutor().submit(TtlRunnable.get(new ProfileRunnable(task, getQueue(), getThreadPoolName())), result);
    }

    @Override
    public void execute(Runnable task, long startTimeout) {
        getThreadPoolExecutor().execute(TtlRunnable.get(new ProfileRunnable(task, getQueue(), getThreadPoolName())));
    }

    @Override
    public Future<?> submit(Runnable task) {
        return getThreadPoolExecutor().submit(TtlRunnable.get(new ProfileRunnable(task, getQueue(), getThreadPoolName())));
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
        return getThreadPoolExecutor().invokeAll(TtlCallable.gets(tasks));
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException {
        return getThreadPoolExecutor().invokeAll(TtlCallable.gets(tasks), timeout, unit);
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
        return getThreadPoolExecutor().invokeAny(TtlCallable.gets(tasks));
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
        return getThreadPoolExecutor().invokeAny(TtlCallable.gets(tasks), timeout, unit);
    }

    public boolean isTerminating() {
        return getThreadPoolExecutor().isTerminating();
    }

    public BlockingQueue<Runnable> getQueue() {
        return getThreadPoolExecutor().getQueue();
    }

    public int getPoolSize() {
        return getThreadPoolExecutor().getPoolSize();
    }

    public int getCorePoolSize() {
        return getThreadPoolExecutor().getCorePoolSize();
    }

    public int getActiveCount() {
        return getThreadPoolExecutor().getActiveCount();
    }

    public int getLargestPoolSize() {
        return getThreadPoolExecutor().getLargestPoolSize();
    }

    public int getMaximumPoolSize() {
        return getThreadPoolExecutor().getMaximumPoolSize();
    }

    public long getTaskCount() {
        return getThreadPoolExecutor().getTaskCount();
    }

    public long getCompletedTaskCount() {
        return getThreadPoolExecutor().getCompletedTaskCount();
    }

    @Override
    public boolean isAutoStartup() {
        return true;
    }

    /**
     * 多个phase相同的SmartLifecycle的stop可以并发执行，所以要在新线程中关闭
     * @param callback
     */
    @Override
    public void stop(final Runnable callback) {
        new Thread(){
            @Override
            public void run(){
                try {
                    ThreadPoolExecutorWrapper.this.shutdownGracefully();
                }finally {
                    callback.run();
                }
            }
        }.start();
    }

    @Override
    public void start() {
        //do nothing
    }

    @Override
    public void stop() {
        //will not be invoked
    }

    @Override
    public boolean isRunning() {
        return !getThreadPoolExecutor().isTerminating() && !getThreadPoolExecutor().isTerminated();
    }

    //让GracefulCloseLifeCycle先stop
    @Override
    public int getPhase() {
        return -2;
    }

    public String getThreadPoolName() {
        return getThreadNamePrefix();
    }

    protected ThreadPoolExecutor getInnerThreadPoolExecutor(){
        return getThreadPoolExecutor();
    }

    @Override
    public void destroy() {
        //这里不销毁，使用SmartLifeCycle的机制来销毁
    }
}
