package cn.com.duibaboot.ext.autoconfigure.cloud.netflix.ribbon;

import cn.com.duiba.boot.netflix.ribbon.RibbonServerListFilter;
import cn.com.duiba.boot.netflix.ribbon.RibbonServerPredicate;
import cn.com.duiba.boot.perftest.PerfTestContext;
import cn.com.duibaboot.ext.autoconfigure.cloud.netflix.eureka.DiscoveryMetadataAutoConfiguration;
import cn.com.duibaboot.ext.autoconfigure.perftest.DubboPerfTestRegistryFactoryWrapper;
import cn.com.duiba.wolf.threadpool.NamedThreadFactory;
import cn.com.duiba.wolf.utils.ConcurrentUtils;
import cn.com.duiba.wolf.utils.NumberUtils;
import com.netflix.client.IClientConfigAware;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.*;
import com.netflix.niws.loadbalancer.DiscoveryEnabledServer;
import org.apache.http.pool.LeastActiveServerListFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.AutoConfigureBefore;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.context.embedded.EmbeddedWebApplicationContext;
import org.springframework.cloud.netflix.feign.CustomFeignClientsRegistrar;
import org.springframework.cloud.netflix.ribbon.PropertiesFactory;
import org.springframework.cloud.netflix.ribbon.RibbonAutoConfiguration;
import org.springframework.cloud.netflix.ribbon.SpringClientFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Profile;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 在 RibbonClientConfiguration 执行之前先定义自己的一些ribbon组件，以达到如下目的：<br/>
 * 1.使用自定义的ping组件，每隔3秒访问服务器/monitor/check接口，如果不通或者返回不是OK，则标记为dead状态，去除对该服务器的流量（官方的只判断了从eureka取下来的状态是否UP，存在很大延迟.）<br/>
 * 2.使用自定义的支持预热的rule组件（rule组件用于在每次调用接口时取得一个服务器），先尝试从isAlive的服务器中选一个，如果拿不到再转发到上层(上层方法默认获取全部服务器,并没有判断isAlive状态)<br/>
 * 3.设置ILoadBalancer中的pingStrategy，官方的pingStrategy是顺序执行的，假设服务器很多，执行速度会较慢(对于一次对全部服务器的ping，有个默认超时时间5秒)，这里替换成使用并发执行的策略,以加快速度
 */
@Configuration
@Import(LeastActiveServerListFilter.class)
@ConditionalOnClass({IPing.class, IRule.class, ILoadBalancer.class, IPingStrategy.class})
public class RibbonCustomAutoConfiguration {

    private static final Logger logger = LoggerFactory.getLogger(RibbonCustomAutoConfiguration.class);

    @Resource
    private SpringClientFactory springClientFactory;

    /**
     * 扫描所有Feign客户端名，并初始化context。以防止启动后第一次进行Feign调用很慢的问题。
     * @return
     */
    @Bean
    public ApplicationListener<ContextRefreshedEvent> springClientFactoryInitListener(){
        return new ApplicationListener<ContextRefreshedEvent>() {
            @Override
            public void onApplicationEvent(ContextRefreshedEvent event) {
                //扫描所有FeignClient并初始化！！每个大约要两秒，所以要并发运行
                if(event.getApplicationContext() instanceof EmbeddedWebApplicationContext) {//只在主ApplicationContext中运行，以防止死循环
                    Set<String> enabledFeignClientNames = CustomFeignClientsRegistrar.getEnabledFeignClientNames();
//                    if(enabledFeignClientNames.size() <= 2){
                    for(final String name : enabledFeignClientNames) {
                        springClientFactory.getClientConfig(name);
                    }
//                    }
//                    else if(enabledFeignClientNames.size() > 1) {
//                        ExecutorService executor = Executors.newCachedThreadPool();
//                        List<Runnable> runnables = new ArrayList<>(enabledFeignClientNames.size());
//                        for(final String name : enabledFeignClientNames){
//                            runnables.add(() -> springClientFactory.getClientConfig(name));
//                        }
//
//                        try {
//                            ConcurrentUtils.executeTasksBlocking(executor, runnables);
//                        } catch (InterruptedException e) {
//                            //Ignore
//                        }finally{
//                            executor.shutdownNow();
//                        }
//                    }
                }
            }
        };
    }

