package cn.com.duibaboot.ext.autoconfigure.perftest;

import java.lang.reflect.Field;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.sql.DataSource;

import cn.com.duiba.boot.perftest.PerfTestContext;
import cn.com.duibaboot.ext.autoconfigure.cloud.netflix.eureka.DiscoveryMetadataRegister;
import cn.com.duibaboot.ext.autoconfigure.core.SpecifiedBeanPostProcessor;
import cn.com.duibaboot.ext.autoconfigure.core.ttl.TransmittableExecutorBeanPostProcessor;
import com.alibaba.ttl.TransmittableThreadLocal;
import com.aliyun.openservices.ons.api.Consumer;
import com.mongodb.MongoClientURI;
import com.netflix.hystrix.HystrixCommand;
import feign.RequestInterceptor;
import feign.RequestTemplate;
import feign.hystrix.HystrixFeign;
import io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter;
import io.shardingjdbc.core.jdbc.core.ShardingContext;
import io.shardingjdbc.core.jdbc.core.datasource.MasterSlaveDataSource;
import io.shardingjdbc.core.jdbc.core.datasource.ShardingDataSource;
import io.shardingjdbc.core.rule.MasterSlaveRule;
import org.apache.rocketmq.client.producer.DefaultMQProducer;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnResource;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.autoconfigure.mongo.MongoProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.config.annotation.InterceptorRegistration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;

/**
 * 线上压力测试自动配置
 *
 * Created by wenqi.huang on 2016/11/24.
 */
@Configuration
@ConditionalOnResource(resources = "/autoconfig/perftest.properties")
@ConditionalOnClass({TransmittableThreadLocal.class})
@EnableAspectJAutoProxy(proxyTargetClass = true)
public class PerfTestAutoConfiguration {

    private static final String PERF_TEST_KEY = "_duibaPerf";

    /**
     * 嵌入springmvc的拦截器，处理包含_duibaPerf=1参数的url(或cookie中有_duibaPerf=1)，数据库全部走影子库
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({TransmittableThreadLocal.class})
    @ConditionalOnWebApplication
    @Order(Ordered.HIGHEST_PRECEDENCE)//这个HandlerInterceptor要在最前面
    public static class WebConfiguration extends WebMvcConfigurerAdapter {

        @Override
        public void addInterceptors(InterceptorRegistry registry) {
            try {
                //这个拦截器一定要在第一位，否则如果其他拦截器里有数据库/rpc操作会出问题
                //目前只能通过反射放到第一位
                Field field = registry.getClass().getDeclaredField("registrations");
                field.setAccessible(true);
                List<InterceptorRegistration> registrations = (List<InterceptorRegistration>)field.get(registry);
                for(InterceptorRegistration r : registrations){//判断已经添加过则不再添加
                    if(r instanceof InternalInterceptorRegistration){
                        return;
                    }
                }

                InternalInterceptorRegistration registration = new InternalInterceptorRegistration(pressureTestInterceptor());
                registration.addPathPatterns("/**");
                if(registrations.isEmpty()) {
                    registrations.add(registration);
                }else {
                    registrations.add(0, registration);
                }

                //registry.addInterceptor(pressureTestInterceptor()).addPathPatterns("/**");
            }catch(Exception e){
                throw new RuntimeException(e);
            }
        }

        private class InternalInterceptorRegistration extends InterceptorRegistration{

            public InternalInterceptorRegistration(HandlerInterceptor interceptor) {
                super(interceptor);
            }
        }

        private HandlerInterceptor pressureTestInterceptor(){

            return new HandlerInterceptorAdapter() {
                @Override
                public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object o) throws Exception {
                    PerfTestContext._setPerfTestMode(false);

                    boolean isTestMode = isPerfTestReq(request);

                    if(isTestMode) {
                        PerfTestContext._setPerfTestMode(true);
                    }

                    return true;
                }

                //postHandle不一定会被调用(preHandle抛出异常或return false时)，用afterCompletion进行清理
                @Override
                public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
                    PerfTestContext._setPerfTestMode(false);
                }
            };
        }

    }

    /**
     * 检测到包含_duibaPerf=1参数的url(或cookie中有_duibaPerf=1)则认为是压测请求
     * @param request
     * @return
     */
    private static boolean isPerfTestReq(HttpServletRequest request){
        boolean isTestMode = false;
        String testInParameter = request.getParameter(PERF_TEST_KEY);
        if(testInParameter != null && ("1".equals(testInParameter) || "true".equals(testInParameter))){
            isTestMode = true;
        }else{
            Cookie[] cookies = request.getCookies();
            if(cookies != null) {
                for (Cookie cookie : cookies) {
                    if(PERF_TEST_KEY.equals(cookie.getName()) && ("1".equals(cookie.getValue()) || "true".equals(cookie.getValue()))){
                        isTestMode = true;
                    }
                }
            }
        }

        if(!isTestMode){
            //feign 发起的http压测请求头中会带有isPerfTestMode=true
            if("true".equals(request.getHeader(DubboPerfTestFilter.IS_PERF_TEST_MODE))){
                isTestMode = true;
            }
        }

        return isTestMode;
    }

