package cn.com.duibaboot.ext.autoconfigure.perftest;

import cn.com.duiba.boot.perftest.PerfTestContext;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.common.message.MessageExt;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * 加入AOP，压测时rocketmq消息添加压测标识
 */
@Aspect
public class RocketMqPerfAspect {

    private static final String PERF_TEST_PROPERTY = "perfTest";

    private static final Logger logger = LoggerFactory.getLogger(RocketMqPerfAspect.class);

    @Around("execution(* org.apache.rocketmq.client.producer.DefaultMQProducer+.*(..))")
    public Object rocketMqProducerJoinPoint(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        if (signature.getMethod().getName().startsWith("send")
                && signature.getParameterTypes().length >= 1
                && signature.getParameterTypes()[0].equals(Message.class)
                && PerfTestContext.isCurrentInPerfTestMode()){
            PerfTestContext.debugInfo("rocketMQProducer");

            Object[] args = joinPoint.getArgs();
            Message m = (Message)args[0];
            if(m != null){
                m.putUserProperty(PERF_TEST_PROPERTY, "true");
            }
            return joinPoint.proceed(args);
        }
        return joinPoint.proceed();
    }

    @Around("execution(* org.apache.rocketmq.client.consumer.listener.MessageListener+.*(..))")
    public Object rocketMqConsumerJoinPoint(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        if (signature.getMethod().getName().startsWith("consumeMessage")
                && signature.getParameterTypes().length >= 1
                && signature.getParameterTypes()[0].equals(List.class)){
            Object[] args = joinPoint.getArgs();
            List<MessageExt> list = (List<MessageExt>)args[0];
            if(list != null && list.size() == 1){
                MessageExt m = list.get(0);
                if("true".equals(m.getUserProperty(PERF_TEST_PROPERTY))){
                    PerfTestContext._setPerfTestMode(true);
                    PerfTestContext.debugInfo("rocketMQConsumer");
                    try {
                        return joinPoint.proceed();
                    }finally {
                        PerfTestContext._setPerfTestMode(false);
                    }
                }
            }else if(list != null && list.size() > 1){
                List<MessageExt> newList = new ArrayList<>();
                for(MessageExt m : list){
                    if(!"true".equals(m.getUserProperty(PERF_TEST_PROPERTY))){
                        newList.add(m);
                    }
                }
                if(newList.size() != list.size()) {
                    logger.warn("检测到rocketmq消费者中一次消费了多条消息且消息中有压测消息，将丢弃其中的压测消息（如需支持压测请设置DefaultMQPushConsumer.setConsumeMessageBatchMaxSize 为1）");
                    args[0] = newList;
                    return joinPoint.proceed(args);
                }
            }
        }
        return joinPoint.proceed();
    }
}
