package cn.com.duibaboot.ext.autoconfigure.perftest;

import cn.com.duiba.boot.perftest.PerfTestContext;
import cn.com.duiba.boot.perftest.PerfTestUtils;
import cn.com.duibaboot.ext.autoconfigure.core.SpecifiedBeanPostProcessor;
import com.alibaba.ttl.TransmittableThreadLocal;
import org.springframework.beans.BeansException;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnResource;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;

import java.util.Collection;
import java.util.Map;

/**
 * websocket线上压力测试自动配置
 *
 * Created by wenqi.huang on 2016/11/24.
 */
@Configuration
@ConditionalOnResource(resources = "/autoconfig/perftest.properties")
@ConditionalOnClass({TransmittableThreadLocal.class,
        WebSocketHandlerMapping.class,
        AbstractWebSocketMessageBrokerConfigurer.class})
public class WebSocketPerfTestConfiguration {

    /**
     * websocket 请求加入压测标记(只支持spring-websocket的socketjs messagebroker开发方式)
     * @return
     */
    @Bean
    public static SpecifiedBeanPostProcessor<WebSocketHandlerMapping> webSocketHandlerMappingSpecifiedBeanPostProcessor(){
        return new SpecifiedBeanPostProcessor<WebSocketHandlerMapping>() {
            @Override
            public Class<WebSocketHandlerMapping> getBeanType() {
                return WebSocketHandlerMapping.class;
            }

            @Override
            public Object postProcessBeforeInitialization(WebSocketHandlerMapping bean, String beanName) throws BeansException {
                return bean;
            }

            @Override
            public Object postProcessAfterInitialization(WebSocketHandlerMapping bean, String beanName) throws BeansException {
                for(SockJsHttpRequestHandler rh : (Collection<SockJsHttpRequestHandler>)bean.getUrlMap().values()){
                    ((TransportHandlingSockJsService)rh.getSockJsService()).getHandshakeInterceptors().add(0, new HandshakeInterceptor() {
                        @Override
                        public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
                            if(request instanceof ServletServerHttpRequest){
                                ServletServerHttpRequest req = (ServletServerHttpRequest)request;
                                if(PerfTestUtils.isPerfTestRequest(req.getServletRequest())){
                                    attributes.put(PerfTestUtils.PERF_TEST_KEY, true);
                                }
                            }
                            return true;
                        }

                        @Override
                        public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
                            //do nothing
                        }
                    });
                }
                return bean;
            }

            @Override
            public int getOrder() {
                return 0;
            }
        };
    }

    /**
     * websocket 请求加入压测标记(只支持spring-websocket的socketjs messagebroker开发方式)
     * @return
     */
    @Bean
    public AbstractWebSocketMessageBrokerConfigurer bootWebSocketMessageBrokerConfigurer(){
        return new AbstractWebSocketMessageBrokerConfigurer() {
            @Override
            public void registerStompEndpoints(StompEndpointRegistry registry) {
                //do nothing
            }

            @Override
            public void configureClientInboundChannel(ChannelRegistration registration) {
                registration.interceptors(new ExecutorChannelInterceptor(){

                    @Override
                    public Message<?> preSend(Message<?> message, MessageChannel channel) {
                        return message;
                    }
                    @Override
                    public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
                        //do nothing
                    }
                    @Override
                    public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
                        //do nothing
                    }
                    @Override
                    public boolean preReceive(MessageChannel channel) {
                        return true;
                    }
                    @Override
                    public Message<?> postReceive(Message<?> message, MessageChannel channel) {
                        return message;
                    }
                    @Override
                    public void afterReceiveCompletion(Message<?> message, MessageChannel channel, Exception ex) {
                        //do nothing
                    }

                    // sockjs（websocket）事件处理前被调用
                    @Override
                    public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) {
                        Map<String, Object> sessionAttr = (Map<String, Object>)message.getHeaders().get(SimpMessageHeaderAccessor.SESSION_ATTRIBUTES);
                        if(sessionAttr != null){
                            if(Boolean.TRUE.equals(sessionAttr.get(PerfTestUtils.PERF_TEST_KEY))){
                                PerfTestContext._setPerfTestMode(true);
                            }
                        }
                        return message;
                    }
                    // sockjs（websocket）事件处理后被调用
                    @Override
                    public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler, Exception ex) {
                        PerfTestContext._setPerfTestMode(false);
                    }
                });
            }

        };
    }

}
