package cn.com.duibaboot.ext.autoconfigure.flowreplay.record.aop;

import cn.com.duibaboot.ext.autoconfigure.flowreplay.CaffeineCacheFlowReplaySpan;
import cn.com.duibaboot.ext.autoconfigure.flowreplay.FlowReplayConstants;
import cn.com.duibaboot.ext.autoconfigure.flowreplay.FlowReplayTrace;
import cn.com.duibaboot.ext.autoconfigure.javaagent.core.interceptor.enhance.InstanceMethodsAroundInterceptor;
import cn.com.duibaboot.ext.autoconfigure.javaagent.core.interceptor.enhance.MethodInterceptResult;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.github.benmanes.caffeine.cache.Policy;
import com.github.benmanes.caffeine.cache.stats.CacheStats;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

/**
 * Caffeine缓存的方法录制
 * Created by guoyanfei .
 * 2019-04-15 .
 */
@Slf4j
public class RecordCaffeineCacheMethodInterceptor implements InstanceMethodsAroundInterceptor {

    @Override
    public void beforeMethod(Object obj, Method method, Object[] allArguments, Class<?>[] argumentsTypes, MethodInterceptResult result) throws Throwable {
        // do nothing
    }

    @Override
    public Object afterMethod(Object zuperCall, Object obj, Method method, Object[] allArguments, Class<?>[] argumentsTypes, Object ret) throws Throwable {
        String methodName = method.getName();
        Assert.state("build".equals(methodName) || "buildAsync".equals(methodName), "method name must be 'build', will not be here");
        int parameterCount = method.getParameterCount();

        if ("buildAsync".equals(methodName) && parameterCount == 1) {//Caffeine.buildAsync方法
            // TODO 后续该方法也需要增强，暂时先不
            return ret;
        } else if ("build".equals(methodName) && parameterCount == 0) {//Caffeine.build无参方法
            Cache cache = (Cache) ret;
            return new RecordCaffeineCacheMethodInterceptor.RecordLocalCache(cache);
        } else if ("build".equals(methodName) && parameterCount == 1) {//Caffeine.build带CacheLoader的方法
            LoadingCache cache = (LoadingCache) ret;
            return new RecordCaffeineCacheMethodInterceptor.RecordLocalLoadingCache(cache);
        } else {
            throw new IllegalStateException("will never be here");
        }
    }

    @Override
    public void handleMethodException(Object obj, Method method, Object[] allArguments, Class<?>[] argumentsTypes, Throwable t) {
        // do nothing
    }

    static class RecordLocalCache<K, V> implements Cache<K, V> {

        private Cache<K, V> cache;

        private RecordLocalCache(Cache<K, V> cache) {
            this.cache = cache;
        }

        protected Cache<K, V> getCache() {
            return cache;
        }

        @CheckForNull
        @Override
        public V getIfPresent(@Nonnull Object key) {
            return executeRecord("getIfPresent", key, () -> getCache().getIfPresent(key));
        }

        @CheckForNull
        @Override
        public V get(@Nonnull K key, @Nonnull Function<? super K, ? extends V> mappingFunction) {
            return executeRecord("get", key, () -> getCache().get(key, mappingFunction));
        }

        @Nonnull
        @Override
        public Map<K, V> getAllPresent(@Nonnull Iterable<?> keys) {
            return executeRecord("getAllPresent", keys, () -> getCache().getAllPresent(keys));
        }

        @Override
        public void put(@Nonnull K key, @Nonnull V value) {
            getCache().put(key, value);
        }

        @Override
        public void putAll(@Nonnull Map<? extends K, ? extends V> map) {
            getCache().putAll(map);
        }

        @Override
        public void invalidate(@Nonnull Object key) {
            getCache().invalidate(key);
        }

        @Override
        public void invalidateAll(@Nonnull Iterable<?> keys) {
            getCache().invalidateAll(keys);
        }

        @Override
        public void invalidateAll() {
            getCache().invalidateAll();
        }

        @Override
        public long estimatedSize() {
            return getCache().estimatedSize();
        }

        @Nonnull
        @Override
        public CacheStats stats() {
            return getCache().stats();
        }

        @Nonnull
        @Override
        public ConcurrentMap<K, V> asMap() {
            return getCache().asMap();
        }

        @Override
        public void cleanUp() {
            getCache().cleanUp();
        }

        @Nonnull
        @Override
        public Policy<K, V> policy() {
            return getCache().policy();
        }
    }

    static class RecordLocalLoadingCache<K, V> extends RecordLocalCache<K, V> implements LoadingCache<K, V> {

        private RecordLocalLoadingCache(Cache<K, V> cache) {
            super(cache);
        }

        protected LoadingCache<K, V> getCache() {
            return (LoadingCache) super.getCache();
        }

        @CheckForNull
        @Override
        public V get(@Nonnull K key) {
            return executeRecord("get", key, () -> getCache().get(key));
        }

        @Nonnull
        @Override
        public Map<K, V> getAll(@Nonnull Iterable<? extends K> keys) {
            return executeRecord("getAll", keys, () -> getCache().getAll(keys));
        }

        @Override
        public void refresh(@Nonnull K key) {
            getCache().refresh(key);
        }
    }

    /**
     * 回调接口
     *
     * @param <R>
     */
    private interface Callback<R> {

        R invoke();
    }

    private static <T> T executeRecord(String methodName, Object key, Callback<T> callback) {
        if (!canRecord(key)) {
            return callback.invoke();
        }

        Object[] allArguments = new Object[] { key };
        Class<?>[] argumentsTypes = new Class[] { key.getClass() };

        // 标记本次调用的子调用不需要录制
        IgnoreSubInvokesContext.instMark(callback, methodName, allArguments);

        T v;
        try {
            v = callback.invoke();
        } catch (Throwable t) {
            // 如果捕捉到异常，并且正在录制中，那么不录这个用例
            FlowReplayTrace.remove();
            IgnoreSubInvokesContext.unmark();
            throw t;
        }

        record(methodName, allArguments, argumentsTypes, v);

        return v;
    }

    /**
     * 可以录制
     *
     * @param allArguments
     * @return
     */
    private static boolean canRecord(Object arg) {
        if (!FlowReplayTrace.isTraced()) {
            return false;
        }
        if (arg instanceof Iterable<?>) {
            for (Object o : (Iterable<?>) arg) {
                if (o == null) {
                    continue;
                }
                if (FlowReplayConstants.CANNOT_DESERIALIZE_CLASSES.contains(o.getClass().getName())) {
                    return false;
                }
            }
        }
        if (FlowReplayConstants.CANNOT_DESERIALIZE_CLASSES.contains(arg.getClass().getName())) {
            return false;
        }
        // 本次调用已经被标记为需要忽略，不录制
        if (IgnoreSubInvokesContext.isMarked()) {
            return false;
        }
        return true;
    }

    /**
     * 录制
     *
     * @param methodName
     * @param allArguments
     * @param argumentsTypes
     * @param ret
     */
    private static void record(String methodName, Object[] allArguments, Class<?>[] argumentsTypes, Object ret) {
        try {
            CaffeineCacheFlowReplaySpan span = CaffeineCacheFlowReplaySpan.createSpan(methodName, allArguments, argumentsTypes, ret);
            span.setTraceId(FlowReplayTrace.getCurrentTraceId());
            FlowReplayTrace.addSubSpan(span);
            log.debug("Caffeine录制_traceId={}_spanId={}", span.getTraceId(), span.getSpanId());
        } catch (Throwable t) {
            log.error("CaffeineCache_录制异常", t);
        } finally {
            IgnoreSubInvokesContext.unmark();
        }
    }
}
