package cn.com.duibaboot.ext.autoconfigure.perftest.caffeine;

import cn.com.duiba.boot.perftest.InternalPerfTestContext;
import cn.com.duibaboot.ext.autoconfigure.javaagent.core.interceptor.enhance.InstanceMethodsAroundInterceptor;
import cn.com.duibaboot.ext.autoconfigure.javaagent.core.interceptor.enhance.MethodInterceptResult;
import com.github.benmanes.caffeine.cache.AsyncLoadingCache;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.github.benmanes.caffeine.cache.Policy;
import com.github.benmanes.caffeine.cache.stats.CacheStats;
import org.springframework.util.Assert;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.*;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
 * 拦截caffeine的Caffeine.build/buildAsync方法,进行两次build调用，从而生成两个Cache实例（一个给正常流量用，一个给压测流量用），并用一个新的Cache实例对它们进行包装。以防止压测流量影响正常流量
 */
public class CaffeineCacheInterceptor implements InstanceMethodsAroundInterceptor {
    @Override
    public void beforeMethod(Object objInst, Method method, Object[] allArguments,
                             Class<?>[] argumentsTypes, MethodInterceptResult result) throws Throwable {
        //do nothing
    }

    @Override
    public Object afterMethod(Object zuperCall, Object objInst, Method method, Object[] allArguments,
                              Class<?>[] argumentsTypes, Object ret) throws Throwable {
        String methodName = method.getName();
        if(!"build".equals(method.getName()) && !"buildAsync".equals(method.getName())){
            return ret;
        }
        int parameterCount = method.getParameterCount();

        if ("buildAsync".equals(methodName) && parameterCount == 1) {//Caffeine.buildAsync方法
            AsyncLoadingCache cacheForNormal = (AsyncLoadingCache)ret;
            AsyncLoadingCache cacheForPerfTest = (AsyncLoadingCache)((Callable<?>)zuperCall).call();
            return new AutoRoutingLocalAsyncLoadingCache(cacheForNormal, cacheForPerfTest);
        } else if ("build".equals(methodName) && parameterCount == 0) {//Caffeine.build无参方法
            Cache cacheForNormal = (Cache)ret;
            Cache cacheForPerfTest = (Cache)((Callable<?>)zuperCall).call();
            return new AutoRoutingLocalCache(cacheForNormal, cacheForPerfTest);
        } else if ("build".equals(methodName) && parameterCount == 1) {//Caffeine.build带CacheLoader的方法
            LoadingCache cacheForNormal = (LoadingCache)ret;
            LoadingCache cacheForPerfTest = (LoadingCache)((Callable<?>)zuperCall).call();
            return new AutoRoutingLocalLoadingCache(cacheForNormal, cacheForPerfTest);
        } else {
            throw new IllegalStateException("will never be here");
        }
    }

    @Override
    public void handleMethodException(Object objInst, Method method, Object[] allArguments,
                                      Class<?>[] argumentsTypes, Throwable t) {
        //do nothing
    }

    /**
     * 内部封装了两个cache实例，分别给正常/压测流量使用，根据当前是否压测来决定路由给哪个cache
     * @param <K>
     * @param <V>
     */
    static class AutoRoutingLocalCache<K, V> implements Cache<K, V>{

        private Cache<K, V> cacheForNormal;//给正常请求使用的cache
        private Cache<K, V> cacheForPerfTest;//给压测请求实用的cache

        private AutoRoutingLocalCache(Cache<K, V> cacheForNormal, Cache<K, V> cacheForPerfTest){
            this.cacheForNormal = cacheForNormal;
            this.cacheForPerfTest = cacheForPerfTest;
        }

        protected Cache<K, V> determineCache(){
            return InternalPerfTestContext.isCurrentInPerfTestMode() ? cacheForPerfTest : cacheForNormal;
        }

        // Cache methods

        @Override
        @CheckForNull
        public V getIfPresent(@Nonnull Object key) {
            return determineCache().getIfPresent(key);
        }

        @CheckForNull
        @Override
        public V get(@Nonnull K key, @Nonnull Function<? super K, ? extends V> mappingFunction) {
            return determineCache().get(key, mappingFunction);
        }

        @Nonnull
        @Override
        public Map<K, V> getAllPresent(@Nonnull Iterable<?> keys) {
            return determineCache().getAllPresent(keys);
        }

        @Override
        public void put(@Nonnull K key, @Nonnull V value) {
            determineCache().put(key, value);
        }

        @Override
        public void putAll(@Nonnull Map<? extends K, ? extends V> map) {
            determineCache().putAll(map);
        }

        @Override
        public void invalidate(@Nonnull Object key) {
            determineCache().invalidate(key);
        }

        @Override
        public void invalidateAll(@Nonnull Iterable<?> keys) {
            determineCache().invalidateAll(keys);
        }