    /**
     * 让ribbon使用自定义的类，包括：<br/>
     * 1.使用自定义的ping组件，每隔3秒访问服务器/monitor/check接口，如果不通或者返回不是OK，则标记为dead状态，去除对该服务器的流量（官方的只判断了从eureka取下来的状态是否UP，存在很大延迟.）<br/>
     * 2. 使用自定义的支持预热的rule组件（rule组件用于在每次调用接口时取得一个服务器），先尝试从isAlive的服务器中选一个，如果拿不到再转发到上层(上层方法默认获取全部服务器,并没有判断isAlive状态)<br/>
     * 算法：根据服务器启动时间计算权重进行轮询。
     */
    @Configuration
    @AutoConfigureBefore(RibbonAutoConfiguration.class)
    @ConditionalOnClass({IPing.class, IRule.class, ILoadBalancer.class, IPingStrategy.class})
    public static class RibbonPropertiesFactoryConfiguration{

        @Bean
        public PropertiesFactory propertiesFactory() {
            return new CustomPropertiesFactory();
        }

    }

    private static class CustomPropertiesFactory extends PropertiesFactory{

        @Resource
        ApplicationContext context;

        @Override
        public String getClassName(Class clazz, String name) {
            if(clazz.equals(IRule.class)){
                return CustomZoneAvoidanceRule.class.getName();
            }else if(clazz.equals(IPing.class)){
                return RibbonPing.class.getName();
            }else if(clazz.equals(ServerList.class)){
                return CustomRibbonServerList.class.getName();
            }
            return super.getClassName(clazz, name);
        }

        @Override
        @SuppressWarnings("unchecked")
        public <C> C get(Class<C> clazz, IClientConfig config, String name) {
            if(!ServerList.class.equals(clazz)){
                return super.get(clazz, config, name);
            }
            String className = getClassName(clazz, name);
            if (StringUtils.hasText(className)) {
                try {
                    Class<?> toInstantiate = Class.forName(className);
                    return (C) instantiateWithConfig(context, toInstantiate, config);
                } catch (ClassNotFoundException e) {
                    throw new IllegalArgumentException("Unknown class to load "+className+" for class " + clazz + " named " + name);
                }
            }
            return null;
        }

        <C> C instantiateWithConfig(ApplicationContext context,
                Class<C> clazz, IClientConfig config) {
            C result = null;

            try {
                Constructor<C> constructor = clazz.getConstructor(IClientConfig.class);
                result = constructor.newInstance(config);
            } catch (Throwable e) {
                // Ignored
            }

            if (result == null) {
                result = BeanUtils.instantiate(clazz);

                if (result instanceof IClientConfigAware) {
                    ((IClientConfigAware) result).initWithNiwsConfig(config);
                }

                if (context != null) {
                    context.getAutowireCapableBeanFactory().autowireBean(result);
                }
//                        if(result instanceof InitializingBean){
//                            try {
//                                ((InitializingBean) result).afterPropertiesSet();
//                            } catch (Exception e) {
//                                throw new BeanCreationException("", e);
//                            }
//                        }
            }

            return result;
        }
    }


    @Bean
    public IPingStrategy getConcurrentRibbonPingStrategy(){
        return new ConcurrentPingStrategy();
    }

    /**
     * 设置ILoadBalancer中的pingStrategy，官方的pingStrategy是顺序执行的，假设服务器很多，执行速度会较慢(对于一次对全部服务器的ping，有个默认超时时间5秒)，这里替换成使用并发执行的策略,以加快速度
     * @return
     */
    @Bean
    public ApplicationListener<ContextRefreshedEvent> iLoadBalancerInitListener(){
        //这里不用BeanPostProcessor是因为ILoadBalancer是在spring子容器中生成的，子容器默认不会应用BeanPostProcessor，但ApplicationListener仍然生效
        return new ApplicationListener<ContextRefreshedEvent>() {
            @Override
            public void onApplicationEvent(ContextRefreshedEvent event) {
                try {
                    ILoadBalancer bean = event.getApplicationContext().getBean(ILoadBalancer.class);
                    Field pingStrategyField = ReflectionUtils.findField(bean.getClass(), "pingStrategy");
                    if(pingStrategyField != null) {
                        pingStrategyField.setAccessible(true);
                        ReflectionUtils.setField(pingStrategyField, bean, getConcurrentRibbonPingStrategy());
                    }else{
                        logger.error("[NOTIFYME]ILoadBalancer中的pingStrategy Field不存在，无法设置使用并发ping策略");
                    }
                }catch(NoSuchBeanDefinitionException e){
                    //Ignore
                }
            }
        };
    }

