package com.zcoson.atman.plugin.web.filter;

import com.zcoson.atman.core.InstanceFactory;
import com.zcoson.atman.core.sampler.RatioSampler;
import com.zcoson.atman.core.sampler.Sampler;
import com.zcoson.atman.core.utils.NumberUtils;
import com.zcoson.atman.core.utils.StringUtils;

import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * AtmanRateSamplerFilter
 *
 * @author zhangshun
 * @version V1.0
 * @since 2017-11-30 14:21
 * <p></p>
 * <code>
 * <filter>
 * <filter-name>AtmanRateSampler</filter-name>
 * <filter-class>com.zcoson.atman.plugin.web.filter.AtmanRateSamplerFilter</filter-class>
 * </filter>
 * <filter-mapping>
 * <filter-name>AtmanRateSampler</filter-name>
 * <url-pattern>/magiceye/sample</url-pattern>
 * </filter-mapping>
 * </code>
 */
public class AtmanRateSamplerFilter implements Filter {

    private static final String SAMPLE_PARAM_NAME = "sampled";

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        // ignore
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
        throws IOException, ServletException {
        HttpServletRequest httpRequest = (HttpServletRequest) request;
        HttpServletResponse httpResponse = (HttpServletResponse) response;

        if (!isLocalRequest(httpRequest)) {
            httpResponse.setStatus(406);
            response.getWriter().print("only allow 127.0.0.1 \r\n");
            return;
        }

        String rate = request.getParameter(SAMPLE_PARAM_NAME);
        if (StringUtils.isEmpty(rate) || !NumberUtils.isNumber(rate)) {
            httpResponse.setStatus(416);
            response.getWriter().print("value is not in this range [1 ~ 100] \r\n");
            return;
        }

        Sampler sampler = InstanceFactory.getSampler();
        if (!(sampler instanceof RatioSampler)) {
            httpResponse.setStatus(500);
            response.getWriter().print("unknow sampler type " + sampler.getClass().getName() + "\r\n");
            return;
        }

        RatioSampler ratioSampler = (RatioSampler) sampler;
        ratioSampler.setRatio(Short.valueOf(rate));
        response.getWriter().print("ok\r\n");
    }

    @Override
    public void destroy() {
        // ignore
    }

    private boolean isLocalRequest(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");

        if (StringUtils.isEmpty(ip)) {
            ip = request.getRemoteAddr();
        }

        if ("127.0.0.1".equals(ip) || "0:0:0:0:0:0:0:1".equals(ip)) {
            return true;
        }

        return false;
    }
}
