package cn.com.duiba.wolf.dubbo;

import java.io.PrintWriter;
import java.lang.reflect.Method;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.dubbo.common.Constants;
import com.alibaba.dubbo.common.extension.Activate;
import com.alibaba.dubbo.common.io.UnsafeStringWriter;
import com.alibaba.dubbo.rpc.*;
import com.alibaba.dubbo.rpc.service.GenericService;

/**
 * Created by wenqi.huang on 16/6/1.<br/>
 * dubbo拦截器,写法参考dubbo自带的ExceptionFilter
 */
@Activate(group = { Constants.PROVIDER })
public class RuntimeExceptionFilter implements Filter {

    private static final Logger log = LoggerFactory.getLogger(RuntimeExceptionFilter.class);

    @Override
    public Result invoke(Invoker<?> invoker, Invocation invocation) throws RpcException {
        try {
            Result result = invoker.invoke(invocation);
            if (result.hasException() && GenericService.class != invoker.getInterface()) {
                try {
                    Throwable exception = result.getException();

                    // 如果是checked异常，直接抛出
                    if (!(exception instanceof RuntimeException) && (exception instanceof Exception)) {
                        return result;
                    }

                    // 是Dubbo本身的异常，直接抛出
                    if (exception instanceof RpcException) {
                        return result;
                    }

                    try {
                        // 在方法签名上有声明，直接抛出
                        Method method = invoker.getInterface().getMethod(invocation.getMethodName(),
                                                                         invocation.getParameterTypes());
                        Class<?>[] exceptionClassses = method.getExceptionTypes();
                        for (Class<?> exceptionClass : exceptionClassses) {
                            if (exception.getClass().equals(exceptionClass)) {
                                return result;
                            }
                        }

                    } catch (NoSuchMethodException e) {
                        return result;
                    }

                    // 包装成RuntimeException抛给客户端（把异常堆栈缩小,每个causeBy只保留两行）
                    Object[] arguments = invocation.getArguments();
                    log.error(exception.getMessage() + " params : " + arguments.toString(), exception);
                    return new RpcResult(new DubboException(exception.getMessage(), toShortString(exception)));
                    //StringUtils.toString(exception);
                } catch (Throwable e) {
                    return result;
                }
            }
            return result;
        } catch (RuntimeException e) {
            throw e;
        }
    }

    /**
     *
     * @param e
     * @return string
     */
    private String toShortString(Throwable e) {
        UnsafeStringWriter w = new UnsafeStringWriter();
        PrintWriter p = new PrintWriter(w);
        try {
            //e.printStackTrace(p);

            // Print our stack trace
            printStackTrace(e, p, true);
//            // Print suppressed exceptions, if any
//            for (Throwable se : e.getSuppressed())
//                se.printEnclosedStackTrace(p, trace, "Suppressed: ", "\t", dejaVu);

            // Print cause, if any
//            Throwable ourCause = e.getCause();
//            if (ourCause != null)
//                ourCause.printEnclosedStackTrace(p, trace, "Caused by: ", "", dejaVu);

            return w.toString();
        } finally {
            p.close();
        }
    }

    private void printStackTrace(Throwable e, PrintWriter p, boolean isTop){
        if(e != null){
            if(isTop) {
                p.println(e);
            }else{
                p.println("Caused by: " + e);
            }
            StackTraceElement[] trace = e.getStackTrace();
            int i = 0;
            int lastPrintedLineNumber = 0;
            for (StackTraceElement traceElement : trace) {
                if(i < 2 || traceElement.getClassName().startsWith("cn.com.duiba")) {
                    if(lastPrintedLineNumber < i - 1){
                        p.println("\t...");
                    }
                    p.println("\tat " + traceElement);
                    lastPrintedLineNumber = i;
                }
                i++;
            }

            printStackTrace(e.getCause(), p, false);
        }
    }


}