    /**
     * ribbon列表过滤器,用于开发环境，开发环境下feign会优先调用本机服务，如果没有本机服务，则优先调用jar方式启动的服务
     * @return
     */
    @Bean
    @Profile("dev")
    public DevRibbonServerListFilter devRibbonServerListFilter(){
        return new DevRibbonServerListFilter(true);
    }

    /**
     * 官方是顺序执行的，改为使用并发执行ping
     */
    private static class ConcurrentPingStrategy implements IPingStrategy, DisposableBean {
        private final ExecutorService executor = Executors.newFixedThreadPool(20, new NamedThreadFactory("ribbonPing"));

        private static final Logger logger = LoggerFactory.getLogger(ConcurrentPingStrategy.class);

        @Override
        public boolean[] pingServers(final IPing ping, Server[] servers) {
            int numCandidates = servers.length;
            boolean[] results = new boolean[numCandidates];/* Default answer is DEAD. */

            if (logger.isDebugEnabled()) {
                logger.debug("LoadBalancer:  PingTask executing ["
                        + numCandidates + "] servers configured");
            }
            if(ping == null){
                logger.error("IPing is null, will mark all servers as not alive");
                return results;
            }

            List<Callable<Boolean>> callables = new ArrayList<>();
            for (final Server server : servers) {
                callables.add(() -> ping.isAlive(server));
//                results[i] = ping.isAlive(servers[i]);
            }

            if(executor.isTerminated()){
//                logger.error("IPing's executor is terminated(most likely your ApplicationContext is destroyed), will mark all servers as not alive");
                return results;
            }

            try {
                List<Boolean> ret = ConcurrentUtils.submitTasksBlocking(executor, callables);
                int i = 0;
                for(Boolean b : ret){
                    results[i++] = b;
                }
            } catch (InterruptedException e) {
                //Ignore
                Thread.currentThread().interrupt();
            }

            return results;
        }

        @Override
        public void destroy() throws Exception {
            executor.shutdown();
        }
    }

    /**
     * 使用自定义的rule组件（rule组件用于在每次调用接口时取得一个服务器），先尝试从isAlive的服务器中选一个，如果拿不到再转发到上层(上层方法默认获取全部服务器,并没有判断isAlive状态)
     * @return
     */
    public static class CustomZoneAvoidanceRule extends ZoneAvoidanceRule{
        private final AtomicInteger nextIndex = new AtomicInteger();

        private RibbonServerPredicate compositeRibbonServerPredicate;
        private List<RibbonServerListFilter> ribbonServerListFilters = Collections.emptyList();

        @Autowired(required=false)
        public void setRibbonServerPredicates(List<RibbonServerPredicate> ribbonServerPredicates
                , List<RibbonServerListFilter> ribbonServerListFilters){
            if(ribbonServerPredicates != null && !ribbonServerPredicates.isEmpty()) {
                this.compositeRibbonServerPredicate = discoveryEnabledServer -> ribbonServerPredicates.stream().allMatch(p -> p.test(discoveryEnabledServer));
            }
            if(ribbonServerListFilters != null) {
                this.ribbonServerListFilters = ribbonServerListFilters;
            }
        }

        /**
         * 加入自定义的过滤逻辑
         * @param serverList
         * @return
         */
        private List<Server> filterServerList(List<Server> serverList){
            for(RibbonServerListFilter filter : ribbonServerListFilters){
                serverList = filter.filter(serverList);
            }

            if(compositeRibbonServerPredicate == null){
                return serverList;
            }
            List<Server> filtedList = new ArrayList<>();
            for(Server server : serverList){
                if(server instanceof DiscoveryEnabledServer
                        && compositeRibbonServerPredicate.test((DiscoveryEnabledServer)server)){
                    filtedList.add(server);
                }
            }

            return filtedList;
        }

