SpringBoot实现限流注解
在高并发系统中,保护系统的三种方式分别为:缓存,降级和限流。
限流的目的是通过对并发访问请求进行限速或者一个时间窗口内的的请求数量进行限速来保护系统,一旦达到限制速率则可以拒绝服务、排队或等待
1、限流类型枚举类
/**
* 限流类型
* @author ss_419
*/
public enum LimitType {
/**
* 默认的限流策略,针对某一个接口进行限流
*/
DEFAULT,
/**
* 针对某一个IP进行限流
*/
IP
}
2、自定义限流注解
/**
* @author ss_419
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/**
* 限流的 key,主要是指前缀
* @return
*/
String key() default "rate_limit:";
/**
* 在时间窗内的限流次数
* @return
*/
int count() default 100;
/**
* 限流类型
* @return
*/
LimitType limitType() default LimitType.DEFAULT;
/**
* 限流时间窗
* @return
*/
int time() default 60;
}
3、限流lua脚本
1、由于我们使用 Redis 进行限流,我们需要引入 Redis 的 maven 依赖,同时需要引入 aop 的依赖
<!-- aop依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- redis依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
2、配置redis以及lua脚本
@Configuration
public class RedisConfig {
@Bean
RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory factory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
template.setKeySerializer(jackson2JsonRedisSerializer);
template.setValueSerializer(jackson2JsonRedisSerializer);
template.setHashKeySerializer(jackson2JsonRedisSerializer);
template.setHashValueSerializer(jackson2JsonRedisSerializer);
return template;
}
/**
* 读取lua脚本
* @return
*/
@Bean
DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
}
通过 Lua 脚本,根据 Redis 中缓存的键值判断限流时间(也是 key 的过期时间)内,访问次数是否超出了限流次数,没超出则访问次数 +1,返回 true,超出了则返回 false。
limit.lua:
local key = KEYS[1]
local time = tonumber(ARGV[1])
local count = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
4、限流切面处理类
1、使用我们刚刚的 Lua 脚本判断是否超出了限流次数,超出了限流次数后返回一个自定义异常,然后在全局异常中去捕捉异常,返回 JSON 数据。
2、根据注解参数,判断限流类型,拼接缓存 key 值
package org.pp.ratelimiter.aspectj;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.pp.ratelimiter.annotation.RateLimiter;
import org.pp.ratelimiter.enums.LimitType;
import org.pp.ratelimiter.exception.RateLimitException;
import org.pp.ratelimiter.utils.IpUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.lang.reflect.Method;
import java.util.Collections;
@Aspect
@Component
public class RateLimiterAspect {
private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);
@Autowired
RedisTemplate<Object, Object> redisTemplate;
@Autowired
RedisScript<Long> redisScript;
@Before("@annotation(rateLimiter)")
public void before(JoinPoint jp, RateLimiter rateLimiter) throws RateLimitException {
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, jp);
try {
Long number = redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
if (number == null || number.intValue() > count) {
//超过限流阈值
logger.info("当前接口以达到最大限流次数");
throw new RateLimitException("访问过于频繁,请稍后访问");
}
logger.info("一个时间窗内请求次数:{},当前请求次数:{},缓存的 key 为 {}", count, number, combineKey);
} catch (Exception e) {
throw e;
}
}
/**
* 这个 key 其实就是接口调用次数缓存在 redis 的 key
* rate_limit:11.11.11.11-org.javaboy.ratelimit.controller.HelloController-hello
* rate_limit:org.javaboy.ratelimit.controller.HelloController-hello
* @param rateLimiter
* @param jp
* @return
*/
private String getCombineKey(RateLimiter rateLimiter, JoinPoint jp) {
StringBuffer key = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
key.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()))
.append("-");
}
MethodSignature signature = (MethodSignature) jp.getSignature();
Method method = signature.getMethod();
key.append(method.getDeclaringClass().getName())
.append("-")
.append(method.getName());
return key.toString();
}
}
5、使用与测试
@RestController
public class HelloController {
/**
* 限流 10 秒之内,这个接口可以访问3次
* @return
*/
@GetMapping("/hello")
@RateLimiter(time = 10,count = 3)
public Map<String, Object> hello() {
Map<String, Object> map = new HashMap<>();
map.put("status", 200);
map.put("message", "Hello RateLimiter");
return map;
}
}
十秒之内访问次数超过3次就会报异常
redis中的数据,每一次访问都加1
当访问次数超过3,则进行限流操作