package cn.com.duiba.sso.api.tool.threadpool;

import com.google.common.collect.Maps;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.*;

/**
 * Created by liuyao on 2017/8/12.
 */
@Component
public class ThreadPoolBuilder implements DisposableBean {

    private static Map<String,ExecutorService> threadPoolCacheMap = Maps.newConcurrentMap();

    protected static final Logger logger = LoggerFactory.getLogger(ThreadPoolBuilder.class);
    /**
     * 优雅关闭超时时间，单位ms，默认3000ms，超时则强制关闭
     */
    private long shutdownTimeout = 3000;

    /**
     * 构造线程池
     * @param properties
     * @return
     */
    public static ExecutorService build(ThreadPoolProperties properties){
        if (threadPoolCacheMap.containsKey(properties.getThreadName())){
            return threadPoolCacheMap.get(properties.getThreadName());
        }
        BlockingQueue<Runnable> queue;
        int queueSize = properties.getQueueSize();
        if(queueSize <= 0){
            queue = new SynchronousQueue<>();
        }else{
            queue = new ArrayBlockingQueue(queueSize);
        }

        int maxSize = Math.max(properties.getMaxSize(), properties.getCoreSize());
        ThreadPoolExecutor threadPool = new ThreadPoolExecutor(properties.getCoreSize(),
                maxSize,
                60L, TimeUnit.SECONDS,
                queue,
                new NameThreadFactory(properties.getThreadName()),
                new AbortPolicyWithReport());

        threadPoolCacheMap.put(properties.getThreadName(),threadPool);

        return threadPool;
    }

    /**
     * 优雅关闭所有的线程池
     * @throws Exception
     */
    @Override
    public void destroy() throws Exception {
        Set<String> threadNameSet = threadPoolCacheMap.keySet();
        if(threadNameSet.isEmpty()){
            return;
        }

        //使用默认线程池来料理后事
        ExecutorService threadPool = build(new ThreadPoolProperties());
        CountDownLatch latch = new CountDownLatch(threadPoolCacheMap.size()-1);

        for(String name:threadNameSet){
            if(StringUtils.equals(name,ThreadPoolProperties.DEFAULT_THREAD_NAME)){
                continue;
            }
            ExecutorService pool = threadPoolCacheMap.get(name);
            ShutdownThreadPoolTask task = new ShutdownThreadPoolTask(pool,latch);
            threadPool.submit(task);
        }
        //等所有其他的线程关闭后才关闭默认线程池
        latch.await();

        //关闭默认线程池
        ShutdownThreadPoolTask task = new ShutdownThreadPoolTask(threadPool,null);
        Runtime.getRuntime().addShutdownHook(new Thread(task));
    }

    private class ShutdownThreadPoolTask implements Runnable{

        private ExecutorService executorService;

        private CountDownLatch latch;

        public ShutdownThreadPoolTask(ExecutorService executorService,CountDownLatch latch){
            this.executorService = executorService;
            this.latch = latch;
        }

        @Override
        public void run() {
            try{
                //先尝试软关闭线程池
                executorService.shutdown();
                try {
                    executorService.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS);
                } catch (InterruptedException e) {
                    logger.info("", e);
                    logger.warn("线程池超过{}ms仍未关闭，强制关闭。", shutdownTimeout);
                }
                //超时后仍未关闭则强制关闭线程池
                if(!executorService.isShutdown()){
                    executorService.shutdownNow();
                }
            }finally {
                if(latch!=null){
                    latch.countDown();
                }
            }
        }
    }
}