        @Override
        public void invalidateAll() {
            determineCache().invalidateAll();
        }

        @Override
        public long estimatedSize() {
            return determineCache().estimatedSize();
        }

        @Nonnull
        @Override
        public CacheStats stats() {
            return determineCache().stats();
        }

        @Nonnull
        @Override
        public ConcurrentMap<K, V> asMap() {
            return determineCache().asMap();
        }

        @Override
        public void cleanUp() {
            determineCache().cleanUp();
        }

        @Nonnull
        @Override
        public Policy<K, V> policy() {
            return determineCache().policy();
        }

        private void readObject(ObjectInputStream stream) throws InvalidObjectException {
            Cache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("readObject", ObjectInputStream.class);
                method.setAccessible(true);
                method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }

        Object writeReplace() {
            Cache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("writeReplace");
                method.setAccessible(true);
                return method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }

    }

    /**
     * 内部封装了两个cache实例，分别给正常/压测流量使用，根据当前是否压测来决定路由给哪个cache
     * @param <K>
     * @param <V>
     */
    static class AutoRoutingLocalLoadingCache<K, V> extends AutoRoutingLocalCache<K, V> implements LoadingCache<K, V>{

        private AutoRoutingLocalLoadingCache(LoadingCache<K, V> cacheForNormal, LoadingCache<K, V> cacheForPerfTest){
            super(cacheForNormal, cacheForPerfTest);
        }

        protected LoadingCache<K, V> determineCache(){
            return (LoadingCache)super.determineCache();
        }

        // LoadingCache methods

        @CheckForNull
        @Override
        public V get(@Nonnull K key) {
            return determineCache().get(key);
        }

        @Nonnull
        @Override
        public Map<K, V> getAll(@Nonnull Iterable<? extends K> keys) {
            return determineCache().getAll(keys);
        }

        @Override
        public void refresh(@Nonnull K key) {
            determineCache().refresh(key);
        }

        private void readObject(ObjectInputStream stream) throws InvalidObjectException {
            LoadingCache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("readObject", ObjectInputStream.class);
                method.setAccessible(true);
                method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }

        Object writeReplace() {
            LoadingCache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("writeReplace");
                method.setAccessible(true);
                return method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }

    }

    static class AutoRoutingLocalAsyncLoadingCache<K, V> implements AsyncLoadingCache<K, V> {
        private AsyncLoadingCache<K, V> cacheForNormal;//给正常请求使用的cache
        private AsyncLoadingCache<K, V> cacheForPerfTest;//给压测请求实用的cache

        private AutoRoutingLocalAsyncLoadingCache(AsyncLoadingCache<K, V> cacheForNormal, AsyncLoadingCache<K, V> cacheForPerfTest){
            this.cacheForNormal = cacheForNormal;
            this.cacheForPerfTest = cacheForPerfTest;
        }

        protected AsyncLoadingCache<K, V> determineCache(){
            return InternalPerfTestContext.isCurrentInPerfTestMode() ? cacheForPerfTest : cacheForNormal;
        }

        @CheckForNull
        @Override
        public CompletableFuture<V> getIfPresent(@Nonnull Object key) {
            return determineCache().getIfPresent(key);
        }

        @Nonnull
        @Override
        public CompletableFuture<V> get(@Nonnull K key, @Nonnull Function<? super K, ? extends V> mappingFunction) {
            return determineCache().get(key, mappingFunction);
        }

        @Nonnull
        @Override
        public CompletableFuture<V> get(@Nonnull K key, @Nonnull BiFunction<? super K, Executor, CompletableFuture<V>> mappingFunction) {
            return determineCache().get(key, mappingFunction);
        }

        @Nonnull
        @Override
        public CompletableFuture<V> get(@Nonnull K key) {
            return determineCache().get(key);
        }

        @Nonnull
        @Override
        public CompletableFuture<Map<K, V>> getAll(@Nonnull Iterable<? extends K> keys) {
            return determineCache().getAll(keys);
        }

        @Override
        public void put(@Nonnull K key, @Nonnull CompletableFuture<V> valueFuture) {
            determineCache().put(key, valueFuture);
        }

        @Nonnull
        @Override
        public LoadingCache<K, V> synchronous() {
            return determineCache().synchronous();
        }

        private void readObject(ObjectInputStream stream) throws InvalidObjectException {
            AsyncLoadingCache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("readObject", ObjectInputStream.class);
                method.setAccessible(true);
                method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }

        Object writeReplace() {
            AsyncLoadingCache cache = determineCache();
            try {
                Method method = cache.getClass().getDeclaredMethod("writeReplace");
                method.setAccessible(true);
                return method.invoke(cache);
            } catch (NoSuchMethodException|InvocationTargetException|IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }
}
