package cn.com.duiba.cat.message.io;

import cn.com.duiba.cat.Cat;
import cn.com.duiba.cat.exception.ResourceNotFoundException;
import cn.com.duiba.cat.message.internal.MessageIdFactory;
import cn.com.duiba.cat.model.configuration.ClientConfigService;
import cn.com.duiba.cat.model.configuration.DefaultClientConfigService;
import cn.com.duiba.cat.model.configuration.client.entity.Server;
import cn.com.duiba.cat.util.Pair;
import cn.com.duiba.cat.util.Threads;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class ChannelManager implements Threads.Task {

    private static final Logger LOGGER = LoggerFactory.getLogger(ChannelManager.class);

    private static ChannelManager      instance      = new ChannelManager();
    private        ClientConfigService configService = DefaultClientConfigService.getInstance();
    private        Bootstrap           bootstrap;
    private        boolean             active        = true;
    private        ChannelsHolder      activeChannelHolder;
    private        MessageIdFactory    idFactory     = MessageIdFactory.getInstance();
    private        int                 reconnectCount;
    private final  Random              random        = new Random();

    private ChannelManager() {
        EventLoopGroup group = new NioEventLoopGroup(1, r -> {
            Thread t = new Thread(r);

            t.setDaemon(true);
            return t;
        });

        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(group).channel(NioSocketChannel.class);
        bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
        bootstrap.handler(new ChannelInitializer<Channel>() {

            @Override
            protected void initChannel(Channel ch) {
            }

        });
        this.bootstrap = bootstrap;

        Set<Server> servers = configService.getServers();

        if (CollectionUtils.isNotEmpty(servers)) {
            List<InetSocketAddress> configuredAddresses = parseSocketAddress(servers);
            ChannelsHolder holder = initChannel(configuredAddresses, servers);

            if (holder != null) {
                activeChannelHolder = holder;
            }
        }
    }

    public static ChannelManager getInstance() {
        return instance;
    }

    /**
     * 决出一个通道使用
     *
     * @param channels
     * @return
     */
    private ChannelEntity determineChannelEntity() {
        List<String> activeIps = activeChannelHolder.getActiveIps();
        int length = activeIps.size(); // 总个数
        int totalWeight = 0; // 总权重
        boolean sameWeight = true; // 权重是否都一样
        int[] weights = new int[length];
        for (int i = 0; i < length; i++) {
            String ip = activeIps.get(i);
            Server server = activeChannelHolder.getServer(ip);
            int weight = server.getWeight();
            weights[i] = weight;
            totalWeight += weight; // 累计总权重
            if (sameWeight && i > 0 && weight != weights[i - 1]) {
                sameWeight = false; // 计算所有权重是否一样
            }
        }
        if (totalWeight > 0 && !sameWeight) {
            // 如果权重不相同且权重大于0则按总权重数随机
            int offset = random.nextInt(totalWeight);
            // 并确定下一个值落在哪个片断上
            for (int i = 0; i < length; i++) {
                offset -= weights[i];
                if (offset < 0) {
                    String ip = activeIps.get(i);
                    return activeChannelHolder.getChannelEntity(ip);
                }
            }
        }
        // 如果权重相同或权重为0则直接随机即可
        String ip = activeIps.get(random.nextInt(length));
        return activeChannelHolder.getChannelEntity(ip);
    }

    public ChannelFuture channelFuture() {
        if (activeChannelHolder != null) {
            ChannelEntity channelEntity = this.determineChannelEntity();
            if (checkWritable(channelEntity)) {
                return channelEntity.getChannelFuture();
            }
        }
        return null;
    }

    private boolean checkActive(ChannelFuture future) {
        boolean isActive = false;

        if (future != null) {
            Channel channel = future.channel();

            if (channel.isActive() && channel.isOpen()) {
                isActive = true;
            } else {
                LOGGER.error("channel buf is not active ,current channel " + future.channel().remoteAddress());
            }
        }

        return isActive;
    }

    private void checkServerChanged() {
        Pair<Boolean, Set<Server>> pair = serversChanged();

        if (pair.getKey()) {
            LOGGER.info("router config changed :" + pair.getValue());
            Set<Server> servers = pair.getValue();
            List<InetSocketAddress> serverAddresses = parseSocketAddress(servers);
            ChannelsHolder newHolder = initChannel(serverAddresses, servers);

            if (newHolder != null) {
                if (newHolder.isConnectChanged()) {
                    ChannelsHolder last = activeChannelHolder;

                    activeChannelHolder = newHolder;
                    closeChannelHolder(last);
                    LOGGER.info("switch active channel to " + activeChannelHolder);
                } else {
                    activeChannelHolder = newHolder;
                }
            } else {
                ChannelsHolder last = activeChannelHolder;

                activeChannelHolder = null;
                closeChannelHolder(last);
                LOGGER.info("switch active channel to [none]");
            }
        }
    }

    private boolean checkWritable(ChannelEntity channelEntity) {
        boolean isWritable = false;

        if (channelEntity != null) {
            Channel channel = channelEntity.getChannelFuture().channel();

            if (channel.isActive() && channel.isOpen()) {
                if (channel.isWritable()) {
                    isWritable = true;
                } else {
                    channel.flush();
                }
            } else {
                int count = channelEntity.getAttempts().incrementAndGet();

                if (count % 1000 == 0 || count == 1) {
                    LOGGER.error("channel buf is is close when send msg! Attempts: " + count);
                }
            }
        }

        return isWritable;
    }

    private void closeChannelFuture(ChannelFuture channelFuture) {
        if (channelFuture != null) {
            SocketAddress address = channelFuture.channel().remoteAddress();

            if (address != null) {
                LOGGER.info("close channel " + address);
            }
            channelFuture.channel().close();
        }
    }

    private void closeChannelHolder(ChannelsHolder channelHolder) {
        if (channelHolder == null) {
            return;
        }
        Collection<ChannelEntity> channels = channelHolder.getChannels();
        for (ChannelEntity cf : channels) {
            closeChannelFuture(cf.getChannelFuture());
        }
    }

    private ChannelFuture createChannelFuture(InetSocketAddress address) {
        LOGGER.info("start connect server" + address.toString());
        ChannelFuture future = null;

        try {
            future = bootstrap.connect(address);
            future.awaitUninterruptibly(configService.getClientConnectTimeout(), TimeUnit.MILLISECONDS); // 100 ms

            if (!future.isSuccess()) {
                LOGGER.error("Error when try connecting to " + address);
                closeChannelFuture(future);
            } else {
                LOGGER.info("Connected to CAT server at " + address);
                return future;
            }
        } catch (Throwable e) {
            LOGGER.error("Error when connect server " + address.getAddress(), e);

            if (future != null) {
                closeChannelFuture(future);
            }
        }
        return null;
    }

    @Override
    public String getName() {
        return "netty-channel-health-check";
    }

    private ChannelsHolder initChannel(List<InetSocketAddress> addresses, Set<Server> servers) {
        if (CollectionUtils.isEmpty(addresses)) {
            return null;
        }
        ChannelsHolder newHolder = new ChannelsHolder();
        boolean connectChanged = false;

        Map<String, Server> ipServerMap = servers.stream().collect(Collectors.toMap(Server::getIp, x -> x));

        for (InetSocketAddress address : addresses) {
            String ip = address.getAddress().getHostAddress();
            Server server = ipServerMap.get(ip);

            ChannelEntity newChannelEntity = new ChannelEntity(ip, server, address);

            ChannelFuture existChannelFuture = null;
            if (activeChannelHolder != null) {
                existChannelFuture = activeChannelHolder.getActiveChannelFuture(ip);
            }

            if (existChannelFuture != null) {
                newChannelEntity.setChannelFuture(existChannelFuture);
            } else {
                ChannelFuture newChannelFuture = createChannelFuture(address);
                if (newChannelFuture == null) {
                    LOGGER.error("createChannelFuture error, ip={}", ip);
                    continue;
                }
                newChannelEntity.setChannelFuture(newChannelFuture);
                connectChanged = true;
            }
            newHolder.addChannelEntity(newChannelEntity);
        }
        if (activeChannelHolder == null || (!connectChanged && activeChannelHolder.getActiveIps().size() != addresses.size())) {
            connectChanged = true;
        }
        newHolder.setConnectChanged(connectChanged);
        return newHolder;
    }

    private boolean isChannelStalled(ChannelEntity channel) {
        boolean active = checkActive(channel.getChannelFuture());
        AtomicInteger channelStalledTimesCounter = channel.getStalledTimesCounter();

        if (!active) {
            return channelStalledTimesCounter.incrementAndGet() % 3 == 0;
        } else {
            if (channelStalledTimesCounter.get() > 0) {
                channelStalledTimesCounter.decrementAndGet();
            }
            return false;
        }
    }

    private List<InetSocketAddress> parseSocketAddress(Set<Server> servers) {
        if (CollectionUtils.isEmpty(servers)) {
            return Collections.emptyList();
        }

        List<InetSocketAddress> address = new ArrayList<>();
        for (Server server : servers) {
            address.add(new InetSocketAddress(server.getIp(), server.getPort()));
        }
        return address;
    }

    private Pair<Boolean, Set<Server>> serversChanged() {
        Set<Server> servers = configService.getServers();
        if (activeChannelHolder == null) {
            if (CollectionUtils.isEmpty(servers)) {
                return new Pair<>(false, servers);
            } else {
                return new Pair<>(true, servers);
            }
        } else {
            if (activeChannelHolder.isEqualServers(servers)) {
                return new Pair<>(false, servers);
            } else {
                return new Pair<>(true, servers);
            }
        }
    }

    private void doubleCheckAndReconnectActiveServer(ChannelsHolder channelHolder) {
        if (channelHolder == null) {
            return;
        }
        Collection<ChannelEntity> activeChannels = channelHolder.getChannels();
        for (ChannelEntity channelEntity : activeChannels) {
            String ip = channelEntity.getIp();
            ChannelFuture currentChannelFuture = channelEntity.getChannelFuture();
            if (this.isChannelStalled(channelEntity)) {
                closeChannelFuture(currentChannelFuture);

                InetSocketAddress address = channelEntity.getAddress();
                ChannelFuture newFuture = createChannelFuture(address);
                if (newFuture == null) {
                    LOGGER.error("createChannelFuture error, ip={}", ip);
                    continue;
                }
                channelEntity.setChannelFuture(newFuture);
            }
        }
    }

    @Override
    public void run() {
        while (active) {
            try {
                if (Cat.isEnabled()) {
                    reconnectCount++;

                    if (reconnectCount % 10 == 0) {
                        // make save message id index async very 10 seconds
                        idFactory.saveMark();
                        checkServerChanged();

                        doubleCheckAndReconnectActiveServer(activeChannelHolder);
                    }
                }
            } catch (Exception e) {
                LOGGER.error("ChannelManager reconnect error", e);
            }
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                // ignore
            }
        }
    }

    @Override
    public void shutdown() {
        active = false;
    }

    public static class ChannelEntity {

        private final String ip;

        private final Server server;

        private final InetSocketAddress address;

        private final AtomicInteger stalledTimesCounter = new AtomicInteger(0);

        private final AtomicInteger attempts = new AtomicInteger(0);

        private ChannelFuture channelFuture;

        public ChannelEntity(String ip, Server server, InetSocketAddress address) {
            this.ip = ip;
            this.server = server;
            this.address = address;
        }

        public String getIp() {
            return ip;
        }

        public Server getServer() {
            return server;
        }

        public InetSocketAddress getAddress() {
            return address;
        }

        public AtomicInteger getStalledTimesCounter() {
            return stalledTimesCounter;
        }

        public AtomicInteger getAttempts() {
            return attempts;
        }

        public ChannelFuture getChannelFuture() {
            return channelFuture;
        }

        public void setChannelFuture(ChannelFuture channelFuture) {
            this.channelFuture = channelFuture;
        }
    }

    public static class ChannelsHolder {

        private final Map<String, ChannelEntity> channelMap = new HashMap<>();

        private final List<String> activeIps = new ArrayList<>();

        private boolean connectChanged;

        public boolean isConnectChanged() {
            return connectChanged;
        }

        public void setConnectChanged(boolean connectChanged) {
            this.connectChanged = connectChanged;
        }

        public Server getServer(String ip) {
            ChannelEntity channel = this.channelMap.get(ip);
            if (channel == null) {
                throw new ResourceNotFoundException("Channel not found, ip={}", ip);
            }
            return channel.getServer();
        }

        public ChannelFuture getActiveChannelFuture(String ip) {
            ChannelEntity channelEntity = this.channelMap.get(ip);
            if (channelEntity == null) {
                return null;
            }
            return channelEntity.getChannelFuture();
        }

        public List<String> getActiveIps() {
            return activeIps;
        }

        public Collection<ChannelEntity> getChannels() {
            return this.channelMap.values();
        }

        public void addChannelEntity(ChannelEntity newChannel) {
            this.channelMap.put(newChannel.getIp(), newChannel);
            this.activeIps.add(newChannel.getIp());
        }

        public ChannelEntity getChannelEntity(String ip) {
            ChannelEntity channelEntity = this.channelMap.get(ip);
            if (channelEntity == null) {
                throw new ResourceNotFoundException("Channel not found, ip={}", ip);
            }
            return channelEntity;
        }

        public boolean isEqualServers(Set<Server> servers) {
            if (CollectionUtils.isEmpty(servers)) {
                return this.channelMap.isEmpty();
            }
            if (servers.size() != this.channelMap.size()) {
                return false;
            }
            for (Server s : servers) {
                ChannelEntity channel = this.channelMap.get(s.getIp());
                if (channel == null) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            return this.activeIps.toString();
        }



    }

}
