diff --git a/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/aspect/RateLimiterAspect.java b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/aspect/RateLimiterAspect.java new file mode 100644 index 000000000..e372b3596 --- /dev/null +++ b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/aspect/RateLimiterAspect.java @@ -0,0 +1,98 @@ +package com.xkcoding.ratelimit.redis.aspect; + +import cn.hutool.core.util.StrUtil; +import com.xkcoding.ratelimit.redis.annotation.RateLimiter; +import com.xkcoding.ratelimit.redis.util.IpUtil; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.stereotype.Component; + +import java.lang.reflect.Method; +import java.time.Instant; +import java.util.Collections; +import java.util.concurrent.TimeUnit; + +/** + *

+ * 限流切面 + *

+ * + * @author yangkai.shen + * @date Created in 2019/9/30 10:30 + */ +@Slf4j +@Aspect +@Component +@RequiredArgsConstructor(onConstructor_ = @Autowired) +public class RateLimiterAspect { + private final static String SEPARATOR = ":"; + private final static String REDIS_LIMIT_KEY_PREFIX = "limit:"; + private final StringRedisTemplate stringRedisTemplate; + private final RedisScript limitRedisScript; + + @Pointcut("@annotation(com.xkcoding.ratelimit.redis.annotation.RateLimiter)") + public void rateLimit() { + + } + + @Around("rateLimit()") + public Object pointcut(ProceedingJoinPoint point) throws Throwable { + MethodSignature signature = (MethodSignature) point.getSignature(); + Method method = signature.getMethod(); + // 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解 + RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class); + if (rateLimiter != null) { + String key = rateLimiter.key(); + // 默认用方法名做限流的 key 前缀 + if (StrUtil.isBlank(key)) { + key = method.getName(); + } + // 最终限流的 key 为 前缀 + IP地址 + // TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理 + key = key + SEPARATOR + IpUtil.getIpAddr(); + + long max = rateLimiter.max(); + long timeout = rateLimiter.timeout(); + TimeUnit timeUnit = rateLimiter.timeUnit(); + boolean limited = shouldLimited(key, max, timeout, timeUnit); + if (limited) { + throw new RuntimeException("手速太快了,慢点儿吧~"); + } + } + + return point.proceed(); + } + + private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) { + // 最终的 key 格式为: + // limit:自定义key:IP + // limit:方法名:IP + key = REDIS_LIMIT_KEY_PREFIX + key; + // 统一使用单位毫秒 + long ttl = timeUnit.toMillis(timeout); + // 当前时间毫秒数 + long now = Instant.now().toEpochMilli(); + long expired = now - ttl; + // 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String + Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + ""); + if (executeTimes != null) { + if (executeTimes == 0) { + log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max); + return true; + } else { + log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes); + return false; + } + } + return false; + } +} diff --git a/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/config/RedisConfig.java b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/config/RedisConfig.java new file mode 100644 index 000000000..d155b582e --- /dev/null +++ b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/config/RedisConfig.java @@ -0,0 +1,28 @@ +package com.xkcoding.ratelimit.redis.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.ClassPathResource; +import org.springframework.data.redis.core.script.DefaultRedisScript; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.scripting.support.ResourceScriptSource; + +/** + *

+ * Redis 配置 + *

+ * + * @author yangkai.shen + * @date Created in 2019/9/30 11:37 + */ +@Configuration +public class RedisConfig { + @Bean + @SuppressWarnings("unchecked") + public RedisScript limitRedisScript() { + DefaultRedisScript redisScript = new DefaultRedisScript<>(); + redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua"))); + redisScript.setResultType(Long.class); + return redisScript; + } +} diff --git a/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/util/IpUtil.java b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/util/IpUtil.java new file mode 100644 index 000000000..ff2fd39ba --- /dev/null +++ b/spring-boot-demo-ratelimit-redis/src/main/java/com/xkcoding/ratelimit/redis/util/IpUtil.java @@ -0,0 +1,59 @@ +package com.xkcoding.ratelimit.redis.util; + +import cn.hutool.core.util.StrUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import javax.servlet.http.HttpServletRequest; + +/** + *

+ * IP 工具类 + *

+ * + * @author yangkai.shen + * @date Created in 2019/9/30 10:38 + */ +@Slf4j +public class IpUtil { + private final static String UNKNOWN = "unknown"; + private final static int MAX_LENGTH = 15; + + /** + * 获取IP地址 + * 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址 + * 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址 + */ + public static String getIpAddr() { + HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); + String ip = null; + try { + ip = request.getHeader("x-forwarded-for"); + if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { + ip = request.getHeader("Proxy-Client-IP"); + } + if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { + ip = request.getHeader("WL-Proxy-Client-IP"); + } + if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { + ip = request.getHeader("HTTP_CLIENT_IP"); + } + if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { + ip = request.getHeader("HTTP_X_FORWARDED_FOR"); + } + if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) { + ip = request.getRemoteAddr(); + } + } catch (Exception e) { + log.error("IPUtils ERROR ", e); + } + // 使用代理,则获取第一个IP地址 + if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) { + if (ip.indexOf(StrUtil.COMMA) > 0) { + ip = ip.substring(0, ip.indexOf(StrUtil.COMMA)); + } + } + return ip; + } +} diff --git a/spring-boot-demo-ratelimit-redis/src/main/resources/application.yml b/spring-boot-demo-ratelimit-redis/src/main/resources/application.yml index af5002ce9..43382fcd2 100644 --- a/spring-boot-demo-ratelimit-redis/src/main/resources/application.yml +++ b/spring-boot-demo-ratelimit-redis/src/main/resources/application.yml @@ -2,3 +2,20 @@ server: port: 8080 servlet: context-path: /demo +spring: + redis: + host: localhost + # 连接超时时间(记得添加单位,Duration) + timeout: 10000ms + # Redis默认情况下有16个分片,这里配置具体使用的分片 + # database: 0 + lettuce: + pool: + # 连接池最大连接数(使用负值表示没有限制) 默认 8 + max-active: 8 + # 连接池最大阻塞等待时间(使用负值表示没有限制) 默认 -1 + max-wait: -1ms + # 连接池中的最大空闲连接 默认 8 + max-idle: 8 + # 连接池中的最小空闲连接 默认 0 + min-idle: 0 diff --git a/spring-boot-demo-ratelimit-redis/src/main/resources/scripts/redis/limit.lua b/spring-boot-demo-ratelimit-redis/src/main/resources/scripts/redis/limit.lua new file mode 100644 index 000000000..4658052c1 --- /dev/null +++ b/spring-boot-demo-ratelimit-redis/src/main/resources/scripts/redis/limit.lua @@ -0,0 +1,27 @@ +-- 下标从 1 开始 +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local ttl = tonumber(ARGV[2]) +local expired = tonumber(ARGV[3]) +-- 最大访问量 +local max = tonumber(ARGV[4]) + +-- 清除过期的数据 +-- 移除指定分数区间内的所有元素,expired 即已经过期的 score +-- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired +redis.call('zremrangebyscore', key, 0, expired) +-- 获取 zset 中的元素个数 +local current = tonumber(redis.call('zcard', key)) + +-- +local next = current + 1 +if next > max then + -- 达到限流大小 返回 0 + return 0; +else + -- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score] + redis.call("zadd", key, now, now) + -- 每次访问均重新设置 zset 的过期时间,单位毫秒 + redis.call("pexpire", key, ttl) + return next +end