文章目录
- 1、demo结构
- 2、自定义接口
- 3、编写写LUA脚本
- 4、通过AOP切面识别需要限流的接口
- 编写切面
- AOP通知类型
- 5、Redis限流自定义异常构建
- Redis限流自定义异常
- 声明这个类为全局异常处理器
- 专属日志
- 6、流量限制器
- RateLimiter
- RateLimitAlg
- ApiLimit
- RateLimitRule
- RuleConfig
- 7、Guavalimit
- 限流自定义异常
- 限流key类型枚举
- 基于Guava cache缓存存储实现限流切面
- 8、测试控制层
- 9、测试结果
1、demo结构
2、自定义接口
通过自定义接口标注需要限流的接口
/**
* redis限流自定义注解
* @author zyw
*/
//注解的保留位置,RUNTIME表示这种类型的Annotations将被JVM保留,所以他们能在运行时被JVM或其他使用反射机制的代码所读取和使用。
@Retention(RetentionPolicy.RUNTIME)
//说明注解的作用目标,METHOD表示用来修饰方法
@Target({ElementType.METHOD})
//说明该注解将被包含在javadoc中
@Documented
public @interface RedisLimit {
/**
* 资源的key,唯一
* 作用:不同的接口,不同的流量控制
*/
String key() default "";
/**
* 最多的访问限制次数
*/
long permitsPerSecond() default 2;
/**
* 过期时间也可以理解为单位时间,单位秒,默认60
*/
long expire() default 60;
/**
* 得不到令牌的提示语
*/
String msg() default "系统繁忙,请稍后再试.";
}
3、编写写LUA脚本
通过Lua脚本动态实现动态的创建redis缓存
--获取KEY
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local curentLimit = tonumber(redis.call('get', key) or "0")
if curentLimit + 1 > limit
then return 0
else
-- 自增长 1
redis.call('INCRBY', key, 1)
-- 设置过期时间
redis.call('EXPIRE', key, ARGV[2])
return curentLimit + 1
end
4、通过AOP切面识别需要限流的接口
编写切面
- 1 定义一个类,该类添加了@Component、@Aspect注解
- 2 定义切点(切点定义方式可参考《Spring AOP配置 之 @PointCut注解》)
- 3 配置增强,给方法添加@Before、@After、@AfterReturning、@AfterThrowing、@Around等增强配置。
AOP通知类型
- @Around 环绕通知
- @Before 通知执行
- @Before 通知执行结束
- @Around 环绕通知执行结束
- @After 后置通知执行了
- @AfterReturning 第一个后置返回通知后执行
/**
* Limit AOP
*/
@Slf4j
@Aspect
@Component
public class RedisLimitAop {
@Autowired
private StringRedisTemplate stringRedisTemplate;
@Pointcut("@annotation(com.example.redislimit.aop.RedisLimit)")
private void check() {
}
private DefaultRedisScript<Long> redisScript;
@PostConstruct
public void init() {
redisScript = new DefaultRedisScript<>();
redisScript.setResultType(Long.class);
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("rateLimiter.lua")));
}
@Before("check()")
public void before(JoinPoint joinPoint) {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
// 请求对象
ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest servletRequest = sra.getRequest();
//拿到RedisLimit注解,如果存在则说明需要限流
RedisLimit redisLimit = method.getAnnotation(RedisLimit.class);
if (redisLimit != null) {
//获取redis的key
String key = redisLimit.key();
String className = method.getDeclaringClass().getName();
String name = method.getName();
String limitKey = key + className + method.getName();
log.info(limitKey);
if (StringUtils.isEmpty(key)) {
throw new RedisLimitException("key cannot be null");
}
long limit = redisLimit.permitsPerSecond();
long expire = redisLimit.expire();
List<String> keys = new ArrayList<>();
keys.add(key);
Long count = stringRedisTemplate.execute(redisScript, keys, String.valueOf(limit), String.valueOf(expire));
log.info("Access try count is {} for key={}", count, key);
if (count != null && count == 0) {
log.debug("获取key失败,key为{}", key);
throw new RedisLimitException(redisLimit.msg());
}
}
}
5、Redis限流自定义异常构建
Redis限流自定义异常
/**
* Redis限流自定义异常
* @date 2023/3/10 21:43
*/
public class RedisLimitException extends RuntimeException{
public RedisLimitException(String msg) {
super( msg );
}
}
声明这个类为全局异常处理器
@RestControllerAdvice// 声明这个类为全局异常处理器
public class GlobalExceptionHandler {
@ExceptionHandler(RedisLimitException.class) // 声明当前方法要处理的异常类型
public ResultInfo handlerCustomException(RedisLimitException e) {
//1. 打印日志
// e.printStackTrace();
//2. 给前端提示
return ResultInfo.error(e.getMessage());
}
//非预期异常 对于他们,我们直接捕获,捕获完了,记录日志, 给前端一个假提示
@ExceptionHandler(Exception.class)
public ResultInfo handlerException(Exception e) {
//1. 打印日志
e.printStackTrace();
//2. 给前端提示
return ResultInfo.error("当前系统异常");
}
}
专属日志
@Getter
@Setter
public class ResultInfo<T> {
private String message;
private String code;
private T data;
public ResultInfo(String message, String code, T data) {
this.message = message;
this.code = code;
this.data = data;
}
public static ResultInfo error(String message) {
return new ResultInfo(message,"502",null);
}
}
6、流量限制器
RateLimiter
public class RateLimiter {
private static final Logger log = LoggerFactory.getLogger(RateLimiter.class);
// 为每个api在内存中存储限流计数器
private ConcurrentHashMap<String, RateLimitAlg> counters = new ConcurrentHashMap<>();
private RateLimitRule rule;
public RateLimiter() {
// 将限流规则配置文件ratelimiter-rule.yaml中的内容读取到RuleConfig中
InputStream in = null;
RuleConfig ruleConfig = null;
try {
in = this.getClass().getResourceAsStream("/ratelimiter-rule.yaml");
if (in != null) {
Yaml yaml = new Yaml();
ruleConfig = yaml.loadAs(in, RuleConfig.class);
}
} finally {
if (in != null) {
try {
in.close();
} catch (IOException e) {
log.error("close file error:{}", e);
}
}
}
// 将限流规则构建成支持快速查找的数据结构RateLimitRule
this.rule = new RateLimitRule(ruleConfig);
}
public boolean limit(String appId, String url) throws Exception {
ApiLimit apiLimit = rule.getLimit(appId, url);
if (apiLimit == null) {
return true;
}
// 获取api对应在内存中的限流计数器(rateLimitCounter)
String counterKey = appId + ":" + apiLimit.getApi();
RateLimitAlg rateLimitCounter = counters.get(counterKey);
if (rateLimitCounter == null) {
RateLimitAlg newRateLimitCounter = new RateLimitAlg(apiLimit.getLimit());
rateLimitCounter = counters.putIfAbsent(counterKey, newRateLimitCounter);
if (rateLimitCounter == null) {
rateLimitCounter = newRateLimitCounter;
}
}
// 判断是否限流
return rateLimitCounter.tryAcquire();
}
}
RateLimitAlg
public class RateLimitAlg {
/* timeout for {@code Lock.tryLock() }. */
private static final long TRY_LOCK_TIMEOUT = 200L; // 200ms.
private Stopwatch stopwatch;
private AtomicInteger currentCount = new AtomicInteger(0);
private final int limit;
private Lock lock = new ReentrantLock();
public RateLimitAlg(int limit) {
this(limit, Stopwatch.createStarted());
}
@VisibleForTesting
protected RateLimitAlg(int limit, Stopwatch stopwatch) {
this.limit = limit;
this.stopwatch = stopwatch;
}
public boolean tryAcquire() throws Exception {
int updatedCount = currentCount.incrementAndGet();
if (updatedCount <= limit) {
return true;
}
try {
if (lock.tryLock(TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS)) {
try {
if (stopwatch.elapsed(TimeUnit.MILLISECONDS) > TimeUnit.SECONDS.toMillis(1)) {
currentCount.set(0);
stopwatch.reset();
}
updatedCount = currentCount.incrementAndGet();
return updatedCount <= limit;
} finally {
lock.unlock();
}
} else {
throw new Exception("tryAcquire() wait lock too long:" + TRY_LOCK_TIMEOUT + "ms");
}
} catch (InterruptedException e) {
throw new Exception("tryAcquire() is interrupted by lock-time-out.", e);
}
}
}
ApiLimit
public class ApiLimit {
private static final int DEFAULT_TIME_UNIT = 1; // 1 second
private String api;
private int limit;
private int unit = DEFAULT_TIME_UNIT;
public ApiLimit() {}
public ApiLimit(String api, int limit) {
this(api, limit, DEFAULT_TIME_UNIT);
}
public ApiLimit(String api, int limit, int unit) {
this.api = api;
this.limit = limit;
this.unit = unit;
}
public String getApi() {
return api;
}
public void setApi(String api) {
this.api = api;
}
public int getLimit() {
return limit;
}
public void setLimit(int limit) {
this.limit = limit;
}
public int getUnit() {
return unit;
}
public void setUnit(int unit) {
this.unit = unit;
}
}
RateLimitRule
public class RateLimitRule {
public RateLimitRule(RuleConfig ruleConfig) {
//...
}
public ApiLimit getLimit(String appId, String api) {
return null;
}
}
RuleConfig
public class RuleConfig {
private List<AppRuleConfig> configs;
public List<AppRuleConfig> getConfigs() {
return configs;
}
public void setConfigs(List<AppRuleConfig> configs) {
this.configs = configs;
}
public static class AppRuleConfig {
private String appId;
private List<ApiLimit> limits;
public AppRuleConfig() {}
public AppRuleConfig(String appId, List<ApiLimit> limits) {
this.appId = appId;
this.limits = limits;
}
public String getAppId() {
return appId;
}
public void setAppId(String appId) {
this.appId = appId;
}
public List<ApiLimit> getLimits() {
return limits;
}
public void setLimits(List<ApiLimit> limits) {
this.limits = limits;
}
}
}
7、Guavalimit
限流自定义异常
/**
* @Description 限流自定义异常
* @Author zyw
* @Date 2019/8/7 16:01
*/
public class LimitAccessException extends RuntimeException {
private static final long serialVersionUID = -3608667856397125671L;
public LimitAccessException(String message) {
super(message);
}
}
限流key类型枚举
/**
* @Description 限流key类型枚举
* @Author zyw
* @Date 2020/5/17 14:28
*/
public enum LimitKeyTypeEnum {
IPADDR("IPADDR", "根据Ip地址来限制"),
CUSTOM("CUSTOM", "自定义根据业务唯一码来限制,需要在请求参数中添加 String limitKeyValue");
private String keyType;
private String desc;
LimitKeyTypeEnum(String keyType, String desc) {
this.keyType = keyType;
this.desc = desc;
}
public String getKeyType() {
return keyType;
}
public String getDesc() {
return desc;
}
}
自定义限流注解
/**
* @Description 自定义限流注解
* @Author zyw
* @Date 2020/5/17 11:49
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LxRateLimit {
//资源名称
String name() default "默认资源";
//限制每秒访问次数,默认为3次
double perSecond() default 3;
/**
* 限流Key类型
* 自定义根据业务唯一码来限制需要在请求参数中添加 String limitKeyValue
*/
LimitKeyTypeEnum limitKeyType() default LimitKeyTypeEnum.IPADDR;
}
基于Guava cache缓存存储实现限流切面
/**
* @Description 基于Guava cache缓存存储实现限流切面
* @Author 张佑威
* @Date 2020/5/17 11:51
*/
@Slf4j
@Aspect
@Component
public class LxRateLimitAspect {
/**
* 缓存
* maximumSize 设置缓存个数
* expireAfterWrite 写入后过期时间
*/
private static LoadingCache<String, RateLimiter> limitCaches = CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, TimeUnit.DAYS)
.build(new CacheLoader<String, RateLimiter>() {
@Override
public RateLimiter load(String key) throws Exception {
double perSecond = LxRateLimitUtil.getCacheKeyPerSecond(key);
return RateLimiter.create(perSecond);
}
});
/**
* 切点
* 通过扫包切入 @Pointcut("execution(public * com.ycn.springcloud.*.*(..))")
* 带有指定注解切入 @Pointcut("@annotation(com.ycn.springcloud.annotation.LxRateLimit)")
*/
@Pointcut("@annotation(com.example.guavalimit.limit.LxRateLimit)")
public void pointcut() {
}
@Around("pointcut()")
public Object around(ProceedingJoinPoint point) throws Throwable {
log.info("限流拦截到了{}方法...", point.getSignature().getName());
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
if (method.isAnnotationPresent(LxRateLimit.class)) {
String cacheKey = LxRateLimitUtil.generateCacheKey(method, request);
RateLimiter limiter = limitCaches.get(cacheKey);
if (!limiter.tryAcquire()) {
throw new LimitAccessException("【限流】这位小同志的手速太快了");
}
}
return point.proceed();
}
}
限流工具类
/**
* @Description 限流工具类
* @Author zyw
* @Date 2020/5/17 15:37
*/
public class LxRateLimitUtil {
/**
* 获取唯一key根据注解类型
* <p>
* 规则 资源名:业务key:perSecond
*
* @param method
* @param request
* @return
*/
public static String generateCacheKey(Method method, HttpServletRequest request) {
//获取方法上的注解
LxRateLimit lxRateLimit = method.getAnnotation(LxRateLimit.class);
StringBuffer cacheKey = new StringBuffer(lxRateLimit.name() + ":");
switch (lxRateLimit.limitKeyType()) {
case IPADDR:
cacheKey.append(getIpAddr(request) + ":");
break;
case CUSTOM:
String limitKeyValue = request.getParameter("limitKeyValue");
if (StringUtils.isEmpty(limitKeyValue)) {
throw new LimitAccessException("【" + method.getName() + "】自定义业务Key缺少参数String limitKeyValue,或者参数为空");
}
cacheKey.append(limitKeyValue + ":");
break;
}
cacheKey.append(lxRateLimit.perSecond());
return cacheKey.toString();
}
/**
* 获取缓存key的限制每秒访问次数
* <p>
* 规则 资源名:业务key:perSecond
*
* @param cacheKey
* @return
*/
public static double getCacheKeyPerSecond(String cacheKey) {
String perSecond = cacheKey.split(":")[2];
return Double.parseDouble(perSecond);
}
/**
* 获取客户端IP地址
*
* @param request 请求
* @return
*/
public static String getIpAddr(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if ("0:0:0:0:0:0:0:1".equals(ip)) {
ip = "127.0.0.1";
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
if ("127.0.0.1".equals(ip)) {
//根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ip = inet.getHostAddress();
}
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
if (ip != null && ip.length() > 15) {
if (ip.indexOf(",") > 0) {
ip = ip.substring(0, ip.indexOf(","));
}
}
return ip;
}
}
8、测试控制层
@RestController
public class TestController {
@GetMapping("/test")
public String getTest(){
return "jxj";
}
@GetMapping("/guavalimit")
@LxRateLimit
public String guavaLimit(){
return "ok";
}
@GetMapping("/redislimit")
@RedisLimit(key = "redis-limit:test", permitsPerSecond = 2, expire = 1, msg = "当前排队人数较多,请稍后再试!")
public String redisLimit(){
return "ok";
}
}
9、测试结果
Redis+LUA脚本