        @Override
        public Server choose(Object key) {
            //copy from PredicateBasedRule.choose
            ILoadBalancer lb = getLoadBalancer();
            //注意这里调用的是getReachableServers方法，只会拿到isAlive的服务器列表进行选择
            List<Server> eligible = getPredicate().getEligibleServers(lb.getReachableServers(), key);
            eligible = filterServerList(eligible);

            Server choosedServer;
            if (!eligible.isEmpty()) {
                choosedServer = chooseRoundRobinWithTimeBasedWeight(eligible);
            }else {
                //如果找不到，那么使用上级的方法,super中会使用从eureka拿到的所有UP状态的服务器
                choosedServer = super.choose(key);
            }

            if(choosedServer != null && choosedServer instanceof DiscoveryEnabledServer){
                DiscoveryEnabledServer server = (DiscoveryEnabledServer)choosedServer;
                String isPerfTestSupportted = server.getInstanceInfo().getMetadata().get(DubboPerfTestRegistryFactoryWrapper.IS_PERF_TEST_SUPPORTTED_KEY);
                if(!"1".equals(isPerfTestSupportted) && PerfTestContext.isCurrentInPerfTestMode()){
                    throw new IllegalStateException("识别到当前请求是压测流量，但是调用的服务不支持性能压测，放弃本次请求，请通知该服务负责人加入spring-boot-starter-perftest依赖；尝试调用的服务端：" +
                            server.getInstanceInfo().getVIPAddress()+
                            ",ip:"+server.getInstanceInfo().getIPAddr()+
                            ",port:"+server.getInstanceInfo().getPort());
                }
            }

            return choosedServer;
        }

        /**
         * 使用权重加权轮询算法获取一个服务器。其中权重根据服务器启动时间变化，以达到对于刚启动的服务器逐渐增加流量的目的
         * @param eligible
         * @return
         */
        private Server chooseRoundRobinWithTimeBasedWeight(List<Server> eligible){
//            return eligible.get(nextIndex.getAndIncrement() % eligible.size());
            int length = eligible.size(); // 总个数
            int totalWeight = 0; // 总权重
            boolean sameWeight = true; // 权重是否都一样
            int[] weights = new int[length];
            for (int i = 0; i < length; i++) {
                int weight = getTimeBasedWeight(eligible.get(i));
                weights[i] = weight;
                totalWeight += weight; // 累计总权重
                if (sameWeight && i > 0
                        && weight != weights[i - 1]) {
                    sameWeight = false; // 计算所有权重是否一样
                }
            }
            if (totalWeight > 0 && ! sameWeight) {
                // 如果权重不相同且权重大于0则按总权重数轮询
                int offset = nextIndex.getAndIncrement() % totalWeight;
                // 并确定下一个值落在哪个片断上
                for (int i = 0; i < length; i++) {
                    offset -= weights[i];
                    if (offset < 0) {
                        return eligible.get(i);
                    }
                }
            }

            // 如果权重相同或权重为0则直接轮询即可
            return eligible.get(nextIndex.getAndIncrement() % eligible.size());
        }

        /**
         * 根据服务器启动时间计算服务器权重，启动时间越久，权重越大(预热时间5分钟，5分钟后weight会达到服务器设定的最大值)
         * @param server
         * @return
         */
        private int getTimeBasedWeight(Server server) {
            int weight = 100;
            if(server instanceof DiscoveryEnabledServer){
                String weightStr = ((DiscoveryEnabledServer) server).getInstanceInfo().getMetadata().get(DiscoveryMetadataAutoConfiguration.WEIGHT_KEY);
                weight = NumberUtils.parseInt(weightStr, 100);
                weight = Math.max(weight, 0);
            }

            if (weight > 0) {
                //获取目标服务器启动时间
                long timestamp = 0;
                if(server instanceof DiscoveryEnabledServer){
                    String tsStr = ((DiscoveryEnabledServer) server).getInstanceInfo().getMetadata().get(DiscoveryMetadataAutoConfiguration.SERVER_START_UP_TIME_KEY);
                    timestamp = NumberUtils.parseLong(tsStr,0);
                    timestamp = Math.max(timestamp, 0);
                }

                if (timestamp > 0L) {
                    int uptime = (int) (System.currentTimeMillis() - timestamp);
                    int warmup = 300000;//预热时间，单位:ms
                    if (uptime > 0 && uptime < warmup) {
                        weight = calculateWarmupWeight(uptime, warmup, weight);
                    }
                }
            }
            return weight;
        }

        private int calculateWarmupWeight(int uptime, int warmup, int weight) {
            int ww = (int) ( (float) uptime / ( (float) warmup / (float) weight ) );
            if(ww < 1){
                return 1;
            }
            return (ww > weight ? weight : ww);
        }

    }
}
