在给spring webflux做接口签名、防重放的时候,往往需要获取请求参数,请求方法等,而spring webflux无法像spring mvc那样好获取,这里根据之前的实践特地说明一下:
总体思路:
1、利用过滤器,从原request中获取到信息后,缓存在一个上下文对象中,然后构造新的request,传入后面的过滤器。因为原request流式的,用过一次后便无法再取参数了。
2、通过exchange的Attributes传递上下文对象,在不同的过滤器中使用即可。
1、上下文对象
@Getter
@Setter
@ToString
public class GatewayContext {
public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";
/**
* cache requestMethod
*/
private String requestMethod;
/**
* cache queryParams
*/
private MultiValueMap<String, String> queryParams;
/**
* cache json body
*/
private String requestBody;
/**
* cache Response Body
*/
private Object responseBody;
/**
* request headers
*/
private HttpHeaders requestHeaders;
/**
* cache form data
*/
private MultiValueMap<String, String> formData;
/**
* cache all request data include:form data and query param
*/
private MultiValueMap<String, String> allRequestData = new LinkedMultiValueMap<>(0);
private byte[] requestBodyBytes;
}
2、在过滤器中获取请求参数、请求方法。
这里我们只对application/json
、application/x-www-form-urlencoded
这种做body参数拦截,而对于其他的请求,则可以通过url直接获取到query参数。
@Slf4j
@Component
public class GatewayContextFilter implements WebFilter, Ordered {
/**
* default HttpMessageReader
*/
private static final List<HttpMessageReader<?>> MESSAGE_READERS = HandlerStrategies.withDefaults().messageReaders();
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
GatewayContext gatewayContext = new GatewayContext();
HttpHeaders headers = request.getHeaders();
gatewayContext.setRequestHeaders(headers);
gatewayContext.getAllRequestData().addAll(request.getQueryParams());
gatewayContext.setRequestMethod(request.getMethodValue().toUpperCase());
gatewayContext.setQueryParams(request.getQueryParams());
/*
* save gateway context into exchange
*/
exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
MediaType contentType = headers.getContentType();
if (headers.getContentLength() > 0) {
if (MediaType.APPLICATION_JSON.equals(contentType)) {
return readBody(exchange, chain, gatewayContext);
}
if (MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {
return readFormData(exchange, chain, gatewayContext);
}
}
String path = request.getPath().value();
if (!"/".equals(path)) {
log.info("{} Gateway context is set with {}-{}", path, contentType, gatewayContext);
}
return chain.filter(exchange);
}
@Override
public int getOrder() {
return Integer.MIN_VALUE + 1;
}
/**
* ReadFormData
*/
private Mono<Void> readFormData(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {
HttpHeaders headers = exchange.getRequest().getHeaders();
return exchange.getFormData()
.doOnNext(multiValueMap -> {
gatewayContext.setFormData(multiValueMap);
gatewayContext.getAllRequestData().addAll(multiValueMap);
log.debug("[GatewayContext]Read FormData Success");
})
.then(Mono.defer(() -> {
Charset charset = headers.getContentType().getCharset();
charset = charset == null ? StandardCharsets.UTF_8 : charset;
String charsetName = charset.name();
MultiValueMap<String, String> formData = gatewayContext.getFormData();
/*
* formData is empty just return
*/
if (null == formData || formData.isEmpty()) {
return chain.filter(exchange);
}
log.info("1. Gateway Context formData: {}", formData);
StringBuilder formDataBodyBuilder = new StringBuilder();
String entryKey;
List<String> entryValue;
try {
/*
* repackage form data
*/
for (Map.Entry<String, List<String>> entry : formData.entrySet()) {
entryKey = entry.getKey();
entryValue = entry.getValue();
if (entryValue.size() > 1) {
for (String value : entryValue) {
formDataBodyBuilder
.append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
.append("=")
.append(URLEncoder.encode(value, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
.append("&");
}
} else {
formDataBodyBuilder
.append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
.append("=")
.append(URLEncoder.encode(entryValue.get(0), charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
.append("&");
}
}
} catch (UnsupportedEncodingException e) {
log.error("GatewayContext readFormData error {}", e.getMessage(), e);
}
/*
* 1. substring with the last char '&'
* 2. if the current request is encrypted, substring with the start chat 'secFormData'
*/
String formDataBodyString = "";
String originalFormDataBodyString = "";
if (formDataBodyBuilder.length() > 0) {
formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);
originalFormDataBodyString = formDataBodyString;
}
/*
* get data bytes
*/
byte[] bodyBytes = formDataBodyString.getBytes(charset);
int contentLength = bodyBytes.length;
gatewayContext.setRequestBodyBytes(originalFormDataBodyString.getBytes(charset));
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(exchange.getRequest().getHeaders());
httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);
/*
* in case of content-length not matched
*/
httpHeaders.setContentLength(contentLength);
/*
* use BodyInserter to InsertFormData Body
*/
BodyInserter<String, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromObject(formDataBodyString);
CachedBodyOutputMessage cachedBodyOutputMessage = new CachedBodyOutputMessage(exchange, httpHeaders);
log.info("2. GatewayContext Rewrite Form Data :{}", formDataBodyString);
return bodyInserter.insert(cachedBodyOutputMessage, new BodyInserterContext())
.then(Mono.defer(() -> {
ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(
exchange.getRequest()) {
@Override
public HttpHeaders getHeaders() {
return httpHeaders;
}
@Override
public Flux<DataBuffer> getBody() {
return cachedBodyOutputMessage.getBody();
}
};
return chain.filter(exchange.mutate().request(decorator).build());
}));
}));
}
/**
* ReadJsonBody
*/
private Mono<Void> readBody(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {
return DataBufferUtils.join(exchange.getRequest().getBody())
.flatMap(dataBuffer -> {
/*
* read the body Flux<DataBuffer>, and release the buffer
* when SpringCloudGateway Version Release To G.SR2,this can be update with the new version's feature
* see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
*/
byte[] bytes = new byte[dataBuffer.readableByteCount()];
dataBuffer.read(bytes);
DataBufferUtils.release(dataBuffer);
gatewayContext.setRequestBodyBytes(bytes);
Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
DataBufferUtils.retain(buffer);
return Mono.just(buffer);
});
/*
* repackage ServerHttpRequest
*/
ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return cachedFlux;
}
};
ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
return ServerRequest.create(mutatedExchange, MESSAGE_READERS)
.bodyToMono(String.class)
.doOnNext(objectValue -> {
gatewayContext.setRequestBody(objectValue);
if (objectValue != null && !objectValue.trim().startsWith("{")) {
return;
}
try {
gatewayContext.getAllRequestData().setAll(JsonUtil.fromJson(objectValue, Map.class));
} catch (Exception e) {
log.warn("Gateway context Read JsonBody error:{}", e.getMessage(), e);
}
}).then(chain.filter(mutatedExchange));
});
}
}
3、签名、防重放校验
这里我们从上下文对象中取出参数即可
签名算法逻辑:
@Slf4j
@Component
public class GatewaySignCheckFilter implements WebFilter, Ordered {
@Value("${api.rest.prefix}")
private String apiPrefix;
@Autowired
private RedisUtil redisUtil;
//前后端约定签名密钥
private static final String API_SECRET = "secret-xxx";
@Override
public int getOrder() {
return Integer.MIN_VALUE + 2;
}
@NotNull
@Override
public Mono<Void> filter(ServerWebExchange exchange, @NotNull WebFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
String uri = request.getURI().getPath();
GatewayContext gatewayContext = (GatewayContext) exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);
HttpHeaders headers = gatewayContext.getRequestHeaders();
MediaType contentType = headers.getContentType();
log.info("check url:{},method:{},contentType:{}", uri, gatewayContext.getRequestMethod(), contentType == null ? "" : contentType.toString());
//如果contentType为空,只能是get请求
if (contentType == null || StringUtils.isBlank(contentType.toString())) {
if (request.getMethod() != HttpMethod.GET) {
throw new RuntimeException("非法访问");
}
checkSign(uri, gatewayContext, exchange);
} else {
if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {
checkSign(uri, gatewayContext, exchange);
}
}
return chain.filter(exchange);
}
private void checkSign(String uri, GatewayContext gatewayContext, ServerWebExchange exchange) {
//忽略掉的请求
List<String> ignores = Lists.newArrayList("/open/**", "/open/login/params", "/open/image");
for (String ignore : ignores) {
ignore = apiPrefix + ignore;
if (uri.equals(ignore) || uri.startsWith(ignore.replace("/**", "/"))) {
log.info("check sign ignore:{}", uri);
return;
}
}
String method = gatewayContext.getRequestMethod();
log.info("start check sign {}-{}", method, uri);
HttpHeaders headers = gatewayContext.getRequestHeaders();
log.info("headers:{}", JsonUtils.objectToJson(headers));
String clientId = getHeaderAttr(headers, SystemSign.CLIENT_ID);
String timestamp = getHeaderAttr(headers, SystemSign.TIMESTAMP);
String nonce = getHeaderAttr(headers, SystemSign.NONCE);
String sign = getHeaderAttr(headers, SystemSign.SIGN);
checkTime(timestamp);
checkOnce(nonce);
String headerStr = String.format("%s=%s&%s=%s&%s=%s", SystemSign.CLIENT_ID, clientId,
SystemSign.NONCE, nonce, SystemSign.TIMESTAMP, timestamp);
String signSecret = API_SECRET;
String queryUri = uri + getQueryParam(gatewayContext.getQueryParams());
log.info("headerStr:{},signSecret:{},queryUri:{}", headerStr, signSecret, queryUri);
String realSign = calculatorSign(clientId, queryUri, gatewayContext, headerStr, signSecret);
log.info("sign:{}, realSign:{}", sign, realSign);
if (!realSign.equals(sign)) {
log.warn("wrong sign");
throw new RuntimeException("Illegal sign");
}
}
private String getQueryParam(MultiValueMap<String, String> queryParams) {
if (queryParams == null || queryParams.size() == 0) {
return StringUtils.EMPTY;
}
StringBuilder builder = new StringBuilder("?");
for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
String key = entry.getKey();
List<String> value = entry.getValue();
builder.append(key).append("=").append(value.get(0)).append("&");
}
builder.deleteCharAt(builder.length() - 1);
return builder.toString();
}
private String getHeaderAttr(HttpHeaders headers, String key) {
List<String> values = headers.get(key);
if (CollectionUtils.isEmpty(values)) {
log.warn("GatewaySignCheckFilter empty header:{}", key);
throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
}
String value = values.get(0);
if (StringUtils.isBlank(value)) {
log.warn("GatewaySignCheckFilter empty header:{}", key);
throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
}
return value;
}
private String calculatorSign(String clientId, String queryUri, GatewayContext gatewayContext, String headerStr, String signSecret) {
String method = gatewayContext.getRequestMethod();
byte[] bodyBytes = gatewayContext.getRequestBodyBytes();
if (bodyBytes == null) {
//空白的md5固定为:d41d8cd98f00b204e9800998ecf8427e
bodyBytes = new byte[]{};
}
String bodyMd5 = UaaSignUtils.getMd5(bodyBytes);
String ori = String.format("%s\n%s\n%s\n%s\n%s\n", method, clientId, headerStr, queryUri, bodyMd5);
log.info("clientId:{},signSecret:{},headerStr:{},bodyMd5:{},queryUri:{},ori:{}", clientId, signSecret, headerStr, bodyMd5, queryUri, ori);
return UaaSignUtils.sha256HMAC(ori, signSecret);
}
private void checkOnce(String nonce) {
if (StringUtils.isBlank(nonce)) {
log.warn("GatewaySignCheckFilter checkOnce Illegal");
}
String key = "api:auth:" + nonce;
int fifteenMin = 60 * 15 * 1000;
Boolean succ = redisUtil.setNxWithExpire(key, "1", fifteenMin);
if (succ == null || !succ) {
log.warn("GatewaySignCheckFilter checkOnce Repeat");
throw new RuntimeException("checkOnce Repeat");
}
}
private void checkTime(String timestamp) {
long time;
try {
time = Long.parseLong(timestamp);
} catch (Exception ex) {
log.error("GatewaySignCheckFilter checkTime error:{}", ex.getMessage(), ex);
throw new RuntimeException("checkTime error");
}
long now = DateTimeUtil.now();
log.info("now: {}, time: {}", DateTimeUtil.millsToStr(now), DateTimeUtil.millsToStr(time));
int fiveMinutes = 60 * 5 * 1000;
long duration = now - time;
if (duration > fiveMinutes || (-duration) > fiveMinutes) {
log.warn("GatewaySignCheckFilter checkTime Late");
throw new RuntimeException("checkTime Late");
}
}
public interface SystemSign {
/**
* 客户端ID:固定值,由后端给前端颁发约定
*/
String CLIENT_ID = "client-id";
/**
* 客户端计算出的签名
*/
String SIGN = "sign";
/**
* 时间戳
*/
String TIMESTAMP = "timestamp";
/**
* 唯一值
*/
String NONCE = "nonce";
}
}