package cn.com.duibaboot.ext.autoconfigure.web.wrapper;

import org.jetbrains.annotations.NotNull;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.*;

/**
 * 缓存响应的body
 * 解决 body 只能读取一次的问题
 */
public class BodyWriterHttpServletResponseWrapper extends HttpServletResponseWrapper {

    /**
     * buffer 大小的阈值，单位 byte
     */
    private static final int BUFFER_SIZE_THRESHOLD = 10 * 1024 * 1024;

    private final HttpServletResponse response;
    private final ByteArrayOutputStream buffer;
    private ServletOutputStream servletOutputStream;
    private PrintWriter printWriter;

    /**
     * buffer被重置(只在buffer大小超过阈值的时候设置为true)
     */
    private volatile boolean bufferTooLargeReset = false;

    public BodyWriterHttpServletResponseWrapper(HttpServletResponse response) throws IOException {
        super(response);
        this.response = response;
        buffer = new ByteArrayOutputStream();
    }

    @Override
    public ServletOutputStream getOutputStream() throws IOException {
        if (servletOutputStream == null) {
            servletOutputStream = new BodyWriterServletOutputStream(buffer, response.getOutputStream());
        }
        return servletOutputStream;
    }

    @Override
    public PrintWriter getWriter() throws IOException {
        if (printWriter == null) {
            printWriter = new BodyWriterPrintWriter(buffer, response.getWriter(), this.getCharacterEncoding());
        }
        return printWriter;
    }

    @Override
    public void flushBuffer() throws IOException {
        if (servletOutputStream != null) {
            servletOutputStream.flush();
            ServletOutputStream out = response.getOutputStream();
            if (out != null) {
                out.flush();
            }
        }
        if (printWriter != null) {
            printWriter.flush();
            if (printWriter instanceof BodyWriterPrintWriter) {
                ((BodyWriterPrintWriter) printWriter).bufferWriterFlush();
            }
            PrintWriter writer = response.getWriter();
            if (writer != null) {
                writer.flush();
            }
        }
    }

    @Override
    public void reset() {
        buffer.reset();
    }

    public byte[] getResponseBody() throws IOException {
        flushBuffer();
        return buffer.toByteArray();
    }

    /**
     * 判断buffer当前是不是太大了
     * @param buffer
     * @return
     */
    private boolean isBufferTooLarge(ByteArrayOutputStream buffer, int currentLength) {
        return (buffer.size() + currentLength) >= BUFFER_SIZE_THRESHOLD;
    }

    /**
     * 清空buffer，并且标记为已重置
     */
    public void bufferTooLargeReset() {
        bufferTooLargeReset = true;
        buffer.reset();
    }

    public boolean isBufferTooLargeReset() {
        return bufferTooLargeReset;
    }

    class BodyWriterPrintWriter extends PrintWriter {

        private PrintWriter bufferWriter;
        private ByteArrayOutputStream buffer;

        BodyWriterPrintWriter(ByteArrayOutputStream buffer, @NotNull Writer out, String characterEncoding) throws UnsupportedEncodingException {
            super(out);
            this.buffer = buffer;
            this.bufferWriter = new PrintWriter(new OutputStreamWriter(buffer, characterEncoding));
        }

        @Override
        public void write(int c) {
            super.write(c);
            if (isBufferTooLargeReset()) {
                return;
            }
            if (isBufferTooLarge(buffer, 4)) {
                bufferTooLargeReset();
            } else {
                bufferWriter.write(c);
            }
        }

        @Override
        public void write(@NotNull char[] buf, int off, int len) {
            super.write(buf, off, len);
            if (isBufferTooLargeReset()) {
                return;
            }
            if (isBufferTooLarge(buffer, buf.length * 2)) {
                bufferTooLargeReset();
            } else {
                bufferWriter.write(buf, off, len);
            }
        }

        @Override
        public void write(@NotNull String s, int off, int len) {
            super.write(s, off, len);
            if (isBufferTooLargeReset()) {
                return;
            }
            if (isBufferTooLarge(buffer, s.getBytes().length)) {
                bufferTooLargeReset();
            } else {
                bufferWriter.write(s, off, len);
            }
        }

        void bufferWriterFlush() {
            bufferWriter.flush();
        }
    }

    class BodyWriterServletOutputStream extends ServletOutputStream {

        private ByteArrayOutputStream buffer;
        private ServletOutputStream out;

        BodyWriterServletOutputStream(ByteArrayOutputStream buffer, ServletOutputStream out) {
            this.buffer = buffer;
            this.out = out;
        }

        @Override
        public void write(int b) throws IOException {
            out.write(b);
            if (isBufferTooLargeReset()) {
                return;
            }
            if (isBufferTooLarge(buffer, 4)) {
                bufferTooLargeReset();
            } else {
                buffer.write(b);
            }
        }

        @Override
        public void write(byte[] b) throws IOException {
            out.write(b, 0, b.length);
            if (isBufferTooLargeReset()) {
                return;
            }
            if (isBufferTooLarge(buffer, b.length)) {
                bufferTooLargeReset();
            } else {
                buffer.write(b, 0, b.length);
            }
        }

        @Override
        public boolean isReady() {
            return out.isReady();
        }

        @Override
        public void setWriteListener(WriteListener writeListener) {
            out.setWriteListener(writeListener);
        }
    }
}
