package cn.com.duibaboot.ext.autoconfigure.core.ttl;

import cn.com.duibaboot.ext.autoconfigure.core.SpecifiedBeanPostProcessor;
import cn.com.duibaboot.ext.autoconfigure.threadpool.wrapper.ScheduledThreadPoolExecutorWrapper;
import cn.com.duibaboot.ext.autoconfigure.threadpool.wrapper.ThreadPoolExecutorWrapper;
import com.alibaba.ttl.TtlRunnable;
import com.alibaba.ttl.threadpool.TtlExecutors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.core.task.TaskDecorator;
import org.springframework.core.task.TaskExecutor;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Field;
import java.util.concurrent.*;

/**
 *
 *
 * 这个spring后置处理器可以把所有线程池转换为经过TransmittableThreadLocal包装的线程池
 * TransmittableThreadLocal的大概作用就是能确保子线程（包括线程池中的线程）能看到父线程的ThreadLocal的值。
 * <br/>
 * 要求所有的线程池都配置在spring上下文中。不然无法包装.
 *
 * Created by wenqi.huang on 2016/11/25.
 */
public class TransmittableExecutorBeanPostProcessor implements SpecifiedBeanPostProcessor<Executor> {

    @Override
    public Class<Executor> getBeanType() {
        return Executor.class;
    }

    @Override
    public Object postProcessBeforeInitialization(Executor bean, String beanName) throws BeansException {
        if(bean instanceof TaskExecutor){
            Field field = ReflectionUtils.findField(bean.getClass(), "taskDecorator", TaskDecorator.class);
            if(field != null){
                field.setAccessible(true);
                TaskDecorator taskDecorator = (TaskDecorator)ReflectionUtils.getField(field, bean);
                ReflectionUtils.setField(field, bean, new TtlTaskDecorator(taskDecorator));
            }
        }
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Executor bean, String beanName) throws BeansException {
        if(bean instanceof TaskExecutor){
            //判断是spring的线程池，则不处理。
            return bean;
        }
        if(bean instanceof ScheduledThreadPoolExecutorWrapper ||
                bean instanceof ThreadPoolExecutorWrapper){
            //已经处理过，不重复处理
            return bean;
        }
        if (bean instanceof ScheduledExecutorService) {
            bean = TtlExecutors.getTtlScheduledExecutorService((ScheduledExecutorService) bean);
        } else if (bean instanceof ExecutorService) {
            bean = TtlExecutors.getTtlExecutorService((ExecutorService) bean);
        }else if (bean instanceof Executor) {
            bean = TtlExecutors.getTtlExecutor((Executor) bean);
        }

        return bean;
    }

    @Override
    public int getOrder() {
        return 0;
    }

    private static class TtlTaskDecorator implements TaskDecorator{

        private TaskDecorator original;

        public TtlTaskDecorator(TaskDecorator original){
            this.original = original;
        }

        @Override
        public Runnable decorate(Runnable runnable) {
            if(original == null) {
                return TtlRunnable.get(runnable);
            }else{
                return TtlRunnable.get(original.decorate(runnable));
            }
        }
    }

}
