package cn.com.duibabiz.component.redis.impl;

import cn.com.duibabiz.component.redis.RedisCustomClient;
import cn.com.duibabiz.component.redis.RedisCustomClientProperties;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.SerializationUtils;
import org.springframework.util.Assert;

import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * @author zhenghuan (zhenghuan@duiba.com.cn)
 * @date 2019/6/3 0003 15:33
 */
public class RedisCustomClientImpl<K, V> implements RedisCustomClient<K, V> {

    @Resource(name = "redisTemplate")
    private RedisTemplate<K, V> redisTemplate;

    @Autowired
    private RedisCustomClientProperties properties;

    @Override
    public V get(Object key) {
        return opsForValue().get(key);
    }

    /**
     * 为了兼容原生写法这里还是保留Collection
     *
     * @param keys must not be {@literal null}.
     * @return
     */
    @Override
    public List<V> multiGet(Collection<K> keys) {
        if (CollectionUtils.isEmpty(keys)) {
            return Collections.emptyList();
        }
        Assert.isTrue(keys.size() <= properties.getKeysMax(), "keys个数超过最大值");

        List<byte[]> result = new ArrayList<>(keys.size());
        //总数,当前循环到keys的哪个位置
        int counter = 0;
        //每次循环
        int index = 0;
        byte[][] rawKeys = getRawKeys(keys, counter);
        for (K hashKey : keys) {
            rawKeys[index] = rawKey(hashKey);
            index++;
            counter++;
            //表示单次达到请求值阀值,需要请求redis
            if (index == properties.getSplit()
                    //表示已经达到最大数量,需要请求redis
                    || counter == keys.size()) {
                index = 0;
                List<byte[]> rawValues = execute(rawKeys);
                result.addAll(rawValues);
                rawKeys = getRawKeys(keys, counter);
            }
        }

        return deserializeValues(result);
    }

    private byte[][] getRawKeys(Collection<K> keys, int counter) {
        return new byte[Math.min(keys.size() - counter, properties.getSplit())][];
    }

    private List<byte[]> execute(final byte[][] rawKeys) {
        return redisTemplate.execute(connection -> connection.mGet(rawKeys), true);
    }

    @Override
    public Long increment(K key, long delta) {
        return opsForValue().increment(key, delta);
    }

    @Override
    public Double increment(K key, double delta) {
        return opsForValue().increment(key, delta);
    }

    @Override
    public Boolean expire(K key, long timeout, TimeUnit unit) {
        return redisTemplate.expire(key, timeout, unit);
    }

    @Override
    public void set(K key, V value, long timeout, TimeUnit unit) {
        opsForValue().set(key, value, timeout, unit);
    }

    @Override
    public Boolean setIfAbsent(K key, V value) {
        return opsForValue().setIfAbsent(key, value);
    }

    @SuppressWarnings("unchecked")
    private byte[] rawKey(Object key) {
        Assert.notNull(key, "non null key required");
        if (keySerializer() == null && key instanceof byte[]) {
            return (byte[]) key;
        }
        return keySerializer().serialize(key);
    }

    @SuppressWarnings("unchecked")
    private List<V> deserializeValues(List<byte[]> rawValues) {
        if (valueSerializer() == null) {
            return (List<V>) rawValues;
        }
        return SerializationUtils.deserialize(rawValues, valueSerializer());
    }

    private RedisSerializer keySerializer() {
        return redisTemplate.getKeySerializer();
    }

    private RedisSerializer valueSerializer() {
        return redisTemplate.getValueSerializer();
    }

    private ValueOperations<K, V> opsForValue() {
        return redisTemplate.opsForValue();
    }
}
