package cn.com.duiba.wolf.redis;

import cn.com.duiba.catmonitor.CatInstance;
import cn.com.duiba.wolf.log.DegradeLogger;
import com.dianping.cat.Cat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.DefaultScriptExecutor;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * 提供Redis一些不直接支持的原子性的操作，只支持主从结构的redis，不支持阿里云集群版redis。
 * Created by wenqi.huang on 2017/3/6.
 */
public class RedisAtomicClient {
    private static final Logger logger = DegradeLogger.wrap(LoggerFactory.getLogger(RedisAtomicClient.class));

    //传入的redisTemplate不需要关心序列化方式，内部会自动使用其他序列化方式
    private final RedisTemplate redisTemplate;

    private RedisSerializer<String> keyRedisSerializer = new StringRedisSerializer();
    private RedisSerializer<String> stringRedisSerializer = new StringRedisSerializer();

    private StringScriptExecutor stringScriptExecutor;

    private static final String INCR_BY_WITH_TIMEOUT = "local v;" +
            " v = redis.call('incrBy',KEYS[1],ARGV[1]);" +
            "if tonumber(v) == 1 then\n" +
            "    redis.call('expire',KEYS[1],ARGV[2])\n" +
            "end\n" +
            "return v";

    public RedisAtomicClient(RedisTemplate redisTemplate){
        this.redisTemplate = redisTemplate;
        if (stringScriptExecutor == null) {
            this.stringScriptExecutor = new StringScriptExecutor(redisTemplate);
        }
    }

    /**
     * 根据key获得对应的long类型数值,不存在则返回null（本方法使用string序列化方式）
     * @param key
     * @return
     */
    public Long getLong(String key){
        final byte[] rawKey = rawKey(key);
        try {
            String val = (String)redisTemplate.execute(new RedisCallback<String>() {
                @Override
                public String doInRedis(RedisConnection connection) throws DataAccessException {
                    byte[] valueBytes = connection.get(rawKey);
                    return (String) stringRedisSerializer.deserialize(valueBytes);
                }
            }, true);

            if(val == null){
                return null;
            }else{
                return Long.valueOf(val);
            }
        } catch(Exception e){
            logger.error("get key error:"+key, e);
            if(CatInstance.isEnable()) {
                Cat.logError(e);
            }
            return null;
        }
    }

    /**
     * 计数器,支持设置失效时间，如果key不存在，则调用此方法后计数器为1（本方法使用string序列化方式）
     * @param key
     * @param delta 可以为负数
     * @param timeout 缓存失效时间
     * @param timeUnit 缓存失效时间的单位
     * @return
     */
    public Long incrBy(String key, long delta, long timeout, TimeUnit timeUnit){
        List<String> keys = new ArrayList<>();
        keys.add(key);
        long timeoutSeconds = timeUnit.SECONDS.convert(timeout, timeUnit);
        String[] args = new String[2];
        args[0] = String.valueOf(delta);
        args[1] = String.valueOf(timeoutSeconds);
        Object currentVal = stringScriptExecutor.execute(new DefaultRedisScript<String>(INCR_BY_WITH_TIMEOUT, String.class), keyRedisSerializer, stringRedisSerializer, keys, args);

        if(currentVal instanceof Long){
            return (Long)currentVal;
        }
        return Long.valueOf((String)currentVal);
    }

    private byte[] rawKey(String key){
        return keyRedisSerializer.serialize(key);
    }

    private class StringScriptExecutor extends DefaultScriptExecutor<String>{
        public StringScriptExecutor(RedisTemplate<String, ?> template) {
            super(template);
        }

        protected byte[][] keysAndArgs(RedisSerializer argsSerializer, List<String> keys, Object[] args) {
            final int keySize = keys != null ? keys.size() : 0;
            byte[][] keysAndArgs = new byte[args.length + keySize][];
            int i = 0;
            if (keys != null) {
                for (String key : keys) {
                    keysAndArgs[i++] = keyRedisSerializer.serialize(key);
                }
            }
            for (Object arg : args) {
                if (argsSerializer == null && arg instanceof byte[]) {
                    keysAndArgs[i++] = (byte[]) arg;
                } else {
                    keysAndArgs[i++] = argsSerializer.serialize(arg);
                }
            }
            return keysAndArgs;
        }
    }
}
