1 公共部分
1.1 请求、响应对象
@Data
public class RpcRequest {
private String serviceName;
private String methodName;
private Class<?>[] parameterTypes;
private Object[] parameters;
}
@Data
public class RpcResponse {
private int code;
private String msg;
private Object data;
private String ex;
}
1.2 rpc协议
@Data
public class RpcProtocol {
private int length;
private byte[] content;
}
1.3 简易注册中心,保存服务名和地址的映射
public class ServiceRegister {
private Map<String, List<String>> register = new HashMap<>();
public ServiceRegister() {
register.put(RpcService.class.getName(),
new ArrayList<>(List.of("localhost:1733")));
}
public List<String> findService(String serviceName) {
return register.get(serviceName);
}
}
1.4 rpc上下文,用来获取单例的ServiceRegister
public class RpcContext {
public static ServiceRegister register() {
return RpcRegisterHodler.REGISTER;
}
private static class RpcRegisterHodler {
private static final ServiceRegister REGISTER = new ServiceRegister();
}
}
1.7 帧解码器
// 帧解码器,要配置在ChannelPipeline的第一个,这样才能解决入站数据的粘包和半包
public class RpcFrameDecoder extends LengthFieldBasedFrameDecoder {
public RpcFrameDecoder() {
super(1024, 0, 4);
}
}
// rpc协议的编解码器
public class RpcProtocolCodec extends ByteToMessageCodec<RpcProtocol> {
// 将rpc协议对象编码成字节流
@Override
protected void encode(ChannelHandlerContext ctx, RpcProtocol msg,
ByteBuf out) throws Exception {
out.writeInt(msg.getLength());
out.writeBytes(msg.getContent());
}
// 将字节流解码成rpc协议对象
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in,
List<Object> out) throws Exception {
int length = in.readInt();
byte[] content = new byte[length];
in.readBytes(content);
RpcProtocol protocol = new RpcProtocol();
protocol.setLength(length);
protocol.setContent(content);
out.add(protocol);
}
}
// rpc请求对象的编解码器
public class RpcRequestCodec extends MessageToMessageCodec<RpcProtocol,
RpcRequest> {
// 将请求对象编码成rpc协议对象
@Override
protected void encode(ChannelHandlerContext ctx, RpcRequest msg,
List<Object> out) throws Exception {
byte[] content = JSON.toJSONBytes(msg);
int length = content.length;
RpcProtocol rpcProtocol = new RpcProtocol();
rpcProtocol.setLength(length);
rpcProtocol.setContent(content);
out.add(rpcProtocol);
}
// 将rpc协议对象解码成请求对象
@Override
protected void decode(ChannelHandlerContext ctx, RpcProtocol msg,
List<Object> out) throws Exception {
RpcRequest request = JSON.parseObject(msg.getData(),
RpcRequest.class,
JSONReader.Feature.SupportClassForName);
out.add(request);
}
}
// rpc响应对象的编解码器
public class RpcResponseCodec extends MessageToMessageCodec<RpcProtocol,
RpcResponse> {
// 将响应对象编码成rpc协议对象
@Override
protected void encode(ChannelHandlerContext ctx, RpcResponse msg,
List<Object> out) throws Exception {
byte[] content = JSON.toJSONBytes(msg);
int length = content.length;
RpcProtocol rpcProtocol = new RpcProtocol();
rpcProtocol.setLength(length);
rpcProtocol.setContent(content);
out.add(rpcProtocol);
}
// 将rpc协议对象解码成响应对象
@Override
protected void decode(ChannelHandlerContext ctx, RpcProtocol msg,
List<Object> out) throws Exception {
RpcResponse response = JSON.parseObject(msg.getContent(), RpcResponse.class);
out.add(response);
}
}
1.6 服务接口
public interface RpcService {
String hello(String name);
}
2 服务端
2.1 接口实现类
@Slf4j
public class RpcServiceImpl implements RpcService {
@Override
public String hello(String name) {
log.info("service received: {} ", name);
return "hello " + name;
}
}
2.2 接口名和实现类的对象映射,通过接口名查找对应的实现类对象
public class ServiceMapping {
private Map<String, RpcService> mappings = new HashMap<>();
public ServiceMapping() {
mappings.put(RpcService.class.getName(), new RpcServiceImpl());
}
public void registerMapping(String serviceName, RpcService service) {
mappings.put(serviceName, service);
}
public RpcService findMapping(String serviceName) {
return mappings.get(serviceName);
}
}
2.2 服务端rpc上下文,用来获取单例的ServiceMapping
public class RpcServerContext {
public static ServiceMapping mapping() {
return RpcMappingrHodler.MAPPING;
}
private static class RpcMappingrHodler {
private static final ServiceMapping MAPPING = new ServiceMapping();
}
}
2.3 业务处理器handler
@Slf4j
public class RpcServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
RpcRequest request = (RpcRequest) msg;
RpcResponse response = invoke(request);
ctx.writeAndFlush(response);
}
private RpcResponse invoke(RpcRequest request) {
RpcResponse response = new RpcResponse();
try {
ServiceMapping register = RpcServerContext.mapping();
RpcService rpcService = register.findMapping(
request.getServiceName());
String methodName = request.getMethodName();
Class<?>[] parameterTypes = request.getParameterTypes();
Object[] parameters = request.getParameters();
// invoke
Method method = RpcService.class.getDeclaredMethod(methodName,
parameterTypes);
Object result = method.invoke(rpcService, parameters);
//
response.setCode(200);
response.setMsg("ok");
response.setData(result);
} catch (Exception e) {
response.setCode(500);
response.setMsg("error");
response.setEx(e.getMessage());
}
return response;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.info("channelInactive :{}", ctx.channel().remoteAddress());
ctx.close();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
log.error("exceptionCaught :{}", ctx.channel().remoteAddress(), cause);
ctx.close();
}
}
2.4 启动类
public class RpcServer {
public static void main(String[] args) {
NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
NioEventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ChannelFuture channelFuture = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new RpcFrameDecoder());
ch.pipeline().addLast(new RpcProtocolCodec());
ch.pipeline().addLast(new RpcRequestCodec());
ch.pipeline().addLast(new RpcResponseCodec());
// ch.pipeline().addLast(new LoggingHandler());
ch.pipeline().addLast(new RpcServerHandler());
}
}).bind(1733);
channelFuture.sync();
channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
3 客户端
3.2 客户端rpc上下文,用来处理channel的响应数据
public class RpcClientContext {
private Map<Channel, Promise<Object>> promises = new HashMap<>();
public Promise<Object> getPromise(Channel channel) {
return promises.remove(channel);
}
public void setPromise(Channel channel, Promise<Object> promise) {
promises.put(channel, promise);
}
}
3.2 业务处理器handler
@Slf4j
public class RpcClientHandler extends ChannelInboundHandlerAdapter {
private final RpcClientContext rpcClientContext;
public RpcClientHandler(RpcClientContext rpcClientContext) {
this.rpcClientContext = rpcClientContext;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
log.info("rpc invoke response: {}", msg);
RpcResponse response = (RpcResponse) msg;
//
Promise<Object> promise = rpcClientContext.getPromise(ctx.channel());
//
if (response.getEx() != null)
promise.setFailure(new RuntimeException(response.getEx()));
else
promise.setSuccess(response.getData());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.info("channelInactive :{}", ctx.channel().remoteAddress());
ctx.close();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
log.error("exceptionCaught :{}", ctx.channel().remoteAddress(), cause);
ctx.close();
}
}
3.3 启动类
@Slf4j
public class RpcClient {
private final Map<String, NioSocketChannel> nioSocketChannels =
new HashMap<>();
private final RpcClientContext rpcClientContext = new RpcClientContext();
public RpcService rpcService() {
String serviceName = RpcService.class.getName();
List<String> services = RpcContext.register().findService(serviceName);
String url = services.get(0);
if (!nioSocketChannels.containsKey(url)) {
NioSocketChannel nioSocketChannel = createNioSocketChannel(url);
nioSocketChannels.put(url, nioSocketChannel);
log.info("create a new channel: {}", nioSocketChannel);
}
final NioSocketChannel nioSocketChannel = nioSocketChannels.get(url);
return (RpcService) Proxy.newProxyInstance(RpcClient.class
.getClassLoader(), new Class[]{RpcService.class},
(proxy, method, args) -> {
RpcRequest request = new RpcRequest();
request.setServiceName(RpcService.class.getName());
request.setMethodName(method.getName());
request.setParameterTypes(method.getParameterTypes());
request.setParameters(args);
nioSocketChannel.writeAndFlush(request);
// wait response
DefaultPromise<Object> promise =
new DefaultPromise<>(nioSocketChannel.eventLoop());
rpcClientContext.setPromise(nioSocketChannel, promise);
promise.await();
if (!promise.isSuccess())
throw new RuntimeException(promise.cause());
return promise.getNow();
});
}
private NioSocketChannel createNioSocketChannel(String url) {
//
String host = url.substring(0, url.indexOf(":"));
int port = Integer.parseInt(url.substring(url.indexOf(":") + 1));
//
EventLoopGroup group = new NioEventLoopGroup();
try {
ChannelFuture channelFuture = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new RpcFrameDecoder());
ch.pipeline().addLast(new RpcProtocolCodec());
ch.pipeline().addLast(new RpcResponseCodec());
ch.pipeline().addLast(new RpcRequestCodec());
// ch.pipeline().addLast(new LoggingHandler());
ch.pipeline().addLast(new
RpcClientHandler(rpcClientContext));
}
}).connect(host, port);
channelFuture.sync();
channelFuture.channel().closeFuture().addListener(future -> {
nioSocketChannels.remove(RpcService.class.getName());
group.shutdownGracefully();
});
//
return (NioSocketChannel) channelFuture.channel();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
private void close() {
nioSocketChannels.values().forEach(NioSocketChannel::close);
}
public static void main(String[] args) {
RpcClient rpcClient = new RpcClient();
RpcService rpcService = rpcClient.rpcService();
String netty = rpcService.hello("netty");
System.out.println(netty);
String world = rpcService.hello("world");
System.out.println(world);
String java = rpcService.hello("java");
System.out.println(java);
rpcClient.close();
}
}
4 总结
这样就实现了简单的rpc服务,通过公共部分的接口、注册中心、编解码器、服务端的服务映射,客户端就能进行远程过程调用了。