    /**
     * websocket 请求加入压测标记(只支持spring-websocket的socketjs messagebroker开发方式)
     * @return
     */
    @ConditionalOnClass(WebSocketHandlerMapping.class)
    @Bean
    public SpecifiedBeanPostProcessor<WebSocketHandlerMapping> webSocketHandlerMappingSpecifiedBeanPostProcessor(){
        return new SpecifiedBeanPostProcessor<WebSocketHandlerMapping>() {
            @Override
            public Class<WebSocketHandlerMapping> getBeanType() {
                return WebSocketHandlerMapping.class;
            }

            @Override
            public Object postProcessBeforeInitialization(WebSocketHandlerMapping bean, String beanName) throws BeansException {
                return bean;
            }

            @Override
            public Object postProcessAfterInitialization(WebSocketHandlerMapping bean, String beanName) throws BeansException {
                for(SockJsHttpRequestHandler rh : (Collection<SockJsHttpRequestHandler>)bean.getUrlMap().values()){
                    ((TransportHandlingSockJsService)rh.getSockJsService()).getHandshakeInterceptors().add(0, new HandshakeInterceptor() {
                        @Override
                        public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
                            if(request instanceof ServletServerHttpRequest){
                                ServletServerHttpRequest req = (ServletServerHttpRequest)request;
                                if(isPerfTestReq(req.getServletRequest())){
                                    attributes.put(PERF_TEST_KEY, true);
                                }
                            }
                            return true;
                        }

                        @Override
                        public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
                        }
                    });
                }
                return bean;
            }

