package com.common.aspect;
|
|
import com.common.annotation.RequestLimit;
|
import com.common.core.enums.ResultCodeEnum;
|
import com.common.core.exception.BizException;
|
import com.common.core.utils.DateUtils;
|
import com.common.redis.util.RedisUtil;
|
import com.common.security.utils.IpUtil;
|
import lombok.extern.slf4j.Slf4j;
|
import org.aspectj.lang.ProceedingJoinPoint;
|
import org.aspectj.lang.annotation.Around;
|
import org.aspectj.lang.annotation.Aspect;
|
import org.aspectj.lang.annotation.Pointcut;
|
import org.aspectj.lang.reflect.MethodSignature;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.core.annotation.Order;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.CollectionUtils;
|
import org.springframework.web.context.request.RequestContextHolder;
|
import org.springframework.web.context.request.ServletRequestAttributes;
|
|
import javax.servlet.http.HttpServletRequest;
|
import java.lang.reflect.Method;
|
import java.security.SecureRandom;
|
import java.util.Set;
|
|
@Aspect
|
@Component
|
@Order(1)
|
@Slf4j
|
public class RequestLimitAspect {
|
|
@Autowired
|
private RedisUtil redisUtil;
|
|
private static final String REQ_LIMIT = "req_limit:%s:%s:";
|
private static final String REQ_LIMIT_FREQUENCY = "req_limit_frequency_%s_%s";
|
|
/**
|
* 定义拦截规则:拦截com.springboot.bcode.api包下面的所有类中,有@RequestLimit Annotation注解的方法
|
* 。
|
*/
|
@Pointcut("@within(org.springframework.web.bind.annotation.RestController) ")
|
public void pointcut() {
|
}
|
|
@Around("pointcut()")
|
public Object method(ProceedingJoinPoint joinPoint) throws Throwable {
|
Object[] args = joinPoint.getArgs();
|
|
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
|
Method method = signature.getMethod(); // 获取被拦截的方法
|
RequestLimit limt = method.getAnnotation(RequestLimit.class);
|
int time,count,size;
|
if (limt == null) {
|
time = RequestLimit.DEFAULT_TIME;
|
count = RequestLimit.DEFAULT_COUNT;
|
size = RequestLimit.DEFAULT_SIZE;
|
}else{
|
time = limt.time();
|
count = limt.count();
|
size = limt.size();
|
}
|
HttpServletRequest request = ((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest();
|
|
|
String ip = IpUtil.getIpAddr(request);
|
String url = request.getRequestURI();
|
|
// judge condition
|
//大小限制
|
if(request.getContentLength() > size*1024*1024){
|
throw new BizException(ResultCodeEnum.LARGE_REQUEST_ERROR);
|
}
|
//频次限制
|
String key = String.format(REQ_LIMIT, url, ip);
|
Set<String> valueList = redisUtil.getKeys(key + "*");
|
//将有效的过滤出来,过期时间在当前时间后的
|
if (!CollectionUtils.isEmpty(valueList) && count>-1 && valueList.size() >= count) {
|
throw new BizException(ResultCodeEnum.BUSY_REQUEST_ERROR);
|
}
|
redisUtil.set(key+ new SecureRandom().nextLong(), DateUtils.now().getTime(),time);
|
return joinPoint.proceed();
|
}
|
}
|