package cn.com.duibaboot.ext.autoconfigure.cloud.netflix.hystrix;

import cn.com.duiba.wolf.perf.timeprofile.DBTimeProfile;
import cn.com.duibaboot.ext.autoconfigure.cat.context.CatConstants;
import cn.com.duibaboot.ext.autoconfigure.cat.context.CatContext;
import cn.com.duibaboot.ext.autoconfigure.core.rpc.RpcContext;
import cn.com.duibaboot.ext.autoconfigure.core.utils.CatUtils;
import com.dianping.cat.Cat;
import com.dianping.cat.message.Message;
import com.dianping.cat.message.Transaction;
import com.netflix.hystrix.strategy.concurrency.HystrixConcurrencyStrategy;
import com.netflix.hystrix.strategy.concurrency.HystrixContextRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.Callable;

/**
 * hyxtrix插件，使用spi机制生效
 * 使用TtlCallable封装，以支持TransmittableThreadLocal，让hystrix中运行的线程得到上级线程设置的线程本地变量
 */
public class CustomHystrixConcurrencyStrategy extends HystrixConcurrencyStrategy {

    private static final Logger logger = LoggerFactory.getLogger(CustomHystrixConcurrencyStrategy.class);

    /**
     * 使用TtlCallable封装，以支持TransmittableThreadLocal
     *
     * @param callable
     * @param <T>
     * @return
     */
    @Override
    public <T> Callable<T> wrapCallable(Callable<T> callable) {
        if (CatUtils.isCatEnabled()) {
            boolean isTimer = isTimer(callable);
            return new CallableWrapperWithCat(callable, isTimer);
        } else {
            return new CallableWrapper(callable);
        }
    }

    private static class CallableWrapper implements Callable {

        private Callable callable;

        public CallableWrapper(Callable callable) {
            this.callable = callable;
        }

        //得到当前线程的context的副本
        private RpcContext rpcContext = RpcContext.getContext().clone();

        private Integer currentThreshold = DBTimeProfile.getCurrentThreshold();

        @Override
        public Object call() throws Exception {
            //把父线程的context设置到当前线程中
            RpcContext.setContext(rpcContext);
            try {
                if (currentThreshold != null) {
                    DBTimeProfile.setCurrentThreshold(currentThreshold);
                }
                DBTimeProfile.start();

                return callable.call();
            } finally {
                RpcContext.removeContext();
                DBTimeProfile.end(Thread.currentThread().getName());
            }
        }
    }

    private static class CallableWrapperWithCat implements Callable {

        private Callable callable;

        private Long submitNanoTime;

        private boolean isTimer;

        //得到当前线程的context的副本
        private RpcContext rpcContext = RpcContext.getContext().clone();

        private Integer currentThreshold = DBTimeProfile.getCurrentThreshold();

        public CallableWrapperWithCat(Callable callable, boolean isTimer) {
            this.callable = callable;
            this.submitNanoTime = System.nanoTime();
            this.isTimer = isTimer;
            this.logAsyncCall(rpcContext);
        }

        @Override
        public Object call() throws Exception {

            if(isTimer){
                return callable.call();
            }

            Transaction all = Cat.newTransaction(CatConstants.THREAD_POOL, CatConstants.HYSTRIX_CALLABLE);
            Cat.logRemoteCallServer(getCatContext(rpcContext));
            CatUtils.newCompletedTransaction(CatConstants.THREAD_POOL, CatConstants.HYSTRIX_QUEUE_WAIT, submitNanoTime);
            Transaction call = Cat.newTransaction(CatConstants.THREAD_POOL, CatConstants.HYSTRIX_CALL);
            try {
                if (currentThreshold != null) {
                    DBTimeProfile.setCurrentThreshold(currentThreshold);
                }
                DBTimeProfile.start();
                if (rpcContext.getTargetServiceId() != null) {
                    Cat.Context cat_ctx = new CatContext();
                    Cat.logRemoteCallClient(cat_ctx);
                    setCatContext(rpcContext, cat_ctx);
                }
                RpcContext.setContext(rpcContext);

                return callable.call();

            } finally {
                call.setStatus(Message.SUCCESS);
                call.complete();
                all.setStatus(Message.SUCCESS);
                all.complete();
                RpcContext.removeContext();
                DBTimeProfile.end(Thread.currentThread().getName());
            }
        }

        private void logAsyncCall(RpcContext rpcContext) {
            try {
                if (CatUtils.isCatEnabled() && !isTimer && Cat.getManager().getPeekTransaction() != null) {
                    Cat.Context cat_ctx = new CatContext();
                    CatUtils.logAsyncCall(cat_ctx);
                    setCatContext(rpcContext, cat_ctx);
                }
            } catch (Exception e) {
                logger.error("logAsyncCall", e);
            }
        }

        private Cat.Context getCatContext(RpcContext rpcContext) {
            try {
                Cat.Context ctx = new CatContext();
                ctx.addProperty(Cat.Context.ROOT, rpcContext.getAttachment(Cat.Context.ROOT));
                ctx.addProperty(Cat.Context.CHILD, rpcContext.getAttachment(Cat.Context.CHILD));
                ctx.addProperty(Cat.Context.PARENT, rpcContext.getAttachment(Cat.Context.PARENT));
                return ctx;
            } catch (Exception e) {
                logger.error("get cat context", e);
            }
            return null;
        }

        private void setCatContext(RpcContext rpcContext, Cat.Context ctx) {
            try {
                rpcContext.setAttachment(Cat.Context.ROOT, ctx.getProperty(Cat.Context.ROOT));
                rpcContext.setAttachment(Cat.Context.CHILD, ctx.getProperty(Cat.Context.CHILD));
                rpcContext.setAttachment(Cat.Context.PARENT, ctx.getProperty(Cat.Context.PARENT));
            } catch (Exception e) {
                logger.error("set cat Context", e);
            }
        }
    }

    private boolean isTimer(Callable callable) {
        try {
            Object clazz = callable.getClass().getEnclosingClass();
            return clazz == HystrixContextRunnable.class;
        } catch (Exception e) {
            return false;
        }
    }

}
