如何开发一个自己的 RPC 框架 (一)#
December 11, 2023 12:00 AM
写在前面:本文参考了开源项目 https://github.com/MIracleCczs/simple-rpc,其中大部分代码参考了该项目,本章主要从客户端的调用出发,讲讲一个 RPC 框架的实现在客户端测需要实现那些功能
RPC 的定义#
可以参考 wiki: https://zh.wikipedia.org/wiki/ 遠程過程調用
一个简单的 RPC 框架是如何组成的?#
一个基础的 RPC 框架,需要包含三大部分:1. 注册中心 2. 服务提供方 3. 服务消费方
从上图可以看出,服务提供方和消费方都需要和注册中心通信
一个远程方法的调用是如何实现的?#
下面,我们将根据上面流程图,一步步进行讲解。为了方便更加清楚的讲清整个逻辑,我们从实际的业务需求出发。
需求:
存在服务提供方 Producer (后面统称服务端),提供方法 get
存在服务消费方 Consumer (后面统称客户端), 需要调用 Producer 中的 get
方法
基础接口定义#
定义 UserService
接口,接口内包含 get
方法
public interface UserService {
String get(String username);
}
客户端发起服务调用#
客户端注解定义#
客户端如何才能够像调用本地方法一样调用远程服务呢?RPC 框架就是用来解决这个问题。我们一般本地方法的调用都是采用
@Autowired
private UserService userService;
通过 Spring
依赖注入的方式,将需要用到的方法注入到调用对象中,那么我们 RPC 调用能不能也采用这种形式呢?答案当然是可以的。那么为了实现上面的需求,我们最简单的办法就是自定义一个注解 RpcClient
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface RpcClient {
}
注解定义完成后,我们就应该考虑注解中需要设置那些属性呢?
那么冒出来的第一个问题就是:客户端如何知道是调用的那个远程服务呢?这时我们就需要设置第一个属性 remoteAppKey
服务的唯一标识,通过 remoteAppKey
客户端可以轻松地在注册中心找到目标服务。
这个时候又会有第二个疑问,如果一个服务多个版本如何处理呢?比如进行灰度升级等操作的时,那么这个时候就需要第二个参数 groupName
找到具体服务下的具体分组
剩余的参数就比较简单了,完成的参数配置如下:
/**
* 服务接口:匹配从注册中心获取到本地的服务提供者,得到服务提供者列表,再根据负载均衡策略选取一个发起服务调用
*/
Class<?> targetItf() default Object.class;
/**
* 超时时间:服务调用超时时间
*/
long timeout() default 3000 * 10L;
/**
* 调用者线程数
*/
int consumeThreads() default 10;
/**
* 服务提供者唯一标识
*/
String remoteAppKey() default "";
/**
* 服务分组组名
*/
String groupName() default "default";
客户端初始化#
为了实现类似 @Autowired
的功能,框架需要在 Bean
初始化之时,将所有被 RpcClient
注解的对象进行依赖注入,那么如何实现这个功能呢? Spring
的 InstantiationAwareBeanPostProcessor
接口,可以在 Bean 的实例化的各个阶段执行自定义逻辑。定义一个 ConsumerAnnotaionBean
方法,实现 InstantiationAwareBeanPostProcessor
接口。
public class ConsumerAnnotaionBean implements InstantiationAwareBeanPostProcessor {
...其他方法省略
@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
}
}
主要实现接口的 postProcessProperties
方法,设置 Bean 对象的属性值
进行具体代码编写之前,我们需要先理清楚这里需要实现那些目的:
- 客户端服务注册 (监控目的)
- 对象依赖的注入
- 查询服务节点,预创建 Netty 连接
好了,理清楚完需求之后,我们便开始对应的逻辑编写
@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
Class<?> beanClass = bean.getClass();
// 依赖注入的服务
Field[] fields = beanClass.getDeclaredFields();
for (Field field : fields) {
RpcClient rpcClient = field.getAnnotation(RpcClient.class);
if (rpcClient == null) {
continue;
}
}
}
根据初始化的 bean
获取对象中的所有参数,然后使用 getAnnotation(RpcClient.class);
判断参数是否被 RpcClient
所注解
客户端服务注册#
这一步的目的是为了在注册中心中记录消费者信息,方便后续监控,所以这一步相对来说非常简单,只需要构造客户端信息然后提交到注册中心即可。
// 构造消费者对象
Consumer consumer = Consumer
.builder()
.groupName(rpcClient.groupName())
.remoteAppKey(rpcClient.remoteAppKey())
.targetItf(targetItf)
.build();
// 注册中心注册消费者
registerCenter.registerConsumer(consumer);
对象依赖的注入#
同样的,我们先梳理一下这里的需求。我们需要实现一个动态代理,在方法调用时,根据方法调用名 + 类名获取远程服务提供方的节点信息,然后构造一个 NettyRequest
信息,发送到服务方,最后需要接收服务放返回的 NettyResponse
解析成方法的返回值进行返回
好了,我们已经理清楚了上面整体流程,那么就开始具体的代码编写吧
首先,定义一个对象 ClientProxyBeanFactory
实现 InvocationHandler
接口
主要是实现接口的 invoke
方法
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {}
invoke
方法中,需要根据类名 + 方法名从注册中心中获取可用的节点,那么具体代理的类这个时候就需要从对象实例化中传入,所以我们在定义 ClientProxyBeanFactory
时,需要定义几个成员变量
@Slf4j
public class ClientProxyBeanFactory implements InvocationHandler {
// 调用连接池(Netty 请求)
private ExecutorService executorService;
// 目标代理类
private Class<?> targetItf;
// 超时
private long timeout;
// 调用线程数
private int consumeThreads;
}
对象初始化的时候,需要设置成员变量的值
private static volatile ClientProxyBeanFactory instance;
public ClientProxyBeanFactory(Class<?> targetItf, long timeout, int consumeThreads) {
this.executorService = new ThreadPoolExecutor(consumeThreads, consumeThreads,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(), new ThreadFactoryBuilder()
.setNameFormat("simple-rpc-%d").build(), new ThreadPoolExecutor.AbortPolicy());
this.targetItf = targetItf;
this.timeout = timeout;
this.consumeThreads = consumeThreads;
}
/**
* 获取代理对象
*
* @return
*/
public Object getProxy() {
return Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{targetItf}, this);
}
public static ClientProxyBeanFactory getInstance(Class<?> targetItf, long timeout, int consumeThreads) {
if (null == instance) {
synchronized (ClientProxyBeanFactory.class) {
if (null == instance) {
instance = new ClientProxyBeanFactory(targetItf, timeout, consumeThreads);
}
}
}
return instance;
}
完成上述成员变量赋值后,便可以开始从注册中心中获取服务节点了
// ConsumerAnnotaionBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 代理 className
String serviceName = targetItf.getName();
// 注册中心服务
IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
// 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
...
}
获取到服务节点后,这里可以根据设置的负载均衡策略获取本次使用的节点信息,假设这里采用随机获取的方法获取得到节点 Producer
开启 Netty 连接,进行消息收发送
拿到了 Producer
就意味着我们可以获取到远程服务 Netty 的 ip + port 信息了,这个时候就可以建立远程服务连接了。但是这里存在一个优化逻辑,就是如果我们每次都是方法调用时再去建立链接,那么建立连接将会是一个非常耗时的操作,但是如果我们提前根据 ip + port 建立一个 Channel
池,方法调用时只需要从连接池中获取 Channel
,那么服务的效率是不是会大大提高了?
基于上面的逻辑,我们需要实现一个 NettyChannelPoolFactory
用来缓存客户端的 Netty
的请求缓存,同时对外提供两个方法: acquire
获取 Channel 信息 release
释放 Channel 信息
具体实现代码如下:
@Slf4j
public class NettyChannelPoolFactory {
private static final NettyChannelPoolFactory CHANNEL_POOL_FACTORY = new NettyChannelPoolFactory();
// 连接池缓存 key 为服务提供者地址,value为Netty Channel阻塞队列
public static final Map<InetSocketAddress, ArrayBlockingQueue<Channel>> CHANNEL_POOL_MAP = Maps.newConcurrentMap();
/**
* 初始化Netty Channel阻塞队列的长度,该值为可配置信息
*/
private static final Integer CHANNEL_CONNECT_SIZE = 3;
public static NettyChannelPoolFactory getInstance() {
return CHANNEL_POOL_FACTORY;
}
/**
* 初始化 netty 连接池
*/
public void initChannelFactory(List<Producer> producerNodeList) {
for (Producer producer : producerNodeList) {
InetSocketAddress address = new InetSocketAddress(producer.getIp(), producer.getPort());
while (CHANNEL_POOL_MAP.get(address) == null || CHANNEL_POOL_MAP.get(address).size() < CHANNEL_CONNECT_SIZE) {
ArrayBlockingQueue<Channel> channels = CHANNEL_POOL_MAP.get(address);
if (channels == null || channels.size() < CHANNEL_CONNECT_SIZE) {
// 初始化 Netty Channel 阻塞队列
Channel channel = null;
while (channel == null) {
channel = registerChannel(address);
}
if (channels == null) {
channels = new ArrayBlockingQueue<>(CHANNEL_CONNECT_SIZE);
}
boolean offer = channels.offer(channel);
if (!offer) {
log.debug("channelArrayBlockingQueue fail");
} else {
CHANNEL_POOL_MAP.put(address, channels);
}
}
}
}
}
/**
* 根据 address 获取客户端队列
*/
public ArrayBlockingQueue<Channel> acquire(InetSocketAddress address) {
return CHANNEL_POOL_MAP.get(address);
}
/**
* 使用完成之后,将 channel 放回到 阻塞队列
*/
public void release(ArrayBlockingQueue<Channel> queue, Channel channel, InetSocketAddress address) {
if (queue == null) {
return;
}
// 回收之前判断channel 是否可用
if (channel == null || !channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
if (channel != null) {
channel.deregister().syncUninterruptibly().awaitUninterruptibly();
channel.closeFuture().syncUninterruptibly().awaitUninterruptibly();
} else {
while (channel == null) {
channel = registerChannel(address);
}
}
}
queue.offer(channel);
}
/**
* 注册 netty 客户端
*/
public Channel registerChannel(InetSocketAddress address) {
try {
EventLoopGroup group = new NioEventLoopGroup(10);
Bootstrap bootstrap = new Bootstrap();
bootstrap.remoteAddress(address);
bootstrap.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 注册Netty编码器
ch.pipeline().addLast(new NettyEncoderHandler());
// 注册Netty解码器
ch.pipeline().addLast(new NettyDecoderHandler(NettyResponse.class));
// 注册客户端业务处理逻辑Handler
ch.pipeline().addLast(new NettyHandlerClient());
}
});
ChannelFuture channelFuture = bootstrap.connect().sync();
final Channel channel = channelFuture.channel();
final CountDownLatch countDownLatch = new CountDownLatch(1);
final List<Boolean> isSuccessHolder = Lists.newArrayListWithCapacity(1);
// 监听channel是否建立成功
channelFuture.addListener(future -> {
if (future.isSuccess()) {
isSuccessHolder.add(Boolean.TRUE);
} else {
// 如果建立失败,保存建立失败标记
log.error("registerChannel fail , {}", future.cause().getMessage());
isSuccessHolder.add(Boolean.FALSE);
}
countDownLatch.countDown();
});
countDownLatch.await();
// 如果Channel建立成功,返回新建的Channel
if (Boolean.TRUE.equals(isSuccessHolder.get(0))) {
return channel;
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("registerChannel fail", e);
}
return null;
}
}
NettyChannelPoolFactory
对象中还定义了一个方法 registerChannel
接收 InetSocketAddress
的入参,返回值为 Channel
。方法中主要根据传入的 address
信息,创建了 Netty
连接,设置了序列化和反序列化的编解码器,然后增加了一个 NettyHandlerClient
的客户端消息处理器。最后将初始化好的 Channel
连接进行返回
有了上面的 NettyChannelPoolFactory
,便可以将从注册中心获得到的 Producer
信息,根据 ip + port 获取 Channel
,从而进行 NettyRequest
消息的发送
**NettyRequest
消息的构造 **
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// 直接取第0 个,这里可以采用负载均衡策略进行获取
Producer providerCopy =producerList.get(0) ;
// NettyRequest 构造
NettyRequest request = NettyRequest.builder()
// 服务节点信息
.producer(providerCopy)
// 本次请求的唯一编号
.uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
// 请求超时时间
.invokeTimeout(timeout)
// 请求方法名称
.invokeMethodName(method.getName())
// 请求参数
.args(args)
.build();
}
好了,现在 NettyRequest
和发送消息的 Channel
都已经有了,只需要将消息发送出去,然后接收消息然后序列成方法的出参即可。
这里可以采用线程池的方式,进行 Netty
消息的发送和返回值的解码
定义一个 ClientServiceCallable
集成自 Callable<NettyResponse>
带返回值的任务的接口
Callable
只有一个需要实现的方法 call()
, 在该方法中,需要完成 1. 获取 Channel
对象 2. 发送请求 3. 结果值返回
@Slf4j
public class ClientServiceCallable implements Callable<NettyResponse> {
/**
* Netty 通信管道
*/
private Channel channel;
/**
* 请求参数
*/
private final NettyRequest request;
public static ClientServiceCallable of(NettyRequest request) {
return new ClientServiceCallable(request);
}
public ClientServiceCallable(NettyRequest request) {
this.request = request;
}
@Override
public NettyResponse call() throws Exception {
InetSocketAddress inetSocketAddress = new InetSocketAddress(request.getProducer().getIp(), request.getProducer().getPort());
// 获取本地缓存 Channel 队列
ArrayBlockingQueue<Channel> blockingQueue = NettyChannelPoolFactory.getInstance().acquire(inetSocketAddress);
try {
if (channel == null) {
// 从队列中获取 Channel
channel = blockingQueue.take();
}
if (channel == null) {
throw new RuntimeException("can't find channel to resolve this request");
}
} catch (Exception e) {
log.error("client send request error", e);
} finally {
// 请求结束,队列归还 Channel
NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
}
}
}
上述代码的 call
方法中,首先从本地缓存中获取到了 Channel
队列,然后在 finally
中将 Channel
归还到队列中。那么方法中剩下的逻辑就是发送 NettyRequest
请求,然后返回结果了
try {
if (channel == null) {
channel = blockingQueue.take();
}
if (channel == null) {
throw new RuntimeException("can't find channel to resolve this request");
} else {
1️⃣ ClientResponseHolder.initResponseData(request.getUniqueKey());
2️⃣while (!channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
log.warn("retry get new channel");
channel = blockingQueue.poll(request.getInvokeTimeout(), TimeUnit.MILLISECONDS);
if (channel == null) {
// 若队列中没有可用的Channel,则重新注册一个Channel
channel = NettyChannelPoolFactory.getInstance().registerChannel(inetSocketAddress);
}
}
// 将本次调用的信息写入Netty通道,发起异步调用
3️⃣ ChannelFuture channelFuture = channel.writeAndFlush(request);
channelFuture.syncUninterruptibly();
// 从返回结果容器中获取返回结果,同时设置等待超时时间为invokeTimeout
long invokeTimeout = request.getInvokeTimeout();
4️⃣ return ClientResponseHolder.getValue(request.getUniqueKey(), invokeTimeout);
}
} catch (Exception e) {
log.error("client send request error", e);
} finally {
NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
}
1️⃣ ClientResponseHolder 类
ClientResponseHolder.initResponseData(request.getUniqueKey());
这里又增加了一个新的类ClientResponseHolder
, 那么这个类是干嘛的呢?
由于消息的发送都是异步的形式,这里使用了 Map<String,NettyResponseWrapper>
进行本地数据缓存,Map
的 KEY 是 NeettyRequest
的 uniqueKey
,而 Value
就是 Netty
的返回结果,即是服务端执行之后的返回值
ClientResponseHolder
的具体实现如下:
@Slf4j
public class ClientResponseHolder {
private static final Map<String, NettyResponseWrapper> RESPONSE_WRAPPER_MAP = Maps.newConcurrentMap();
private static final ScheduledExecutorService executorService;
static {
executorService = new ScheduledThreadPoolExecutor(1, new RemoveExpireThreadFactory("simple-rpc", false));
// 删除过期的数据
executorService.scheduleWithFixedDelay(() -> {
for (Map.Entry<String, NettyResponseWrapper> entry : RESPONSE_WRAPPER_MAP.entrySet()) {
boolean expire = entry.getValue().isExpire();
if (expire) {
RESPONSE_WRAPPER_MAP.remove(entry.getKey());
}
}
}, 1, 20, TimeUnit.MILLISECONDS);
}
/**
* 初始化返回结果容器,requestUniqueKey唯一标识本次调用
*/
public static void initResponseData(String requestUniqueKey) {
RESPONSE_WRAPPER_MAP.put(requestUniqueKey, NettyResponseWrapper.of());
}
/**
* 将Netty调用异步返回结果放入阻塞队列
*/
public static void putResultValue(NettyResponse response) {
long currentTimeMillis = System.currentTimeMillis();
NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(response.getUniqueKey());
responseWrapper.setResponseTime(currentTimeMillis);
responseWrapper.getResponseBlockingQueue().add(response);
RESPONSE_WRAPPER_MAP.put(response.getUniqueKey(), responseWrapper);
}
/**
* 从阻塞队列中获取异步返回结果
*/
public static NettyResponse getValue(String requestUniqueKey, long timeout) {
NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(requestUniqueKey);
try {
return responseWrapper.getResponseBlockingQueue().poll(timeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("get value error", e);
} finally {
RESPONSE_WRAPPER_MAP.remove(requestUniqueKey);
}
return null;
}
}
- initResponseData: 根据
uniqueKey
初始化Map
- putResultValue: 插入
NettyResponse
返回结果 - getValue: 根据
uniqueKey
获取结果
同时定义了一个定时执行的队列,队列中根据 responseTime
判断消息是否过期进行内存数据清洗
2️⃣ Channel 状态判断
判断当前 Netty
通道状态,如果当前 Channel
不可用,则需要重新申请通道
3️⃣ Netty 消息发送
4️⃣ 从本地缓存中获取 Netty 返回结果
异步调用 Netty 服务,使用 Future
获取返回结果
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// 直接取第0 个,这里可以采用负载均衡策略进行获取
Producer providerCopy =producerList.get(0) ;
// NettyRequest 构造
NettyRequest request = NettyRequest.builder()
// 服务节点信息
.producer(providerCopy)
// 本次请求的唯一编号
.uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
// 请求超时时间
.invokeTimeout(timeout)
// 请求方法名称
.invokeMethodName(method.getName())
// 请求参数
.args(args)
.build();
// 发起异步调用,通过 NettyClient 发送请求
try {
Future<NettyResponse> responseFuture = executorService.submit(ClientServiceCallable.of(request));
NettyResponse response = responseFuture.get(timeout, TimeUnit.MILLISECONDS);
if (response != null) {
return response.getResult();
}
} catch (Exception e) {
log.error("send request error", e);
}
}
到这里,我们就完成了 ClientProxyBeanFactory
代理对象的完整编写,现在就需要将初始化好的代理对象进行依赖注入
// ConsumerAnnotaionBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 代理 className
String serviceName = targetItf.getName();
// 注册中心服务
IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
// 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// 3.获取服务代理对象
Class<?> targetItf = rpcClient.targetItf();
if (targetItf == Object.class) {
targetItf = field.getType();
}
// 初始化代理对象
ClientProxyBeanFactory factory = ClientProxyBeanFactory.getInstance(targetItf, rpcClient.timeout(), rpcClient.consumeThreads());
ReflectionUtils.makeAccessible(field);
try {
// 设置代理对象
field.set(bean, factory.getProxy());
} catch (IllegalAccessException e) {
log.error("ReferenceBeanPostProcessor post process properties error, beanName={}", beanName, e);
throw new RuntimeException("ReferenceBeanPostProcessor post process properties error, beanName=" + beanName, e);
}
}
通过 ClientProxyBeanFactory.getInstance
获取到代理对象后,使用 field.set
方法进行执行赋值
完成上述操作之后,当客户端执行 get
方法时,便会 invoke
到 ClientProxyBeanFactory
的 invoke
方法上,随后执行 开启 Netty 连接,进行消息收发送 内容,随后将服务方结果进行返回
查询服务节点,预创建 Netty 连接#
这部分内容和第二步有所重叠,其核心逻辑如下:
到这里,客户端的所有流程就都编写完成了。但是为了理清楚主要思路,文章中对负载均衡策略、序列化和反序列化等都只是一笔带过。这些也是一个 RPC 框架非常很重要的一部分。