基于Redis和令牌桶算法实现的集群接口限流器

任何系统的性能都有一个上限,当并发量超过这个上限之后,可1能会对系统造成毁灭性地打击。因此在任何时刻我们都必须保证系统的并发请求数量不能超过某个阈值,限流就是为了完成这一目的。本限流器是基于Redis记录流量数据,实现对接口的精准限流,保证系统的稳定运行。

1 令牌桶算法

该限流器采用流行的令牌桶算法,现简单讲解令牌桶算法的原理

1.1 简介

令牌桶算法最初来源于计算机网络。在网络传输数据时,为了防止网络拥塞,需限制流出网络的流量,使流量以比较均匀的速度向外发送。令牌桶算法就实现了这个功能,可控制发送到网络上数据的数目,并允许突发数据的发送。

令牌桶算法是网络流量整形(Traffic Shaping)和速率限制(Rate Limiting)中最常使用的一种算法。典型情况下,令牌桶算法用来控制发送到网络上的数据的数目,并允许突发数据的发送。

大小固定的令牌桶可自行以恒定的速率源源不断地产生令牌。如果令牌不被消耗,或者被消耗的速度小于产生的速度,令牌就会不断地增多,直到把桶填满。后面再产生的令牌就会从桶中溢出。最后桶中可以保存的最大令牌数永远不会超过桶的大小。

传送到令牌桶的数据包需要消耗令牌。不同大小的数据包,消耗的令牌数量不一样。令牌桶这种控制机制基于令牌桶中是否存在令牌来指示什么时候可以发送流量。令牌桶中的每一个令牌都代表一个字节。如果令牌桶中存在令牌,则允许发送流量;而如果令牌桶中不存在令牌,则不允许发送流量。因此,如果突发门限被合理地配置并且令牌桶中有足够的令牌,那么流量就可以以峰值速率发送。

1.2 算法过程


算法描述:

  • 假如用户配置的平均发送速率为r,则每隔1/r秒一个令牌被加入到桶中(每秒会有r个令牌放入桶中);
  • 假设桶中最多可以存放b个令牌。如果令牌到达时令牌桶已经满了,那么这个令牌会被丢弃;
  • 当一个n个字节的数据包到达时,就从令牌桶中删除n个令牌(不同大小的数据包,消耗的令牌数量不一样),并且数据包被发送到网络;
  • 如果令牌桶中少于n个令牌,那么不会删除令牌,并且认为这个数据包在流量限制之外(n个字节,需要n个令牌。该数据包将被缓存或丢弃);
  • 算法允许最长b个字节的突发,但从长期运行结果看,数据包的速率被限制成常量r。对于在流量限制外的数据包可以以不同的方式处理:(1)它们可以被丢弃;(2)它们可以排放在队列中以便当令牌桶中累积了足够多的令牌时再传输;(3)它们可以继续发送,但需要做特殊标记,网络过载的时候将这些特殊标记的包丢弃。

2 实现原理

本限流器主要是基于令牌桶思想,并将令牌数存储到Redis中,实现在集群模式下对接口的精准限流,实现思想如下:

  • 对于每个限流接口,记录最大存储令牌数maxPermits, 当前存储令牌数storedPermits, 添加令牌时间间隔intervalMillis, 下次请求可以获取令牌的起始时间nextFreeTicketMillis,这些信息都记录在Redis中
  • 响应本次请求之后,动态计算下一次可以服务的时间,如果下一次请求在这个时间之前则需要进行等待。 nextFreeTicketMicros 记录下一次可以响应的时间。例如,如果我们设置QPS为1,本次请求处理完之后,那么下一次最早的能够响应请求的时间一秒钟之后。
  • 限流器支持处理突发流量请求,突发请求允许个数就是最大存储令牌数maxPermits。例如,我们设置QPS为1,在十秒钟之内没有请求,那么令牌桶中会有10个(假设设置的maxPermits为10)空闲令牌,如果下一次请求是 10个令牌,则可以一次性获取10个令牌,因为令牌桶中已经有10个空闲的令牌。 storedPermits 就是用来表示当前令牌桶中的空闲令牌数。
  • 对于令牌的产生有两种方式,一种是通过后台定时任务来不断产生令牌,一种是延迟生成,在每次获取令牌之前先计算在nextFreeTicketMillis到目前这个时间段内应该产生多少令牌,并更新令牌桶。本限流器采用的是后者。

3 具体实现

