package cn.com.duiba.boot.ext.autoconfigure.test;

import cn.com.duiba.boot.ext.autoconfigure.core.ttl.TransmittableExecutorBeanPostProcessor;
import com.alibaba.ttl.TransmittableThreadLocal;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.sql.DataSource;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.concurrent.Executor;

/**
 * 线上压力测试自动配置
 *
 * Created by wenqi.huang on 2016/11/24.
 */
@Configuration
@ConditionalOnClass({TransmittableThreadLocal.class})
public class PressureTestAutoConfiguration {

    public static class TransmittableThreadLocalHolder{
        protected static final TransmittableThreadLocal<Boolean> threadLocal2PressureTest = new TransmittableThreadLocal<Boolean>();
    }

    /**
     * 嵌入springmvc的拦截器，处理包含_test=1参数的url，数据库全部走影子库
     */
    @Configuration
    @ConditionalOnClass({TransmittableThreadLocal.class})
    public static class WebConfiguration extends WebMvcConfigurerAdapter {

        @Override
        public void addInterceptors(InterceptorRegistry registry) {
            registry.addInterceptor(pressureTestInterceptor()).addPathPatterns("/**");
        }

        private HandlerInterceptor pressureTestInterceptor(){

            return new HandlerInterceptorAdapter() {
                @Override
                public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object o) throws Exception {
                    TransmittableThreadLocalHolder.threadLocal2PressureTest.remove();
                    String test = request.getParameter("_test");
                    if(test != null && ("1".equals(test) || "true".equals(test))){
                        TransmittableThreadLocalHolder.threadLocal2PressureTest.set(true);
                    }

                    return true;
                }

                @Override
                public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
                    TransmittableThreadLocalHolder.threadLocal2PressureTest.remove();
                }

            };
        }

    }

    /**
     * 把所有线程池转换为经过 TransmittableThreadLocal 包装的线程池,关于TransmittableThreadLocal的作用自行搜索
     */
    @Configuration
    @ConditionalOnClass({TransmittableThreadLocal.class})
    //@ConditionalOnBean({Executor.class})
    //ConditionalOnMissingBean/ConditionalOnBean注解后BeanPostProcessor不会生成，跟用户配置的CompomentScan路径有关系，如果配了cn.com.duiba，就会出这个问题
    //@ConditionalOnMissingBean({TransmittableExecutorBeanPostProcessor.class})
    public static class TransmittableThreadLocalConfigurer{

        @Bean
        public BeanPostProcessor executorServiceBeanPostProcessor(){
            return new TransmittableExecutorBeanPostProcessor();
        }

    }

    /**
     * 把spring上下文中的所有数据源转换为自动路由的数据源，可以自动把压测流量导入影子库。
     */
    @Configuration
    @ConditionalOnClass({AbstractRoutingDataSource.class,TransmittableThreadLocal.class})
    //@ConditionalOnBean({DataSource.class})
    public static class DataSourceConfigurer{

        @Bean
        public BeanPostProcessor dataSourcePostProcessor(){
            return new BeanPostProcessor() {
                @Override
                public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
                    return bean;
                }

                @Override
                public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
                    if(bean instanceof DataSource && !(bean instanceof TestRoutingDataSource)){
                        TestRoutingDataSource ts = new TestRoutingDataSource((DataSource)bean);
                        ts.afterPropertiesSet();
                        bean = ts;
                    }

                    return bean;
                }
            };
        }

    }

}
