package com.zcoson.atman.plugin.web.filter;

import io.opentracing.Span;
import io.opentracing.tag.Tags;

import java.io.IOException;
import java.util.Enumeration;
import java.util.Map;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.zcoson.atman.core.Atman;
import com.zcoson.atman.core.AtmanConfig;
import com.zcoson.atman.core.propagation.TextMapCodec;
import com.zcoson.atman.core.utils.ServletParamUtil;
import com.zcoson.atman.core.utils.StringUtils;
import com.zcoson.atman.core.utils.UUIDUtils;
import com.zcoson.atman.plugin.web.servlet.http.TraceLogServletRequestWrapper;

/**
 * AtmanRateSamplerFilter
 *
 * @author zhangshun
 * @version V1.0
 * @since 2017-11-30 14:21
 * <p>springboot
 * Configuration
 * public class TraceServletFilter extends WebMvcConfigurerAdapter {
 * Bean
 * public FilterRegistrationBean traceFilter() {
 * FilterRegistrationBean filterRegistration = new FilterRegistrationBean();
 * filterRegistration.setOrder(1);
 * filterRegistration.setUrlPatterns(Arrays.asList("/*"));
 * filterRegistration.setName("AtmanServletFilter");
 * filterRegistration.setFilter(new AtmanServletFilter());
 * return filterRegistration;
 * }
 * }
 * </code>
 * <p>
 * <p>web.xml 配置文件
 * <code>
 * <filter>
 * <filter-name>AtmanServletFilter</filter-name>
 * <filter-class>com.zcoson.atman.plugin.web.filter.AtmanServletFilter</filter-class>
 * </filter>
 * <filter-mapping>
 * <filter-name>AtmanServletFilter</filter-name>
 * <url-pattern>/magiceye/sample</url-pattern>
 * </filter-mapping>
 * </code>
 */
public class AtmanServletFilter implements Filter {

    private static final Logger LOGGER = LoggerFactory.getLogger(AtmanServletFilter.class);

    private static final String SERVICE_LABEL = "SERVLET";

    private String[] exclusions;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        // Do nothing because of X and Y.
        String exclusion = filterConfig.getInitParameter("exclusions");
        if (null == exclusion || 0 == exclusion.trim().length()) {
            return;
        }
        exclusions = exclusion.split(",");
    }

    /**
     * 过滤器操作方法
     *
     * @param req
     * @param resp
     * @param chain
     * @throws java.io.IOException
     * @throws javax.servlet.ServletException
     */
    @Override
    public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
        throws IOException, ServletException {

        if (!(req instanceof HttpServletRequest)) {
            chain.doFilter(req, resp);
            return;
        }

        HttpServletRequest request = (HttpServletRequest) req;

        if (isExclusions(request.getRequestURI())) {
            chain.doFilter(req, resp);
            return;
        }

        Span span = null;
        try {

            String traceValue = this.getTraceValue(request);
            span = Atman.newSpan(request.getRequestURI(), Atman.extract(traceValue));
            span.setTag(AtmanConfig.SERVICE_TYPE_KEY, AtmanConfig.ServiceType.HTTP.getCode());
            span.setTag(AtmanConfig.SERVICE_LABEL_KEY, SERVICE_LABEL);

            Tags.SPAN_KIND.set(span, Tags.SPAN_KIND_SERVER);
            Tags.PEER_HOSTNAME.set(span, request.getLocalName());

            // 设置头信息
            this.handlerHeader(span, request);

            // 处理参数
            this.handlerParameter(span, request);

        } catch (Throwable t) {
            LOGGER.warn(SERVICE_LABEL + "Before", t);
        }

        try {
            // 处理Payload参数
            ServletRequest wrapper = this.handlerPayload(span, request);
            if (null != wrapper) {
                chain.doFilter(wrapper, resp);
            } else {
                chain.doFilter(req, resp);
            }
        } catch (Throwable t) {
            // 打印日志
            LOGGER.warn(SERVICE_LABEL + "Parameter", t);
            // 继续执行
            chain.doFilter(req, resp);
        } finally {
            if (null != span) {
                span.finish();
            }
        }
    }

    @Override
    public void destroy() {
        // Do nothing because of X and Y.
    }

    private String getTraceValue(HttpServletRequest request) {
        String traceValue = request.getHeader(TextMapCodec.SPAN_CONTEXT_KEY);
        if (StringUtils.isEmpty(traceValue)) {
            return UUIDUtils.randomString();
        }
        return traceValue;
    }

    private void handlerHeader(Span span, HttpServletRequest request) {
        Enumeration<String> names = request.getHeaderNames();
        if (null == names) {
            return;
        }
        String name = null;
        while (names.hasMoreElements()) {
            name = names.nextElement();
            if (!ServletParamUtil.excludeHeaderName(name)) {
                span.log("header." + ServletParamUtil.getParameterName(name), request.getHeader(name));
            }
        }
    }

    private void handlerParameter(Span span, HttpServletRequest request) {
        Map<String, String[]> parameter = request.getParameterMap();
        if (null == parameter || 0 == parameter.size()) {
            return;
        }
        for (Map.Entry<String, String[]> entry : parameter.entrySet()) {
            span.log("parameter." + ServletParamUtil.getParameterName(entry.getKey()),
                ServletParamUtil.getParameterValue(entry.getValue()));
        }
    }

    private ServletRequest handlerPayload(Span span, HttpServletRequest request) {
        TraceLogServletRequestWrapper wrapper = null;
        try {
            wrapper = new TraceLogServletRequestWrapper(request);
            Map<String, Object> headerMap = AtmanConfig.GSON.fromJson(wrapper.getBody(), Map.class);
            if (null == headerMap || 0 == headerMap.size()) {
                return wrapper;
            }
            for (Map.Entry<String, Object> entry : headerMap.entrySet()) {
                span.log("parameter." + entry.getKey(), entry.getValue());
            }
        } catch (Exception e) {
            LOGGER.warn(SERVICE_LABEL + "Payload", e);
        } finally {
            return wrapper;
        }
    }

    private boolean isExclusions(String uri) {
        if (null == exclusions || 0 == exclusions.length) {
            return false;
        }
        if (null == uri) {
            return false;
        }
        uri = uri.trim();
        if (0 == uri.length()) {
            return false;
        }
        for (String ex : exclusions) {
            ex = ex.trim();
            if (uri.startsWith(ex) || uri.endsWith(ex)) {
                return true;
            }
        }
        return false;
    }
}
