Gateway网关自定义拦截器的不可重复读取数据
特别注意一点, 因为在网关层 拿出 request 流之后,必须重写getbody()方法把所有的参数放进去,否则后面转发的请求无法接收到任何数据,
坑,巨坑,因为版本问题网上很多都不能兼容,
我的springboot环境 依赖包
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.2.8.RELEASE</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<!-- gateway版本号, 必须对应,此版本已经包含 web包 -->
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-gateway</artifactId>
<version>2.1.0.RELEASE</version>
</dependency>
<!-- servlet 验证post请求需要重写request流 -->
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
<version>2.5</version>
<scope>provided</scope>
</dependency>
需求描述:前端发起请求的参数携带sign=xxxx,后台验证签名是够正确
sign签名生成规则:
1.将post请求的body转成jsonstring (按照body里key的自然升序排列),
2 get请求的话, 把所有参数都进行排序,生成sign与前端传来的值进行验证对比
下面是非对称加密算法工 具
ApiAuthAES 类工具
package com.platform.platformgateway.util;
import org.springframework.util.StringUtils;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.spec.SecretKeySpec;
import java.math.BigInteger;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
/**
* AES的加密和解密
*
* @author wxq
*/
public class ApiAuthAES {
// 密钥
// private static final String KEY = "c542384322662d446b2302faf2ab3737";
// 算法
private static final String ALGORITHMSTR = "AES/ECB/PKCS5Padding";
/**
* 将byte[]转为各种进制的字符串
*
* @param bytes byte[]
* @param radix 可以转换进制的范围,从Character.MIN_RADIX到Character.MAX_RADIX,超出范围后变为10进制
* @return 转换后的字符串
*/
public static String binary(byte[] bytes, int radix) {
return new BigInteger(1, bytes).toString(radix);
}
/**
* base 64 encode
*
* @param bytes 待编码的byte[]
* @return 编码后的base 64 code
*/
public static String base64Encode(byte[] bytes) {
Base64.Encoder encoder = Base64.getEncoder();
return encoder.encodeToString(bytes);
}
/**
* base 64 decode
*
* @param base64Code 待解码的base 64 code
* @return 解码后的byte[]
* @throws Exception
*/
public static byte[] base64Decode(String base64Code) throws Exception {
Base64.Decoder decoder = Base64.getDecoder();
return StringUtils.isEmpty(base64Code) ? null : decoder.decode(base64Code);
}
/**
* AES加密
*
* @param content 待加密的内容
* @param encryptKey 加密密钥
* @return 加密后的byte[]
* @throws Exception
*/
public static byte[] aesEncryptToBytes(String content, String encryptKey) throws Exception {
KeyGenerator kgen = KeyGenerator.getInstance("AES");
kgen.init(128);
Cipher cipher = Cipher.getInstance(ALGORITHMSTR);
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(encryptKey.getBytes(), "AES"));
return cipher.doFinal(content.getBytes("utf-8"));
}
/**
* AES加密为base 64 code
*
* @param content 待加密的内容
* @param encryptKey 加密密钥
* @return 加密后的base 64 code
* @throws Exception
*/
public static String aesEncrypt(String content, String encryptKey) throws Exception {
return base64Encode(aesEncryptToBytes(content, encryptKey));
}
/**
* AES解密
*
* @param encryptBytes 待解密的byte[]
* @param decryptKey 解密密钥
* @return 解密后的String
* @throws Exception
*/
public static String aesDecryptByBytes(byte[] encryptBytes, String decryptKey) throws Exception {
KeyGenerator kgen = KeyGenerator.getInstance("AES");
kgen.init(128);
Cipher cipher = Cipher.getInstance(ALGORITHMSTR);
cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(decryptKey.getBytes(), "AES"));
byte[] decryptBytes = cipher.doFinal(encryptBytes);
return new String(decryptBytes);
}
/**
* 将base 64 code AES解密
*
* @param encryptStr 待解密的base 64 code
* @param decryptKey 解密密钥
* @return 解密后的string
* @throws Exception
*/
public static String aesDecrypt(String encryptStr, String decryptKey) throws Exception {
return StringUtils.isEmpty(encryptStr) ? null : aesDecryptByBytes(base64Decode(encryptStr), decryptKey);
}
public static String urlEncode(String str) {
try {
return URLEncoder.encode(str, "GBK");
} catch (Exception e) {
e.printStackTrace();
}
return "";
}
public static String urlDncode(String str) throws Exception {
return URLDecoder.decode(str, "GBK");
}
public static Map<String, String> URLRequestParamMap(String strUrlParam) {
Map<String, String> mapRequest = new HashMap<String, String>();
String[] arrSplit = null;
if (StringUtils.isEmpty(strUrlParam)) {
return mapRequest;
}
arrSplit = strUrlParam.split("[&]");
for (String strSplit : arrSplit) {
String[] arrSplitEqual = null;
arrSplitEqual = strSplit.split("[=]");
//解析出键值
if (arrSplitEqual.length > 1) {
//正确解析
mapRequest.put(arrSplitEqual[0], arrSplitEqual[1]);
} else {
if (arrSplitEqual[0] != "") {
//只有参数没有值,不加入
mapRequest.put(arrSplitEqual[0], "");
}
}
}
return mapRequest;
}
public static void main(String[] args) throws Exception {
Long time = System.currentTimeMillis();
System.out.println("signature:"+aesEncrypt(System.currentTimeMillis()+"","WYEB77T")); ;
System.out.println(time); ;
Map map = new HashMap();
map.put("flag", "system");
map.put("dateStr", "2022-09-01");
map.put("time", time);
System.out.println("signature:" + SignUtil.createSign(map, "WYEBHWgS"));
Map map1 = new HashMap();
map1.put("startPage", 0);
map1.put("pageSize", 10);
//不能嵌套签名,无法解析
/* Map map2 = new HashMap();
map2.put("app_id","app_id");
map1.put("conditions",map2);
*/
map1.put("time", time);
System.out.println("signature:" + SignUtil.createSign(map1, "WYEB7qf9O1lg"));
}
}
RequestWrapper 处理body流对象封装,这里使用缓存流处理, 是避免body请求数据太多,导致截取失败的问题
RequestWrapper.java
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.io.IOException;
import java.io.InputStream;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.atomic.AtomicReference;
/**
* @Classname TestReQuest
* @Description TODO
* @Date 2023/5/24 17:32
* @Created by mingcai
*/
@Slf4j
@Component
public class RequestWrapper {
@Autowired
private ObjectMapper objectMapper;
public String getRequestBodyDoctor(Flux<DataBuffer> bodys) {
String resultJson = "";
Flux<DataBuffer> body3 = bodys;
InputStreamHolder holder = new InputStreamHolder();
body3.subscribe(buffer -> holder.inputStream = buffer.asInputStream());
if (null != holder.inputStream) {
// 解析JSON的节点
JsonNode jsonNode = null;
try {
jsonNode = objectMapper.readTree(holder.inputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
Assert.isTrue(jsonNode instanceof ObjectNode, "JSON格式异常");
ObjectNode objectNode = (ObjectNode) jsonNode;
// JSON节点最外层写入新的属性
//objectNode.put("userId", "accessToken");
// DataBuffer dataBuffer = dataBufferFactory.allocateBuffer();
String json = objectNode.toString();
log.info("最终的JSON数据为:{}", json);
// this.setBodyString(json);
return json;
//dataBuffer.write(json.getBytes(StandardCharsets.UTF_8));
//Flux.just(dataBuffer);
}
return resultJson;
}
private class InputStreamHolder {
InputStream inputStream;
}
//也是取不到值
public static String resolveBodyFromRequest(Flux<DataBuffer> body){
AtomicReference<String> bodyRef = new AtomicReference<>();
// 缓存读取的request body信息
body.subscribe(dataBuffer -> {
CharBuffer charBuffer = StandardCharsets.UTF_8.decode(dataBuffer.asByteBuffer());
DataBufferUtils.release(dataBuffer);
bodyRef.set(charBuffer.toString());
});
//获取request body
return bodyRef.get();
}
}
重点来了 ,全局的过滤器, 我把重点画出来了, 至少坑了一天时间,各种尝试都无果,
最终就封装成下面这样, 返回参数无法是一个
ServerHttpRequest, 导致我一直不能有效获取参数,每次读取之后,无法重新放入流中
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.platform.platformgateway.config.db.ConfigBean;
import com.platform.platformgateway.config.redis.BaseRedisCache;
import com.platform.platformgateway.constant.CommonConstants;
import com.platform.platformgateway.util.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.*;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import javax.annotation.Resource;
import java.net.URI;
import java.util.*;
/**
* @author mingcai
*/
@Slf4j
@Configuration
public class AccessGatewayFilter implements GlobalFilter, Ordered {
@Resource
private BaseRedisCache redisCache;
@Resource
private ConfigBean configBean;
@Resource
private ThreadPoolTaskExecutor asyncExecutor;
@Resource
RequestWrapper requestWrapper;
@Override
public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest originalRequest = exchange.getRequest();
String token = originalRequest.getHeaders().getFirst("token");
String appId = originalRequest.getHeaders().getFirst("appId");
String dateTime = "";// = originalRequest.getHeaders().getFirst("time");
String signature = originalRequest.getHeaders().getFirst("signature");
ServerHttpResponse response = exchange.getResponse();
URI originalRequestUrl = originalRequest.getURI();
String path = originalRequest.getPath().toString();
log.info("请求路径:" + originalRequestUrl + " service Path:" + path + " 访问IP: " + originalRequestUrl.getHost());
if (StringUtils.isNotBlank(path) && path.toLowerCase().contains(CommonConstants.V2_PATH)) {
token = redisCache.get(CommonConstants.V2_TOKEN_CACHE);
}
if (StringUtils.isNotBlank(path) && path.toLowerCase().contains(CommonConstants.V3_PATH)) {
token = redisCache.get(CommonConstants.V3_TOKEN_CACHE);
}
// 特殊处理
/* if (path.contains(CommonConstants.V2_FIND_STUDENT_URL)) {
ServerWebExchange build = getServerWebExchange(exchange, token);
return chain.filter(build);
}*/
if (StringUtils.isBlank(signature)) {
response.setStatusCode(HttpStatus.UNAUTHORIZED);
log.debug("signature: 为空! ");
return returnJson(response, "用户的signature参数不能为空!");
}
String secret = (String) redisCache.get_obj(CommonConstants.PLATFORM_CACHE + appId);
if (StringUtils.isBlank(secret)) {
return returnJson(response, "用户的appId错误,请核验后重试");
}
String method = String.valueOf(originalRequest.getMethod());
String contentType = originalRequest.getHeaders().getFirst("Content-Type");
try {
if ("GET".equals(method)) {
dateTime = originalRequest.getQueryParams().getFirst("time");
if (!doGet(originalRequest, secret, signature)) {
return returnJson(response, "get请求参数验证失败");
}
} else if ("POST".equals(method) && !Objects.requireNonNull(contentType).startsWith("multipart/form-data")) {
//当body中没有缓存时,只会执行这一个拦截器, 原因是fileMap中的代码没有执行,所以需要在波多野为空时构建一个空的缓存
DefaultDataBufferFactory defaultDataBufferFactory = new DefaultDataBufferFactory();
DefaultDataBuffer defaultDataBuffer = defaultDataBufferFactory.allocateBuffer(0);
// mediaType
Flux<DataBuffer> bodyDataBuffer = exchange.getRequest().getBody().defaultIfEmpty(defaultDataBuffer);
String finalToken = token;
return DataBufferUtils.join(bodyDataBuffer).flatMap(dataBuffer -> {
DataBufferUtils.retain(dataBuffer);
Flux<DataBuffer> cachedFlux = Flux.defer(() -> Flux.just(dataBuffer.slice(0, dataBuffer.readableByteCount())));
ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return cachedFlux;
}
};
String json = requestWrapper.getRequestBodyDoctor(cachedFlux);
JSONObject jb = JSONObject.parseObject(json);
if (null == jb) {
return returnJson(response, "post请求参数为空");
}
String time = String.valueOf(jb.getOrDefault("time", ""));
log.info("str: " + time);
if (!doPost(jb, secret, signature)) {
return returnJson(response, "post请求参数验证失败");
}
Mono<Void> verifyMono = verifyUser(exchange, exchange.getResponse(), exchange.getRequest(), time);
if (null != verifyMono) {
return verifyMono;
}
ServerWebExchange newexchange = exchange.mutate().request(mutatedRequest).build();
return chain.filter(getServerWebExchange(newexchange, finalToken, null));
});
} else {
return returnJson(response, "不支持的请求方式,仅支持(GET,POST)");
}
} catch (Exception e) {
// throw new RuntimeException(e);
log.error(" 签名错误! {}", e.getMessage());
return returnJson(response, method + " 签名错误");
}
Mono<Void> verifyMono = verifyUser(exchange, exchange.getResponse(), exchange.getRequest(), dateTime);
if (null != verifyMono) {
return verifyMono;
}
ServerWebExchange build = getServerWebExchange(exchange, token, originalRequest);
return chain.filter(build);
}
private static ServerWebExchange getServerWebExchange(ServerWebExchange exchange, String token,ServerHttpRequest serverHttpRequest) {
ServerHttpRequest host = exchange.getRequest().mutate()
.header("Authorization", token)
.header("userTokenHead",token)
.header("businessId", "10000000100001")
.header("serviceGroup", "sky")
.build();
return exchange.mutate().request(host).build();
}
@Override
public int getOrder() {
return -200;
}
private String authorUser(ServerHttpRequest originalRequest) {
String appId = originalRequest.getHeaders().getFirst("appId");
String path = originalRequest.getPath().toString();
if (StringUtils.isBlank(appId)) {
return "用户的appId和secret不能为空";
}
String secret = (String) redisCache.get_obj(CommonConstants.PLATFORM_CACHE + appId);
if (StringUtils.isBlank(secret)) {
return "用户的appId错误,请核验后重试";
}
// 用户权限接口
if (null == redisCache.get_obj(CommonConstants.PLATFORM_CACHE + appId + path)) {
return "用户" + appId + "无此接口权限";
}
return "";
}
/**
* GET请求
*/
public Boolean doGet(ServerHttpRequest request,String secret,String sign) throws Exception {
//从request获取到所有的参数及其值
String queryParams = request.getQueryParams().toString();
if (queryParams == null) {
return false;
}
MultiValueMap<String, String> pNames = request.getQueryParams();
Map<String, Object> map = new HashMap<>();
for (String entry:pNames.keySet()){
Object pValue = request.getQueryParams().getFirst(entry);
map.put(entry, pValue);
}
String newSign = SignUtil.createSign(map,secret);
if (!newSign.equals(sign)) {
//returnJson(response, "get: signature 签名错误");
return false;
}
return true;
}
public Boolean doPost(JSONObject json, String secret, String sign) {
try {
// String sign = (String) json.get("signature");
String newSign = SignUtil.createSign(json, secret);
if (!Objects.equals(newSign, sign)) {
return false;
}
} catch (Exception e) {
//throw new RuntimeException(e);
log.error("参数请求错误: {}", e.getMessage());
}
return true;
}
public Mono<Void> verifyUser(ServerWebExchange exchange,ServerHttpResponse response,ServerHttpRequest serverHttpRequest,String dateTime) {
String path = serverHttpRequest.getPath().toString();
String appId = serverHttpRequest.getHeaders().getFirst("appId");
if (StringUtils.isBlank(dateTime)) {
response.setStatusCode(HttpStatus.UNAUTHORIZED);
log.debug("用户的time时间戳参数不能为空");
return returnJson(response, "用户的time时间戳参数不能为空!");
}
Long time = Long.valueOf(dateTime);
Long nowTime = System.currentTimeMillis();
//5分钟
if (nowTime - time > 300000) {
response.setStatusCode(HttpStatus.UNAUTHORIZED);
log.debug("用户的time时间戳过期");
return returnJson(response, "用户的time时间戳过期!");
}
// 从request对象中获取客户端ip
String clientIp = Objects.requireNonNull(serverHttpRequest.getRemoteAddress()).getHostString();
String authUserMsg = authorUser(serverHttpRequest);
if (StringUtils.isNotBlank(authUserMsg)) {
//401 网关层没有权限
response.setStatusCode(HttpStatus.UNAUTHORIZED);
return returnJson(response, authUserMsg);
} else {
//记录日志
try {
asyncExecutor.execute(() -> {
JSONObject json = new JSONObject();
Map params = serverHttpRequest.getQueryParams();
json.put("path", path);
json.put("clientIp", clientIp);
json.put("params", params);
json.put("appId", appId);
OkHttpClientUtils.postJsonParams(configBean.getPlatformUrl() + "/hoe/platform/log/save", json.toString());
});
} catch (Exception e) {
log.error("保存日志错误: {} ", e.getMessage());
}
}
return null;
}
/**
* 错误信息响应到客户端
* @param mes response
* @date: 2023/5/6 14:13
*/
private Mono<Void> returnJson(ServerHttpResponse response, String mes) {
log.info(mes);
JSONObject json = new JSONObject();
json.put("msg", mes);
String message = JSON.toJSONString(json);
response.getHeaders().add("Content-Type", "application/json;charset=UTF-8");
DataBuffer dataBuffer = response.bufferFactory().wrap(message.getBytes());
return response.writeWith(Flux.just(dataBuffer));
}
}
对于这种封装方式,我是不太满意的,
return DataBufferUtils.join(bodyDataBuffer).flatMap(dataBuffer -> {} 这部分应该重新定义,然后统一放到后面处理的, 目前不知道怎么拆分,如果有知道的大佬,麻烦分享一下了!
下面是排序验签的工具类
import java.util.Comparator;
import java.util.Map;
import java.util.TreeMap;
/**
* @Classname SignUtil
* @Description TODO
* @Date 2023/5/23 19:00
* @Created by mingcai
*/
public class SignUtil {
public static String createSign(Map<String, Object> originMap, String secret) throws Exception {
if (originMap == null) {
return null;
}
originMap = sortMapByKey(originMap);
StringBuilder originStr = new StringBuilder();
for (Map.Entry<String, Object> entry : originMap.entrySet()) {
originStr.append(entry.getKey()).append("=").append(entry.getValue());
originStr.append("&");
}
return ApiAuthAES.aesEncrypt(String.valueOf(originStr), secret);
}
public static Map<String, Object> sortMapByKey(Map<String, Object> map) {
/*
对Map对象的key升序(a->z)排列
*/
if (map == null || map.isEmpty()) {
return null;
}
Map<String, Object> sortMap = new TreeMap<>(Comparator.naturalOrder());
sortMap.putAll(map);
return sortMap;
}
}