在后续一段时间里, 我会写一系列文章来讲述如何实现一个RPC框架(我已经实现了一个示例框架, 代码在我的github上)。 这是系列第四篇文章, 主要讲述了客户端和服务器之间的网络通信问题。
模型定义
我们需要自己来定义RPC通信所传递的内容的模型, 也就是RPCRequest和RPCResponse。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 1@Data
2@Builder
3public class RPCRequest {
4 private String requestId;
5 private String interfaceName;
6 private String methodName;
7 private Class<?>[] parameterTypes;
8 private Object[] parameters;
9}
10
11@Data
12public class RPCResponse {
13 private String requestId;
14 private Exception exception;
15 private Object result;
16
17 public boolean hasException() {
18 return exception != null;
19 }
20}
21
这里唯一需要说明一下的是requestId, 你可能会疑惑为什么我们需要这个东西。
原因是,发送请求的顺序和收到返回的顺序可能是不一致的, 因此我们需要有一个标识符来表明某一个返回所对应的请求是什么。 具体怎么利用这个字段, 本文后续会揭晓。
选择NIO还是IO?
NIO和IO的选择要视具体情况而定。对于我们的RPC框架来说, 一个服务可能与多个服务保持连接, 且每次通信只发送少量信息,那么在这种情况下,NIO可能更适合一些。
我选择使用Netty来简化具体的实现, 自然地,我们就引入了Channel, Handler这些相关的概念。如果对Netty没有任何了解, 建议先去简单了解下相关内容再回过头看这篇文章。
如何复用Channel
既然使用了NIO, 我们自然希望服务和服务之间是使用长连接进行通信, 而不是每个请求都重新创建一个channel。
那么我们怎么去复用channel呢? 既然我们已经通过前文的服务发现获取到了service地址,并且与其建立了channel, 那么我们自然就可以建立一个service地址与channel之间的映射关系, 每次拿到地址之后先判断有没有对应channel, 如果有的话就复用。这种映射关系我建立了ChannelManager去管理:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 1public class ChannelManager {
2 /**
3 * Singleton
4 */
5 private static ChannelManager channelManager;
6
7 private ChannelManager(){}
8
9 public static ChannelManager getInstance() {
10 if (channelManager == null) {
11 synchronized (ChannelManager.class) {
12 if (channelManager == null) {
13 channelManager = new ChannelManager();
14 }
15 }
16 }
17 return channelManager;
18 }
19
20 // Service地址与channel之间的映射
21 private Map<InetSocketAddress, Channel> channels = new ConcurrentHashMap<>();
22
23 public Channel getChannel(InetSocketAddress inetSocketAddress) {
24 Channel channel = channels.get(inetSocketAddress);
25 if (null == channel) {
26 EventLoopGroup group = new NioEventLoopGroup();
27 try {
28 Bootstrap bootstrap = new Bootstrap();
29 bootstrap.group(group)
30 .channel(NioSocketChannel.class)
31 .handler(new RPCChannelInitializer())
32 .option(ChannelOption.SO_KEEPALIVE, true);
33
34 channel = bootstrap.connect(inetSocketAddress.getHostName(), inetSocketAddress.getPort()).sync()
35 .channel();
36 registerChannel(inetSocketAddress, channel);
37
38 channel.closeFuture().addListener(new ChannelFutureListener() {
39 @Override
40 public void operationComplete(ChannelFuture future) throws Exception {
41 removeChannel(inetSocketAddress);
42 }
43 });
44 } catch (Exception e) {
45 log.warn("Fail to get channel for address: {}", inetSocketAddress);
46 }
47 }
48 return channel;
49 }
50
51 private void registerChannel(InetSocketAddress inetSocketAddress, Channel channel) {
52 channels.put(inetSocketAddress, channel);
53 }
54
55 private void removeChannel(InetSocketAddress inetSocketAddress) {
56 channels.remove(inetSocketAddress);
57 }
58
59}
60
有几个地方需要解释一下:
- 这里用单例的目的是, 所有的proxybean都使用同一个ChannelManager。
- 创建Channel的过程很简单,就是最普通的Netty客户端创建channel的方法。
- 在channel被关闭(比如服务器端宕机了)后,需要从map中删除对应的channel
- RPCChannelInitializer是整个过程的核心所在, 用于处理请求和返回的编解码、 收到返回之后的回调等。 下文详细说这个。
编解码
上文的RPCChannelInitializer代码如下:
1
2
3
4
5
6
7
8
9
10
11 1private class RPCChannelInitializer extends ChannelInitializer<SocketChannel> {
2
3 @Override
4 protected void initChannel(SocketChannel ch) throws Exception {
5 ChannelPipeline pipeline = ch.pipeline();
6 pipeline.addLast(new RPCEncoder(RPCRequest.class, new ProtobufSerializer()));
7 pipeline.addLast(new RPCDecoder(RPCResponse.class, new ProtobufSerializer()));
8 pipeline.addLast(new RPCResponseHandler()); //先不用管这个
9 }
10 }
11
这里的Encoder和Decoder都很简单, 继承了Netty中的codec,做一些简单的byte数组和Object对象之间的转换工作:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39 1@AllArgsConstructor
2public class RPCDecoder extends ByteToMessageDecoder {
3
4 private Class<?> genericClass;
5 private Serializer serializer;
6
7 @Override
8 public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
9 if (in.readableBytes() < 4) {
10 return;
11 }
12 in.markReaderIndex();
13 int dataLength = in.readInt();
14 if (in.readableBytes() < dataLength) {
15 in.resetReaderIndex();
16 return;
17 }
18 byte[] data = new byte[dataLength];
19 in.readBytes(data);
20 out.add(serializer.deserialize(data, genericClass));
21 }
22}
23
24@AllArgsConstructor
25public class RPCEncoder extends MessageToByteEncoder {
26
27 private Class<?> genericClass;
28 private Serializer serializer;
29
30 @Override
31 public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception {
32 if (genericClass.isInstance(in)) {
33 byte[] data = serializer.serialize(in);
34 out.writeInt(data.length);
35 out.writeBytes(data);
36 }
37 }
38}
39
这里我选择使用Protobuf序列化协议来做这件事(具体的ProtobufSerializer的实现因为篇幅原因就不贴在这里了, 需要的话请看项目的github)。 总的来说, 这一块还是很简单很好理解的。
发送请求与处理返回内容
请求的发送很简单, 直接用channel.writeAndFlush(request) 就行了。
问题是, 发送之后, 怎么获取这个请求的返回呢?这里,我引入了RPCResponseFuture和ResponseFutureManager来解决这个问题。
RPCResponseFuture实现了Future接口,所包含的值就是RPCResponse, 每个RPCResponseFuture都与一个requestId相关联, 除此之外, 还利用了CountDownLatch来做get方法的阻塞处理:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 1public class RPCResponseFuture implements Future<Object> {
2 private String requestId;
3
4 private RPCResponse response;
5
6 CountDownLatch latch = new CountDownLatch(1);
7
8 public RPCResponseFuture(String requestId) {
9 this.requestId = requestId;
10 }
11
12 public void done(RPCResponse response) {
13 this.response = response;
14 latch.countDown();
15 }
16
17 @Override
18 public RPCResponse get() throws InterruptedException, ExecutionException {
19 try {
20 latch.await();
21 } catch (InterruptedException e) {
22 log.error(e.getMessage());
23 }
24 return response;
25 }
26
27 // ....
28}
29
既然每个请求都会产生一个ResponseFuture, 那么自然要有一个Manager来管理这些future:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 1public class ResponseFutureManager {
2 /**
3 * Singleton
4 */
5 private static ResponseFutureManager rpcFutureManager;
6
7 private ResponseFutureManager(){}
8
9 public static ResponseFutureManager getInstance() {
10 if (rpcFutureManager == null) {
11 synchronized (ChannelManager.class) {
12 if (rpcFutureManager == null) {
13 rpcFutureManager = new ResponseFutureManager();
14 }
15 }
16 }
17 return rpcFutureManager;
18 }
19
20 private ConcurrentHashMap<String, RPCResponseFuture> rpcFutureMap = new ConcurrentHashMap<>();
21
22 public void registerFuture(RPCResponseFuture rpcResponseFuture) {
23 rpcFutureMap.put(rpcResponseFuture.getRequestId(), rpcResponseFuture);
24 }
25
26 public void futureDone(RPCResponse response) {
27 rpcFutureMap.remove(response.getRequestId()).done(response);
28 }
29}
30
ResponseFutureManager很好看懂, 就是提供了注册future、完成future的接口。
现在我们再回过头看RPCChannelInitializer中的RPCResponseHandler就很好理解了: 拿到返回值, 把对应的ResponseFuture标记成done就可以了!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 1/**
2* 处理收到返回后的回调
3*/
4 private class RPCResponseHandler extends SimpleChannelInboundHandler<RPCResponse> {
5
6 @Override
7 public void channelRead0(ChannelHandlerContext ctx, RPCResponse response) throws Exception {
8 log.debug("Get response: {}", response);
9 ResponseFutureManager.getInstance().futureDone(response);
10 }
11
12 @Override
13 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
14 log.warn("RPC request exception: {}", cause);
15 }
16 }
17
前文的FactoryBean的逻辑填充
到这里,我们已经实现了客户端的网络通信, 现在只需要把它加到前文的FactoryBean的doInvoke方法就好了!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 1 private Object doInvoke(Object proxy, Method method, Object[] args) throws Throwable {
2 String targetServiceName = type.getName();
3
4 // Create request
5 RPCRequest request = RPCRequest.builder()
6 .requestId(generateRequestId(targetServiceName))
7 .interfaceName(method.getDeclaringClass().getName())
8 .methodName(method.getName())
9 .parameters(args)
10 .parameterTypes(method.getParameterTypes()).build();
11
12 // Get service address
13 InetSocketAddress serviceAddress = getServiceAddress(targetServiceName);
14
15 // Get channel by service address
16 Channel channel = ChannelManager.getInstance().getChannel(serviceAddress);
17 if (null == channel) {
18 throw new RuntimeException("Cann't get channel for address" + serviceAddress);
19 }
20
21 // Send request
22 RPCResponse response = sendRequest(channel, request);
23 if (response == null) {
24 throw new RuntimeException("response is null");
25 }
26 if (response.hasException()) {
27 throw response.getException();
28 } else {
29 return response.getResult();
30 }
31 }
32
33 private String generateRequestId(String targetServiceName) {
34 return targetServiceName + "-" + UUID.randomUUID().toString();
35 }
36
37 private InetSocketAddress getServiceAddress(String targetServiceName) {
38 String serviceAddress = "";
39 if (serviceDiscovery != null) {
40 serviceAddress = serviceDiscovery.discover(targetServiceName);
41 log.debug("Get address: {} for service: {}", serviceAddress, targetServiceName);
42 }
43 if (StringUtils.isEmpty(serviceAddress)) {
44 throw new RuntimeException("server address is empty");
45 }
46 String[] array = StringUtils.split(serviceAddress, ":");
47 String host = array[0];
48 int port = Integer.parseInt(array[1]);
49 return new InetSocketAddress(host, port);
50 }
51
52 private RPCResponse sendRequest(Channel channel, RPCRequest request) {
53 log.debug("Send request, channel: {}, request: {}", channel, request);
54 CountDownLatch latch = new CountDownLatch(1);
55 RPCResponseFuture rpcResponseFuture = new RPCResponseFuture(request.getRequestId());
56 ResponseFutureManager.getInstance().registerFuture(rpcResponseFuture);
57 channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> {
58 log.debug("Request sent.");
59 latch.countDown();
60 });
61 try {
62 latch.await();
63 } catch (InterruptedException e) {
64 log.error(e.getMessage());
65 }
66
67 try {
68 return rpcResponseFuture.get(1, TimeUnit.SECONDS);
69 } catch (Exception e) {
70 log.warn("Exception:", e);
71 return null;
72 }
73 }
74
就这样, 一个简单的RPC客户端就实现了。 完整代码请看我的github。