3.1 令牌桶

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/**
* Redis令牌桶
*/
@Data
public class RedisPermits implements Serializable {

private static final long serialVersionUID = 1L;
/**
* maxPermits 最大存储令牌数
*/
private Long maxPermits;
/**
* storedPermits 当前存储令牌数
*/
private Long storedPermits;
/**
* intervalMillis 添加令牌时间间隔
*/
private Long intervalMillis;
/**
* nextFreeTicketMillis 下次请求可以获取令牌的起始时间,默认当前系统时间
*/
private Long nextFreeTicketMillis;

/**
* @param permitsPerSecond 每秒放入的令牌数
* @param maxBurstSeconds maxPermits由此字段计算,最大存储maxBurstSeconds秒生成的令牌
*/
public RedisPermits(Double permitsPerSecond, Integer maxBurstSeconds) {
if (null == maxBurstSeconds) {
maxBurstSeconds = 60;
}
this.maxPermits = (long) (permitsPerSecond * maxBurstSeconds);
this.storedPermits = permitsPerSecond.longValue();
this.intervalMillis = (long) (TimeUnit.SECONDS.toMillis(1) / permitsPerSecond);
this.nextFreeTicketMillis = System.currentTimeMillis();
}

/**
* redis的过期时长
* @return
*/
public Long expires() {
long now = System.currentTimeMillis();
return 2 * TimeUnit.MINUTES.toSeconds(1)
+ TimeUnit.MILLISECONDS.toSeconds(Math.max(nextFreeTicketMillis, now) - now);
}

public Map<String, String> toMap() {
Map<String, String> resultMap = new HashMap<>();
resultMap.put("maxPermits", maxPermits.toString());
resultMap.put("storedPermits", storedPermits.toString());
resultMap.put("intervalMillis", intervalMillis.toString());
resultMap.put("nextFreeTicketMillis", nextFreeTicketMillis.toString());
return resultMap;
}

该类主要存储了令牌桶核心的四个参数

3.2 限流器

主要方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@Slf4j
@Data
public class RateLimiter {
/**
* 在超时时间内尝试获取{tokenCount}个令牌
* @param tokenCount
* @param timeout
* @param timeUnit
* @return
* @throws InterruptedException
*/
public boolean tryAcquire(Long tokenCount, Long timeout, TimeUnit timeUnit) throws InterruptedException{
if(checkTokens(tokenCount)) {
Long timeoutMillis = Math.max(timeUnit.toMillis(timeout), 0);
Long millisToWait = tryAndGetWaitTime(tokenCount, timeoutMillis);
if(millisToWait <= timeoutMillis) {
log.info("tryAcquire for {}ms {}", millisToWait, Thread.currentThread().getName());
Thread.sleep(millisToWait);
return true;
}
}
return false;
}

/**
* 等待直到获取指定数量的令牌
* @param tokenCount
* @return
* @throws InterruptedException
*/
public Long acquire(Long tokenCount) throws InterruptedException {
long milliToWait = this.reserve(tokenCount);
log.info("acquire for {}ms {}", milliToWait, Thread.currentThread().getName());
Thread.sleep(milliToWait);
return milliToWait;
}

/**
* 获取令牌n个需要等待的时间
* @param tokenCount
* @return
*/
private long reserve(Long tokenCount) {
if (checkTokens(tokenCount)) {
return reserveAndGetWaitTime(tokenCount);
} else {
return -1;
}
}

/**
* 预定@{tokenCount}个令牌并返回所需要等待的时间
* @param tokenCount
* @return
*/
private Long reserveAndGetWaitTime(Long tokenCount){
putDefaultPermits();
String script = "redis.replicate_commands() " +
"local redisKey = KEYS[1] " +
"local timeStrArray = redis.call('time') " +
"local seconds = tonumber(timeStrArray[1]) " +
"local microseconds = tonumber(timeStrArray[2]) " +
"local nowMilliseconds = seconds * 1000 + math.modf(microseconds/1000) " +
"local redisPermitsValues = redis.call('hmget', redisKey, 'nextFreeTicketMillis', 'maxPermits', 'storedPermits', 'intervalMillis') " +
"local nextFreeTicketMillis = tonumber(redisPermitsValues[1]) " +
"local maxPermits = tonumber(redisPermitsValues[2]) " +
"local storedPermits = tonumber(redisPermitsValues[3]) " +
"local intervalMillis = tonumber(redisPermitsValues[4]) " +
"if(nowMilliseconds > nextFreeTicketMillis) " +
"then " +
"storedPermits = math.min(maxPermits, storedPermits + math.modf((nowMilliseconds - nextFreeTicketMillis) / intervalMillis)) " +
"nextFreeTicketMillis = nowMilliseconds " +
"end " +
"local tokenCount = tonumber(ARGV[1]) " +
"local storedPermitsToSpend = math.min(tokenCount, storedPermits) " +
"local freshPermits = tokenCount - storedPermitsToSpend " +
"local waitMillis = freshPermits * intervalMillis " +
"nextFreeTicketMillis = nextFreeTicketMillis + waitMillis " +
"storedPermits = storedPermits - storedPermitsToSpend " +
"redis.call('hmset', redisKey, 'nextFreeTicketMillis', nextFreeTicketMillis, 'storedPermits', storedPermits) " +
"redis.call('expire', redisKey, 120) " +
"return nextFreeTicketMillis - nowMilliseconds";
List<String> keys = Collections.singletonList(key);
List<String> args = Collections.singletonList(tokenCount.toString());
Object obj = redisUtil.eval(script, keys, args);
Long result = null;
if(obj != null) {
result = (Long) obj;
}
return result;
}

/**
* 判断{timeout}时间内能否获取{tokenCount}令牌,如果能获取到则预定令牌
* @param tokenCount
* @return 需要等待时长
*/
private Long tryAndGetWaitTime(Long tokenCount, Long timeoutMillis) {
putDefaultPermits();
String script = "redis.replicate_commands() " +
"local redisKey = KEYS[1] " +
"local timeStrArray = redis.call('time') " +
"local seconds = tonumber(timeStrArray[1]) " +
"local microseconds = tonumber(timeStrArray[2]) " +
"local nowMilliseconds = seconds * 1000 + math.modf(microseconds/1000) " +
"local redisPermitsValues = redis.call('hmget', redisKey, 'nextFreeTicketMillis', 'maxPermits', 'storedPermits', 'intervalMillis') " +
"local nextFreeTicketMillis = tonumber(redisPermitsValues[1]) " +
"local maxPermits = tonumber(redisPermitsValues[2]) " +
"local storedPermits = tonumber(redisPermitsValues[3]) " +
"local intervalMillis = tonumber(redisPermitsValues[4]) " +
"if(nowMilliseconds > nextFreeTicketMillis) " +
"then " +
"storedPermits = math.min(maxPermits, storedPermits + math.modf((nowMilliseconds - nextFreeTicketMillis) / intervalMillis)) " +
"nextFreeTicketMillis = nowMilliseconds " +
"end " +
"local tokenCount = tonumber(ARGV[1]) " +
"local timeoutMillis = tonumber(ARGV[2]) " +
"local storedPermitsToSpend = math.min(tokenCount, storedPermits) " +
"local freshPermits = tokenCount - storedPermitsToSpend " +
"local waitMillis = freshPermits * intervalMillis " +
"local actualWaitMillis = nextFreeTicketMillis + waitMillis - nowMilliseconds " +
"if(actualWaitMillis <= timeoutMillis) " +
"then " +
"nextFreeTicketMillis = nextFreeTicketMillis + waitMillis " +
"storedPermits = storedPermits - storedPermitsToSpend " +
"redis.call('hmset', redisKey, 'nextFreeTicketMillis', nextFreeTicketMillis, 'storedPermits', storedPermits) " +
"redis.call('expire', redisKey, 120) " +
"end " +
"return actualWaitMillis";
List<String> keys = Collections.singletonList(key);
List<String> args = Arrays.asList(tokenCount.toString(), timeoutMillis.toString());
Object obj = redisUtil.eval(script, keys, args);
Long result = null;
if(obj != null) {
result = (Long) obj;
}
return result;
}
}

可以看到,限流器的主要方法是acquire和tryAcquire,前者是进行线程阻塞以等待令牌桶中达到所需令牌,后者是设定超时时间,并判断在超时时间内能否获取所需令牌,可以的话再进行线程阻塞等待令牌。获取由于存储在Redis中的令牌桶信息在集群环境下会有线程不同步问题,虽然采用Redis分布锁可以解决该问题,但是会造成线程阻塞,降低并发效率。而Redis运行lua脚本是原子性操作,因此本文采用lua脚本执行对令牌桶的计算和更新操作。可以看到核心方法reserveAndGetWaitTime和tryAndGetWaitTime方法都使用了lua脚本,下面简单讲解一下这两个方法的实现逻辑。

reserveAndGetWaitTime

  • 更新令牌桶,这一步操作就是上文讲到的延迟更新令牌
  • 计算所需令牌数与令牌桶中令牌数的插值,确定补全所需令牌数需要等待的时间
  • 取令牌并将令牌桶数据更新到Redis

tryAndGetWaitTime

  • 同样是先更新令牌桶
  • 计算所需令牌数与令牌桶中令牌数的插值,确定补全所需令牌数需要等待的时间
  • 判断等待的时间是否在超时时间内,如果是的话再取令牌将令牌桶数据更新到Redis