package cn.com.duibaboot.ext.autoconfigure.rocketmq.sleuth;

import brave.Span;
import brave.Tags;
import brave.Tracer;
import brave.Tracing;
import brave.propagation.TraceContext;
import brave.propagation.TraceContextOrSamplingFlags;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.common.message.MessageExt;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.util.List;
import java.util.Objects;

/**
 * sleuth对rocketmq的支持
 */
@Aspect
public class RocketmqSleuthPlugin {

	@Resource
	private Tracer tracer;
	@Resource
	private Tracing tracing;
	@Resource
	private RocketmqSleuthConsumeMessageHook rocketmqSleuthConsumeMessageHook;

	private TraceContext.Injector<Message> injector;
	private TraceContext.Extractor<Message> extractor;

	@PostConstruct
	public void init() {
		this.injector = tracing.propagation().injector(Message::putUserProperty);
		this.extractor = tracing.propagation().extractor(Message::getUserProperty);
	}

	@Around(value = "execution(* org.apache.rocketmq.client.producer.MQProducer.send*(org.apache.rocketmq.common.message.Message,..)) && args(message,..)",argNames= "point,message")
	public Object send(ProceedingJoinPoint point, Message message) throws Throwable {
		return rocketMqSendJoinPoint(point,message,"send");
	}

	@Around(value = "execution(* org.apache.rocketmq.client.producer.MQProducer.request(org.apache.rocketmq.common.message.Message,..)) && args(message,..)",argNames= "point,message")
	public Object sendMessages(ProceedingJoinPoint point, Message message) throws Throwable {
		return rocketMqSendJoinPoint(point,message,"request");
	}

	public Object rocketMqSendJoinPoint(ProceedingJoinPoint point,Message message,String type) throws Throwable{
		Span curSpan = tracer.currentSpan();
		if(curSpan == null || curSpan.isNoop()){
			return point.proceed();
		}
		Span span = tracer.nextSpan().name("rocketmq-send")
				.kind(Span.Kind.PRODUCER)
				.remoteServiceName("rocketmq")
				.start();
		try (Tracer.SpanInScope scope = tracer.withSpanInScope(span)) {
			injector.inject(span.context(), message);
			span.tag("lc", "rocketmqProducer");//本地组件名
			span.tag("rocketmq.topic", message.getTopic());//topic
			span.tag("rocketmq.tags", StringUtils.defaultString(message.getTags()));//tags
			span.tag("rocketmq.type", type);
			span.tag("producer.thread", Thread.currentThread().getName());
			return point.proceed();
		} catch (Exception error) {
			Tags.ERROR.tag(error, span);
			throw error;
		} finally {
			span.finish();
		}
	}

	@Around(value = "execution(* org.apache.rocketmq.spring.core.RocketMQListener.onMessage(..))")
	public Object rocketMqListener(ProceedingJoinPoint point) throws Throwable {
		return rocketMqListenerJoinPoint(point,"rocketMQListener");
	}

	@Around(value = "execution(* org.apache.rocketmq.spring.core.RocketMQReplyListener.onMessage(..))")
	public Object rocketMqReplyListener(ProceedingJoinPoint point) throws Throwable {
		return rocketMqListenerJoinPoint(point,"rocketMQReplyListener");
	}

	public Object rocketMqListenerJoinPoint(ProceedingJoinPoint point,String listenerName) throws Throwable {
		RocketmqSleuthContext context = rocketmqSleuthConsumeMessageHook.getRocketmqSleuthContext();
		if(Objects.isNull(context)){
			return point.proceed();
		}
		MessageExt message = context.poll();
		if(Objects.isNull(message)){
			return point.proceed();
		}
		Object target = point.getTarget();
		TraceContextOrSamplingFlags extracted = extractor.extract(message);
		Span span = tracer.nextSpan(extracted).name("rocketmq-consume")
				.kind(Span.Kind.SERVER)
				.remoteServiceName("rocketmq")//远程服务名
				.start();
		try(Tracer.SpanInScope scope = tracer.withSpanInScope(span)) {
			span.tag("lc", listenerName);
			span.tag("rocketmq.topic", message.getTopic());
			span.tag("rocketmq.tags", StringUtils.defaultString(message.getTags()));
			span.tag("rocketmq.type", "consume");
			span.tag("consumer.class", target.getClass().getSimpleName());
			span.tag("consumer.thread", Thread.currentThread().getName());
			return point.proceed();
		} catch(Exception error){
			Tags.ERROR.tag(error,span);
			throw error;
		} finally {
			span.finish();
		}
	}


	/**
	 * 原 bootRocketMqMessageListener 监听方法
	 */
	@Around(value="execution(* org.apache.rocketmq.client.consumer.listener.MessageListener*.consumeMessage(java.util.List<org.apache.rocketmq.common.message.MessageExt>,..)) && args(msgs,..)",argNames= "joinPoint,msgs")
	public Object rocketMqConsumeJoinPoint(ProceedingJoinPoint joinPoint,List<MessageExt> msgs) throws Throwable{
		if(msgs.isEmpty()) {
			return joinPoint.proceed();
		}
		MessageExt msg = msgs.get(0);
		TraceContextOrSamplingFlags extracted = extractor.extract(msg);
		Span span = tracer.nextSpan(extracted).name("rocketmq-consume")
				.kind(Span.Kind.SERVER)
				.remoteServiceName("rocketmq")//远程服务名
				.start();
		try(Tracer.SpanInScope scope = tracer.withSpanInScope(span)) {
			span.tag("lc", "rocketmqConsumer");//本地组件名
			span.tag("rocketmq.topic", msg.getTopic());//topic
			span.tag("rocketmq.tags", org.apache.commons.lang3.StringUtils.defaultString(msg.getTags()));//tags
			span.tag("rocketmq.type", "consume");
			span.tag("consumer.thread", Thread.currentThread().getName());
			return joinPoint.proceed();
		} catch(Exception error){
			Tags.ERROR.tag(error,span);
			throw error;
		} finally {
			span.finish();
		}
	}

}
