package cn.com.duibaboot.ext.autoconfigure.actuate;

import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.actuate.autoconfigure.endpoint.web.WebEndpointProperties;
import org.springframework.boot.actuate.autoconfigure.web.ManagementContextConfiguration;
import org.springframework.boot.actuate.endpoint.ExposableEndpoint;
import org.springframework.boot.actuate.endpoint.web.ExposableWebEndpoint;
import org.springframework.boot.actuate.endpoint.web.PathMappedEndpoint;
import org.springframework.boot.actuate.endpoint.web.WebOperation;
import org.springframework.boot.actuate.endpoint.web.annotation.ServletEndpointsSupplier;
import org.springframework.boot.actuate.endpoint.web.servlet.ControllerEndpointHandlerMapping;
import org.springframework.boot.actuate.endpoint.web.servlet.WebMvcEndpointHandlerMapping;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;

/**
 * 使用自定义的安全拦截器，对于/refresh， /restart等http接口，是内网触发的，才允许执行
 */
@ManagementContextConfiguration
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET)
@ConditionalOnClass(DispatcherServlet.class)
@EnableConfigurationProperties(WebEndpointProperties.class)
public class SecurityEndpointWebMvcManagementContextConfiguration {

    /**
     * 对WebMvcEndpointHandlerMapping和ControllerEndpointHandlerMapping注入自定义安全拦截器，只允许内网调用.(能覆盖大部分endpoint，除了ServletEndpoint-对ServletEndpoint的拦截写在后面的Filter中)
     * @return
     */
    @Bean
    public static BeanPostProcessor endpointHandlerMappingPostProcessor(){
        return new BeanPostProcessor() {
            @Override
            public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
                if(bean instanceof WebMvcEndpointHandlerMapping){
                    WebMvcEndpointHandlerMapping m = (WebMvcEndpointHandlerMapping)bean;
                    CustomMvcEndpointSecurityInterceptor customMvcEndpointSecurityInterceptor = new CustomMvcEndpointSecurityInterceptor();
                    m.setInterceptors(customMvcEndpointSecurityInterceptor);
                    Field f = ReflectionUtils.findField(m.getClass(), "adaptedInterceptors");
                    f.setAccessible(true);
                    List<HandlerInterceptor> adaptedInterceptors = (List<HandlerInterceptor>)ReflectionUtils.getField(f, bean);
                    adaptedInterceptors.add(customMvcEndpointSecurityInterceptor);
                }
                else if(bean instanceof ControllerEndpointHandlerMapping){
                    ControllerEndpointHandlerMapping m = (ControllerEndpointHandlerMapping)bean;
                    CustomMvcEndpointSecurityInterceptor customMvcEndpointSecurityInterceptor = new CustomMvcEndpointSecurityInterceptor();
                    m.setInterceptors(customMvcEndpointSecurityInterceptor);
                    Field f = ReflectionUtils.findField(m.getClass(), "adaptedInterceptors");
                    f.setAccessible(true);
                    List<HandlerInterceptor> adaptedInterceptors = (List<HandlerInterceptor>)ReflectionUtils.getField(f, bean);
                    adaptedInterceptors.add(customMvcEndpointSecurityInterceptor);
                }
                return bean;
            }

            @Override
            public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
                return bean;
            }
        };
    }

    /**
     * 拦截对/hystrix.stream等ServletEndpoint的调用,只允许从内网调用
     * @param servletEndpointsSupplier
     * @param properties
     * @return
     */
    @Bean
    public Filter servletEndpointSecurityfilter(ServletEndpointsSupplier servletEndpointsSupplier
        , WebEndpointProperties properties){
        List<String> pathes = new ArrayList<>();
        String basePath = StringUtils.trimToEmpty(properties.getBasePath());
        if(!basePath.startsWith("/")){
            basePath = "/" + basePath;
        }
        if(!basePath.endsWith("/")){
            basePath = basePath + "/";
        }
        for(ExposableEndpoint endpoint : servletEndpointsSupplier.getEndpoints()){
            if (endpoint instanceof ExposableWebEndpoint) {
                for(WebOperation operation : ((ExposableWebEndpoint)endpoint).getOperations()){
                    pathes.add(basePath + operation.getRequestPredicate().getPath());
                }
            }
            else if (endpoint instanceof PathMappedEndpoint) {
                pathes.add(basePath + ((PathMappedEndpoint) endpoint).getRootPath());
            }
        }

        return new Filter() {
            @Override
            public void init(FilterConfig filterConfig) throws ServletException {
                // do nothing
            }

            @Override
            public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
                HttpServletRequest req = (HttpServletRequest)request;
                HttpServletResponse resp = (HttpServletResponse)response;

                for(String path : pathes){
                    if(req.getRequestURI().startsWith(path)){
                        boolean canContinue = CustomMvcEndpointSecurityInterceptor.canContinue(req, resp);
                        if(!canContinue) {
                            return;
                        }
                    }
                }

                chain.doFilter(request, response);
            }

            @Override
            public void destroy() {
                //do nothing
            }
        };
    }

}