            @Override
            public int getOrder() {
                return 0;
            }
        };
    }

    /**
     * websocket 请求加入压测标记(只支持spring-websocket的socketjs messagebroker开发方式)
     * @return
     */
    @Bean
    @ConditionalOnClass(AbstractWebSocketMessageBrokerConfigurer.class)
    public AbstractWebSocketMessageBrokerConfigurer bootWebSocketMessageBrokerConfigurer(){
        return new AbstractWebSocketMessageBrokerConfigurer() {
            @Override
            public void registerStompEndpoints(StompEndpointRegistry registry) {
                //do nothing
            }

            @Override
            public void configureClientInboundChannel(ChannelRegistration registration) {
                registration.interceptors(new ExecutorChannelInterceptor(){

                    @Override
                    public Message<?> preSend(Message<?> message, MessageChannel channel) {
                        return message;
                    }
                    @Override
                    public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
                    }
                    @Override
                    public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
                    }
                    @Override
                    public boolean preReceive(MessageChannel channel) {
                        return true;
                    }
                    @Override
                    public Message<?> postReceive(Message<?> message, MessageChannel channel) {
                        return message;
                    }
                    @Override
                    public void afterReceiveCompletion(Message<?> message, MessageChannel channel, Exception ex) {
                    }

                    // sockjs（websocket）事件处理前被调用
                    @Override
                    public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) {
                        Map<String, Object> sessionAttr = (Map<String, Object>)message.getHeaders().get(SimpMessageHeaderAccessor.SESSION_ATTRIBUTES);
                        if(sessionAttr != null){
                            if(Boolean.TRUE.equals(sessionAttr.get(PERF_TEST_KEY))){
                                PerfTestContext._setPerfTestMode(true);
                            }
                        }
                        return message;
                    }
                    // sockjs（websocket）事件处理后被调用
                    @Override
                    public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler, Exception ex) {
                        PerfTestContext._setPerfTestMode(false);
                    }
                });
            }

        };
    }

    /**
     * 把所有线程池转换为经过 TransmittableThreadLocal 包装的线程池,关于TransmittableThreadLocal的作用自行搜索
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({TransmittableThreadLocal.class})
    //@ConditionalOnBean({Executor.class})
    //ConditionalOnMissingBean/ConditionalOnBean注解后BeanPostProcessor不会生成，跟用户配置的CompomentScan路径有关系，如果配了cn.com.duiba，就会出这个问题
    //@ConditionalOnMissingBean({TransmittableExecutorBeanPostProcessor.class})
    public static class TransmittableThreadLocalConfigurer{

        @Bean
        public SpecifiedBeanPostProcessor perfTestExecutorServiceBeanPostProcessor(){
            return new TransmittableExecutorBeanPostProcessor();
        }

    }
    /**
     * 拦截spring.data.mongodb操作设置在测试环境下使用影子collection
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({TransmittableThreadLocal.class, MongoTemplate.class,Aspect.class})
    public static class SpringDataMongodbPerfConfiguration{

        @Bean
        public SpringDataMongodbPerfAspect getSpringDataMongodbPerfAspect(){
            return new SpringDataMongodbPerfAspect();
        }

    }
    /**
     * 拦截spring.data.mongodb操作设置在测试环境下使用影子collection
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({TransmittableThreadLocal.class, Consumer.class,Aspect.class})
    public static class OnsPerfConfiguration{
        @Bean
        public OnsPerfAspect getOnsPerfAspect(){
            return new OnsPerfAspect();
        }
    }
    /**
     * 拦截spring.data.mongodb操作设置在测试环境下使用影子collection
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({TransmittableThreadLocal.class, DefaultMQProducer.class,Aspect.class})
    public static class RocketMqPerfConfiguration{
        @Bean
        public RocketMqPerfAspect getRocketMqPerfAspect(){
            return new RocketMqPerfAspect();
        }
    }

    /**
     * 把spring上下文中的所有数据源转换为自动路由的数据源，可以自动把压测流量导入影子库。
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({AbstractRoutingDataSource.class,TransmittableThreadLocal.class})
    //@ConditionalOnBean({DataSource.class})
    public static class DataSourceConfigurer{

        @Bean
        public SpecifiedBeanPostProcessor perfTestDataSourcePostProcessor(){
            return new SpecifiedBeanPostProcessor<DataSource>() {
                @Autowired
                Environment environment;
                @Override
                public int getOrder() {
                    return -2;
                }

                @Override
                public Class<DataSource> getBeanType() {
                    return DataSource.class;
                }

                @Override
                public Object postProcessBeforeInitialization(DataSource bean, String beanName) throws BeansException {
                    return bean;
                }

                @Override
                public Object postProcessAfterInitialization(DataSource bean, String beanName) throws BeansException {
                    if(isBeanInstanceOfShardingJdbcDataSource(bean, beanName)){
                        return bean;
                    }
                    if(!(bean instanceof PerfTestRoutingDataSource)){
                        PerfTestRoutingDataSource ts = new PerfTestRoutingDataSource(bean, environment);
                        ts.afterPropertiesSet();
                        bean = ts;
                    }

                    return bean;
                }

            };
        }

        //判断bean是否是sharding-jdbc的datasource,如果是的话，检测内部的数据源是否为PerfTestRoutingDataSource的数据源，如果不是，则报错提示要求声明为spring 的 bean（内部bean会先创建，故肯定会被提前转为PerfTestRoutingDataSource）
        private boolean isBeanInstanceOfShardingJdbcDataSource(DataSource bean, String beanName){
            try {
                Class.forName("io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter");
            } catch (ClassNotFoundException e) {
                return false;
            }

            if(!(bean instanceof AbstractDataSourceAdapter)){
                return false;
            }

            //如果是sharding-jdbc的数据源，做些额外校验
            if(bean instanceof ShardingDataSource){
                ShardingDataSource ds = (ShardingDataSource)bean;
                Field field = ReflectionUtils.findField(ShardingDataSource.class, "shardingContext");
                field.setAccessible(true);
                ShardingContext shardingContext = (ShardingContext)ReflectionUtils.getField(field, ds);
                Map<String, DataSource> dataSourceMap = shardingContext.getShardingRule().getDataSourceMap();
                for(Map.Entry<String, DataSource> entry : dataSourceMap.entrySet()){
                    DataSource innerDs = entry.getValue();
                    if(innerDs instanceof MasterSlaveDataSource){
                        processMasterSlaveDataSource((MasterSlaveDataSource)innerDs, beanName);
                    }
                    else if(!(innerDs instanceof PerfTestRoutingDataSource)){
                        onShardingJdbcError(beanName);
                    }
                }
            }else if(bean instanceof MasterSlaveDataSource){
                processMasterSlaveDataSource((MasterSlaveDataSource)bean, beanName);
            }else{
                throw new IllegalStateException("[NOTIFYME]sharding jdbc新增的数据源暂时不支持，如遇到此问题，请联系架构组添加支持");
            }

            return true;
        }

        private void processMasterSlaveDataSource(DataSource ds, String beanName){//方法签名不使用MasterSlaveDataSource是为了防止报出找不到类的错误
            MasterSlaveDataSource ds1 = (MasterSlaveDataSource)ds;
            MasterSlaveRule oriRule = ds1.getMasterSlaveRule();
            DataSource oriMd = oriRule.getMasterDataSource();

            boolean throwException = false;
            if(!(oriMd instanceof PerfTestRoutingDataSource)){
                throwException = true;
            }
            if(!throwException) {
                Map<String, DataSource> slaveDataSourceMap = oriRule.getSlaveDataSourceMap();
                for (Map.Entry<String, DataSource> entry : slaveDataSourceMap.entrySet()) {
                    DataSource innerDs = entry.getValue();
                    if (!(innerDs instanceof PerfTestRoutingDataSource)) {
                        throwException = true;
                        break;
                    }
                }
            }

            if(throwException){
                onShardingJdbcError(beanName);
            }
        }

        private void onShardingJdbcError(String beanName){
            throw new IllegalStateException("请把id为["+beanName+"]的sharding-jdbc数据源内部使用的数据源注册为spring的bean，以让线上压测框架有机会处理内部的数据源（内部数据源必须为dbcp2）");
        }

    }

    /**
     * 把spring上下文中的所有Mongoddb数据源转换为自动路由的数据源，可以自动把压测流量导入影子库。
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({MongoProperties.class,TransmittableThreadLocal.class})
    //@ConditionalOnBean({DataSource.class})
    public static class MongoDbDataSourceConfigurer{

        @Bean
        public SpecifiedBeanPostProcessor perfTestMongoDbDataSourcePostProcessor(){
            return new SpecifiedBeanPostProcessor<MongoProperties>() {
                @Override
                public int getOrder() {
                    return -2;
                }

                @Override
                public Class<MongoProperties> getBeanType() {
                    return MongoProperties.class;
                }

                @Override
                public Object postProcessBeforeInitialization(MongoProperties bean, String beanName) throws BeansException {
                    return bean;
                }

                @Override
                public Object postProcessAfterInitialization(MongoProperties bean, String beanName) throws BeansException {
                    SpringDataMongodbPerfAspect.OriginDataBaseUri.put(new MongoClientURI(bean.getUri()).getDatabase(),bean.getUri());
                    return bean;
                }
            };
        }

    }

    /**
     * 注册元数据到Eureka中，表示当前cloud应用支持压测。
     * @return
     */
    @Bean
    public DiscoveryMetadataRegister perftest_DiscoveryMetadataRegister(){
        return new DiscoveryMetadataRegister(){

            @Override
            public void registerMetadata(Map<String, String> appMetadata) {
                appMetadata.put(DubboPerfTestRegistryFactoryWrapper.IS_PERF_TEST_SUPPORTTED_KEY,"1");
            }
        };
    }

    /**
     * 配置Feign请求拦截器，在请求前配置http头表示这是压测请求
     */
    @Configuration
    @ConditionalOnResource(resources = "/autoconfig/perftest.properties")
    @ConditionalOnClass({ HystrixCommand.class, HystrixFeign.class })
    public class PerfTestFeignConfiguration{

        @Bean
        public RequestInterceptor perfTestFeignRequestInterceptor(){
            return new RequestInterceptor() {
                @Override
                public void apply(RequestTemplate template) {
                    if (PerfTestContext.isCurrentInPerfTestMode()) {//表示线上压测流量
                        //在RibbonCustomAutoConfiguration.CustomZoneAvoidanceRule中判断对应服务端是否支持压测，如果不支持压测，会拒绝当前rpc请求
                        template.header(DubboPerfTestFilter.IS_PERF_TEST_MODE, "true");//设置http请求头，传递给服务端，表示是压测请求
                    }
                }
            };
        }

    }
}
