package cn.com.duiba.kjy.base.customweb.web.interceptor.impl;

import cn.com.duiba.kjy.base.customweb.autoconfig.MappingCrosDomainConfig;
import cn.com.duiba.kjy.base.customweb.util.CorsUtils;
import cn.com.duiba.kjy.base.customweb.web.bean.KjjHttpRequest;
import cn.com.duiba.kjy.base.customweb.web.bean.KjjHttpResponse;
import cn.com.duiba.kjy.base.customweb.web.handler.mapping.controller.ControllerMappingHandler;
import cn.com.duiba.kjy.base.customweb.web.interceptor.KjjInterceptor;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.lang.Nullable;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
 * @author dugq
 */
@Slf4j
public class CrossDomainInterceptor implements KjjInterceptor {

    @Override
    public boolean applyPreHandle(KjjHttpRequest request, KjjHttpResponse response, Object handler) {
        Collection<String> varyHeaders = response.getHeader(HttpHeaderNames.VARY.toString());
        if (!varyHeaders.contains(HttpHeaderNames.ORIGIN.toString())) {
            response.addHeader(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
        }
        if (!varyHeaders.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD.toString())) {
            response.addHeader(HttpHeaderNames.VARY, HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
        }
        if (!varyHeaders.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS.toString())) {
            response.addHeader(HttpHeaderNames.VARY, HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS);
        }

        if (!CorsUtils.isCorsRequest(request)) {
            return true;
        }

        if (CollectionUtils.isNotEmpty(response.getHeader(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()))) {
            log.info("Skip: response already contains \"Access-Control-Allow-Origin\"");
            return true;
        }

        if (!(handler instanceof ControllerMappingHandler)){
            return true;
        }
        MappingCrosDomainConfig config = ((ControllerMappingHandler) handler).getMappingCrosDomainConfig();

        boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
        if (config == null) {
            rejectRequest(response);
            return false;
        }

        return handleInternal(request, response, config, preFlightRequest);
    }

    protected void rejectRequest(KjjHttpResponse response) {
        response.getResponse().setStatus(HttpResponseStatus.FORBIDDEN);
        response.write("Invalid CORS request");
        response.flushAndClose();
    }

    protected boolean handleInternal(KjjHttpRequest request, KjjHttpResponse response,
                                     MappingCrosDomainConfig config, boolean preFlightRequest)  {

        String requestOrigin = request.getHeader(HttpHeaderNames.ORIGIN.toString());
        String allowOrigin = checkOrigin(config, requestOrigin);
        HttpHeaders responseHeaders = response.getHeaders();

        if (allowOrigin == null) {
            log.info("Reject: '" + requestOrigin + "' origin is not allowed");
            rejectRequest(response);
            return false;
        }
        responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN,allowOrigin);
        //对于httpMethod 不做要求
        if (preFlightRequest){
            responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS,"POST, GET, PUT, OPTIONS, DELETE, PATCH");
        }

        List<String> requestHeaders = getHeadersToUse(request, preFlightRequest);
        List<String> allowHeaders = checkHeaders(config, requestHeaders);
        if (preFlightRequest && CollectionUtils.isEmpty(allowHeaders)) {
            log.info("Reject: headers '" + requestHeaders + "' are not allowed");
            rejectRequest(response);
            return false;
        }

        if (preFlightRequest && !allowHeaders.isEmpty()) {
            responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, StringUtils.join(allowHeaders,","));
        }

        if (!CollectionUtils.isEmpty(config.getExposedHeaders())) {
            responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS,config.getExposedHeaders());
        }

        if (Boolean.TRUE.equals(config.getAllowCredentials())) {
            responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS,true);
        }

        if (preFlightRequest && config.getMaxAge() != null) {
            responseHeaders.set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE,config.getMaxAge());
        }
        return true;
    }


    /**
     * Check the headers and determine the headers for the response of a
     * pre-flight request. The default implementation simply delegates to
     * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)}.
     */
    @Nullable
    protected List<String> checkHeaders(MappingCrosDomainConfig config, List<String> requestHeaders) {
        return config.checkHeaders(requestHeaders);
    }

    private List<String> getHeadersToUse(KjjHttpRequest request, boolean isPreFlight) {
        HttpHeaders headers = request.headers();
        return (isPreFlight ? headers.getAll(HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS) : new ArrayList<>(headers.names()));
    }


    private String checkOrigin(MappingCrosDomainConfig config, String requestOrigin) {
        return config.checkOrigin(requestOrigin);
    }

    @Override
    public void applyPostHandle(KjjHttpRequest request, KjjHttpResponse response, Object handler, Object result) {
        // do nothing
    }
}
