package cn.com.duibaboot.ext.autoconfigure.security;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.sleuth.instrument.web.SleuthWebProperties;
import org.springframework.http.MediaType;
import org.springframework.web.util.UrlPathHelper;

import javax.annotation.PostConstruct;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Pattern;

/**
 * 根据Content-Type执行对应的防御策略
 * 触发安全规则会抛出 DuibaSecurityException
 */
public class SecurityFilter implements Filter {

    public static final String DEFAULT_SKIP_PATTERN =
            "/api-docs.*|/autoconfig|/configprops|/dump|/health|/info|/metrics.*|/mappings|/trace|/swagger.*|.*\\.png|.*\\.css|.*\\.js|.*\\.html|/favicon.ico|/hystrix.stream";

    private final UrlPathHelper urlPathHelper = new UrlPathHelper();

    private Pattern skipPattern = Pattern.compile(DEFAULT_SKIP_PATTERN);

    private Multimap<SimpleMediaType,DefensivePolicy> defensivePolicyMap = ArrayListMultimap.create();

    @Autowired
    private List<DefensivePolicy> defensivePolicyList;

    @Autowired(required = false)
    private DevEnvSecurityPreprocessor devEnvSecurityPreprocessor;

    @Override
    public void init(FilterConfig filterConfig){
        //没有初始化的行为
    }

    @PostConstruct
    public void initialize(){

        for(DefensivePolicy policy:defensivePolicyList){
            Set<MediaType> mediaTypes = policy.getMediaTypes();
            if(mediaTypes.contains(MediaType.ALL)){
                defensivePolicyMap.put(new SimpleMediaType(MediaType.ALL),policy);
                continue;
            }
            for(MediaType mediaType:mediaTypes){
                defensivePolicyMap.put(new SimpleMediaType(mediaType),policy);
            }
        }
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response =  (HttpServletResponse) servletResponse;

        String uri = this.urlPathHelper.getPathWithinApplication(request);
        if(this.skipPattern.matcher(uri).matches()){
            filterChain.doFilter(request,response);
            return;
        }

        List<DefensivePolicy> policyList = Lists.newArrayList();
        policyList.addAll(defensivePolicyMap.get(new SimpleMediaType(MediaType.ALL)));

        String contentType = request.getHeader("Content-Type");

        if(StringUtils.isNotBlank(contentType)){
            MediaType mediaType = MediaType.parseMediaType(contentType);
            policyList.addAll(defensivePolicyMap.get(new SimpleMediaType(mediaType)));
        }

        if(devEnvSecurityPreprocessor!=null){
            devEnvSecurityPreprocessor.preprocessor(request,response);
        }

        SecurityCheckSandbox sandbox = new SecurityCheckSandbox(request,response,policyList);
        sandbox.doCheck();
        sandbox.doFilter(filterChain);
    }

    @Override
    public void destroy() {
        defensivePolicyMap.clear();
    }


    class SimpleMediaType{

        private final String type;

        private final String subtype;

        SimpleMediaType(MediaType mediaType){
            this.type = mediaType.getType();
            this.subtype = mediaType.getSubtype();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            SimpleMediaType that = (SimpleMediaType) o;
            return Objects.equals(type, that.type) &&
                    Objects.equals(subtype, that.subtype);
        }

        @Override
        public int hashCode() {
            return Objects.hash(type, subtype);
        }
    }

}
