How to Develop Your Own RPC Framework (Part 1)#
December 11, 2023 12:00 AM
Preface: This article references the open-source project https://github.com/MIracleCczs/simple-rpc, much of the code is based on this project. This chapter mainly discusses the implementation of an RPC framework from the perspective of client calls, focusing on the functionalities that need to be implemented on the client side.
Definition of RPC#
You can refer to the wiki: https://en.wikipedia.org/wiki/Remote_procedure_call
What Comprises a Simple RPC Framework?#
A basic RPC framework needs to include three main components: 1. Registration Center 2. Service Provider 3. Service Consumer
From the diagram above, it can be seen that both the service provider and consumer need to communicate with the registration center.
How is a Remote Method Call Implemented?#
Next, we will explain step by step based on the flowchart above. To clarify the entire logic, we will start from actual business requirements.
Requirement:
There exists a service provider Producer (hereafter referred to as the server), providing the method get
.
There exists a service consumer Consumer (hereafter referred to as the client), which needs to call the get
method in Producer.
Basic Interface Definition#
Define the UserService
interface, which contains the get
method.
public interface UserService {
String get(String username);
}
Client Initiates Service Call#
Client Annotation Definition#
How can the client call a remote service as if it were calling a local method? The RPC framework is designed to solve this problem. Typically, local method calls are made using:
@Autowired
private UserService userService;
Through Spring's dependency injection, the required methods are injected into the calling object. Can we also use this form for RPC calls? The answer is yes. To achieve the above requirement, the simplest way is to define a custom annotation RpcClient
.
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface RpcClient {
}
Once the annotation is defined, we should consider what attributes need to be set in the annotation.
The first question that arises is: how does the client know which remote service it is calling? At this point, we need to set the first attribute remoteAppKey
, which is the unique identifier for the service. Through remoteAppKey
, the client can easily find the target service in the registration center.
The second question is: how to handle multiple versions of a service? For example, during gray upgrades, we need a second parameter groupName
to find the specific group under the specific service.
The remaining parameters are relatively simple. The complete parameter configuration is as follows:
/**
* Service interface: matches the local service provider obtained from the registration center, gets the list of service providers, and selects one to initiate the service call based on the load balancing strategy.
*/
Class<?> targetItf() default Object.class;
/**
* Timeout: service call timeout.
*/
long timeout() default 3000 * 10L;
/**
* Number of caller threads.
*/
int consumeThreads() default 10;
/**
* Unique identifier for the service provider.
*/
String remoteAppKey() default "";
/**
* Service group name.
*/
String groupName() default "default";
Client Initialization#
To achieve functionality similar to @Autowired
, the framework needs to perform dependency injection for all objects annotated with RpcClient
at the time of Bean
initialization. How can this functionality be implemented? The InstantiationAwareBeanPostProcessor
interface in Spring allows custom logic to be executed at various stages of Bean instantiation. Define a ConsumerAnnotationBean
method that implements the InstantiationAwareBeanPostProcessor
interface.
public class ConsumerAnnotationBean implements InstantiationAwareBeanPostProcessor {
...other methods omitted
@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
}
}
The main implementation of the interface's postProcessProperties
method is to set the property values of the Bean object.
Before writing the specific code, we need to clarify the objectives we need to achieve here:
- Client service registration (for monitoring purposes).
- Dependency injection of objects.
- Query service nodes and pre-create Netty connections.
Now that we have clarified the requirements, we can start writing the corresponding logic.
@Override
public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
Class<?> beanClass = bean.getClass();
// Services for dependency injection
Field[] fields = beanClass.getDeclaredFields();
for (Field field : fields) {
RpcClient rpcClient = field.getAnnotation(RpcClient.class);
if (rpcClient == null) {
continue;
}
}
}
Get all parameters from the initialized bean
and use getAnnotation(RpcClient.class);
to determine whether the parameter is annotated with RpcClient
.
Client Service Registration#
The purpose of this step is to record consumer information in the registration center for subsequent monitoring. This step is relatively simple; we just need to construct the client information and submit it to the registration center.
// Construct consumer object
Consumer consumer = Consumer
.builder()
.groupName(rpcClient.groupName())
.remoteAppKey(rpcClient.remoteAppKey())
.targetItf(targetItf)
.build();
// Register consumer in the registration center
registerCenter.registerConsumer(consumer);
Dependency Injection of Objects#
Similarly, let's clarify the requirements here. We need to implement a dynamic proxy that, during method calls, retrieves the node information of the remote service provider based on the method call name and class name, constructs a NettyRequest
message, sends it to the service provider, and finally receives the NettyResponse
returned by the service.
Now that we have clarified the overall process, let's start writing the specific code.
First, define an object ClientProxyBeanFactory
that implements the InvocationHandler
interface.
The main task is to implement the interface's invoke
method.
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {}
In the invoke
method, we need to retrieve available nodes from the registration center based on the class name and method name, so the specific proxy class needs to be passed in during object instantiation. Therefore, we need to define several member variables in ClientProxyBeanFactory
.
@Slf4j
public class ClientProxyBeanFactory implements InvocationHandler {
// Connection pool for calls (Netty requests)
private ExecutorService executorService;
// Target proxy class
private Class<?> targetItf;
// Timeout
private long timeout;
// Number of calling threads
private int consumeThreads;
}
When the object is initialized, the values of the member variables need to be set.
private static volatile ClientProxyBeanFactory instance;
public ClientProxyBeanFactory(Class<?> targetItf, long timeout, int consumeThreads) {
this.executorService = new ThreadPoolExecutor(consumeThreads, consumeThreads,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(), new ThreadFactoryBuilder()
.setNameFormat("simple-rpc-%d").build(), new ThreadPoolExecutor.AbortPolicy());
this.targetItf = targetItf;
this.timeout = timeout;
this.consumeThreads = consumeThreads;
}
/**
* Get proxy object
*
* @return
*/
public Object getProxy() {
return Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{targetItf}, this);
}
public static ClientProxyBeanFactory getInstance(Class<?> targetItf, long timeout, int consumeThreads) {
if (null == instance) {
synchronized (ClientProxyBeanFactory.class) {
if (null == instance) {
instance = new ClientProxyBeanFactory(targetItf, timeout, consumeThreads);
}
}
}
return instance;
}
After completing the assignment of the above member variables, we can start retrieving service nodes from the registration center.
// ConsumerAnnotationBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// Proxy class name
String serviceName = targetItf.getName();
// Registration center service
IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
// Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
...
}
Once we have the service nodes, we can retrieve the node information to be used based on the load balancing strategy. Assuming we use a random selection method to obtain the node Producer
.
Open Netty Connection and Send/Receive Messages
Having obtained the Producer
means we can get the remote service's Netty IP + port information, at which point we can establish a connection to the remote service. However, there is an optimization logic here: if we establish a connection every time a method is called, it will be a very time-consuming operation. If we pre-establish a Channel
pool based on IP + port, then during method calls, we only need to retrieve the Channel
from the connection pool, which will greatly improve service efficiency.
Based on the above logic, we need to implement a NettyChannelPoolFactory
to cache the client's Netty request cache, while providing two methods externally: acquire
to get Channel information and release
to release Channel information.
The specific implementation code is as follows:
@Slf4j
public class NettyChannelPoolFactory {
private static final NettyChannelPoolFactory CHANNEL_POOL_FACTORY = new NettyChannelPoolFactory();
// Connection pool cache, key is service provider address, value is Netty Channel blocking queue
public static final Map<InetSocketAddress, ArrayBlockingQueue<Channel>> CHANNEL_POOL_MAP = Maps.newConcurrentMap();
/**
* Initialize the length of the Netty Channel blocking queue, this value is configurable.
*/
private static final Integer CHANNEL_CONNECT_SIZE = 3;
public static NettyChannelPoolFactory getInstance() {
return CHANNEL_POOL_FACTORY;
}
/**
* Initialize netty connection pool
*/
public void initChannelFactory(List<Producer> producerNodeList) {
for (Producer producer : producerNodeList) {
InetSocketAddress address = new InetSocketAddress(producer.getIp(), producer.getPort());
while (CHANNEL_POOL_MAP.get(address) == null || CHANNEL_POOL_MAP.get(address).size() < CHANNEL_CONNECT_SIZE) {
ArrayBlockingQueue<Channel> channels = CHANNEL_POOL_MAP.get(address);
if (channels == null || channels.size() < CHANNEL_CONNECT_SIZE) {
// Initialize Netty Channel blocking queue
Channel channel = null;
while (channel == null) {
channel = registerChannel(address);
}
if (channels == null) {
channels = new ArrayBlockingQueue<>(CHANNEL_CONNECT_SIZE);
}
boolean offer = channels.offer(channel);
if (!offer) {
log.debug("channelArrayBlockingQueue fail");
} else {
CHANNEL_POOL_MAP.put(address, channels);
}
}
}
}
}
/**
* Get client queue by address
*/
public ArrayBlockingQueue<Channel> acquire(InetSocketAddress address) {
return CHANNEL_POOL_MAP.get(address);
}
/**
* After use, put the channel back into the blocking queue
*/
public void release(ArrayBlockingQueue<Channel> queue, Channel channel, InetSocketAddress address) {
if (queue == null) {
return;
}
// Before recycling, check if the channel is available
if (channel == null || !channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
if (channel != null) {
channel.deregister().syncUninterruptibly().awaitUninterruptibly();
channel.closeFuture().syncUninterruptibly().awaitUninterruptibly();
} else {
while (channel == null) {
channel = registerChannel(address);
}
}
}
queue.offer(channel);
}
/**
* Register netty client
*/
public Channel registerChannel(InetSocketAddress address) {
try {
EventLoopGroup group = new NioEventLoopGroup(10);
Bootstrap bootstrap = new Bootstrap();
bootstrap.remoteAddress(address);
bootstrap.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// Register Netty encoder
ch.pipeline().addLast(new NettyEncoderHandler());
// Register Netty decoder
ch.pipeline().addLast(new NettyDecoderHandler(NettyResponse.class));
// Register client business logic handler
ch.pipeline().addLast(new NettyHandlerClient());
}
});
ChannelFuture channelFuture = bootstrap.connect().sync();
final Channel channel = channelFuture.channel();
final CountDownLatch countDownLatch = new CountDownLatch(1);
final List<Boolean> isSuccessHolder = Lists.newArrayListWithCapacity(1);
// Listen for whether the channel is successfully established
channelFuture.addListener(future -> {
if (future.isSuccess()) {
isSuccessHolder.add(Boolean.TRUE);
} else {
// If establishment fails, save the failure mark
log.error("registerChannel fail , {}", future.cause().getMessage());
isSuccessHolder.add(Boolean.FALSE);
}
countDownLatch.countDown();
});
countDownLatch.await();
// If the Channel is successfully established, return the newly created Channel
if (Boolean.TRUE.equals(isSuccessHolder.get(0))) {
return channel;
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("registerChannel fail", e);
}
return null;
}
}
The NettyChannelPoolFactory
object also defines a method registerChannel
that takes InetSocketAddress
as a parameter and returns a Channel
. The method mainly creates a Netty
connection based on the passed address
, sets up the serialization and deserialization encoders and decoders, and then adds a NettyHandlerClient
for client message processing. Finally, it returns the initialized Channel
connection.
With the NettyChannelPoolFactory
in place, we can obtain the Channel
based on the Producer
information retrieved from the registration center, allowing us to send NettyRequest
messages.
Constructing the NettyRequest Message
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// Directly take the first one; load balancing strategies can be applied here
Producer providerCopy = producerList.get(0);
// Construct NettyRequest
NettyRequest request = NettyRequest.builder()
// Service node information
.producer(providerCopy)
// Unique identifier for this request
.uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
// Request timeout
.invokeTimeout(timeout)
// Request method name
.invokeMethodName(method.getName())
// Request parameters
.args(args)
.build();
}
Now that we have both the NettyRequest
and the Channel
for sending messages, we just need to send the message and receive the result, deserializing it into the method's return value.
We can use a thread pool to handle the sending of Netty
messages and the decoding of return values.
Define a ClientServiceCallable
that implements the Callable<NettyResponse>
interface for tasks with return values.
The Callable
interface only requires the implementation of one method, call()
, in which we need to complete 1. Retrieve the Channel
object 2. Send the request 3. Return the result.
@Slf4j
public class ClientServiceCallable implements Callable<NettyResponse> {
/**
* Netty communication pipeline
*/
private Channel channel;
/**
* Request parameters
*/
private final NettyRequest request;
public static ClientServiceCallable of(NettyRequest request) {
return new ClientServiceCallable(request);
}
public ClientServiceCallable(NettyRequest request) {
this.request = request;
}
@Override
public NettyResponse call() throws Exception {
InetSocketAddress inetSocketAddress = new InetSocketAddress(request.getProducer().getIp(), request.getProducer().getPort());
// Get local cached Channel queue
ArrayBlockingQueue<Channel> blockingQueue = NettyChannelPoolFactory.getInstance().acquire(inetSocketAddress);
try {
if (channel == null) {
// Retrieve Channel from the queue
channel = blockingQueue.take();
}
if (channel == null) {
throw new RuntimeException("can't find channel to resolve this request");
}
} catch (Exception e) {
log.error("client send request error", e);
} finally {
// After the request ends, return the Channel to the queue
NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
}
}
}
In the call
method of the above code, we first retrieve the Channel
queue from the local cache, and then in the finally
block, we return the Channel
to the queue. The remaining logic in the method is to send the NettyRequest
request and return the result.
try {
if (channel == null) {
channel = blockingQueue.take();
}
if (channel == null) {
throw new RuntimeException("can't find channel to resolve this request");
} else {
1️⃣ ClientResponseHolder.initResponseData(request.getUniqueKey());
2️⃣ while (!channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
log.warn("retry get new channel");
channel = blockingQueue.poll(request.getInvokeTimeout(), TimeUnit.MILLISECONDS);
if (channel == null) {
// If there are no available Channels in the queue, re-register a Channel
channel = NettyChannelPoolFactory.getInstance().registerChannel(inetSocketAddress);
}
}
// Write the information of this call into the Netty channel and initiate an asynchronous call
3️⃣ ChannelFuture channelFuture = channel.writeAndFlush(request);
channelFuture.syncUninterruptibly();
// Retrieve the return result from the result container, setting the wait timeout to invokeTimeout
long invokeTimeout = request.getInvokeTimeout();
4️⃣ return ClientResponseHolder.getValue(request.getUniqueKey(), invokeTimeout);
}
} catch (Exception e) {
log.error("client send request error", e);
} finally {
NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
}
1️⃣ ClientResponseHolder Class
ClientResponseHolder.initResponseData(request.getUniqueKey());
introduces a new class ClientResponseHolder
. What is this class for?
Since the message sending is asynchronous, a Map<String, NettyResponseWrapper>
is used for local data caching, where the Map
key is the NeettyRequest
's uniqueKey
, and the value is the Netty
return result, which is the return value after the server executes.
The specific implementation of ClientResponseHolder
is as follows:
@Slf4j
public class ClientResponseHolder {
private static final Map<String, NettyResponseWrapper> RESPONSE_WRAPPER_MAP = Maps.newConcurrentMap();
private static final ScheduledExecutorService executorService;
static {
executorService = new ScheduledThreadPoolExecutor(1, new RemoveExpireThreadFactory("simple-rpc", false));
// Remove expired data
executorService.scheduleWithFixedDelay(() -> {
for (Map.Entry<String, NettyResponseWrapper> entry : RESPONSE_WRAPPER_MAP.entrySet()) {
boolean expire = entry.getValue().isExpire();
if (expire) {
RESPONSE_WRAPPER_MAP.remove(entry.getKey());
}
}
}, 1, 20, TimeUnit.MILLISECONDS);
}
/**
* Initialize the return result container, requestUniqueKey uniquely identifies this call
*/
public static void initResponseData(String requestUniqueKey) {
RESPONSE_WRAPPER_MAP.put(requestUniqueKey, NettyResponseWrapper.of());
}
/**
* Put the asynchronous return result of the Netty call into the blocking queue
*/
public static void putResultValue(NettyResponse response) {
long currentTimeMillis = System.currentTimeMillis();
NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(response.getUniqueKey());
responseWrapper.setResponseTime(currentTimeMillis);
responseWrapper.getResponseBlockingQueue().add(response);
RESPONSE_WRAPPER_MAP.put(response.getUniqueKey(), responseWrapper);
}
/**
* Get the asynchronous return result from the blocking queue
*/
public static NettyResponse getValue(String requestUniqueKey, long timeout) {
NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(requestUniqueKey);
try {
return responseWrapper.getResponseBlockingQueue().poll(timeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("get value error", e);
} finally {
RESPONSE_WRAPPER_MAP.remove(requestUniqueKey);
}
return null;
}
}
- initResponseData: Initializes the
Map
based onuniqueKey
. - putResultValue: Inserts the
NettyResponse
return result. - getValue: Retrieves the result based on
uniqueKey
.
A scheduled executor is also defined to periodically check for expired messages and clean up memory data.
2️⃣ Channel State Check
Check the current state of the Netty
channel. If the current Channel
is unavailable, a new channel needs to be requested.
3️⃣ Netty Message Sending
4️⃣ Retrieve the Netty Return Result from Local Cache
Asynchronous Call to Netty Service, Using Future to Get Return Result
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// Directly take the first one; load balancing strategies can be applied here
Producer providerCopy = producerList.get(0);
// Construct NettyRequest
NettyRequest request = NettyRequest.builder()
// Service node information
.producer(providerCopy)
// Unique identifier for this request
.uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
// Request timeout
.invokeTimeout(timeout)
// Request method name
.invokeMethodName(method.getName())
// Request parameters
.args(args)
.build();
// Initiate asynchronous call, sending request through NettyClient
try {
Future<NettyResponse> responseFuture = executorService.submit(ClientServiceCallable.of(request));
NettyResponse response = responseFuture.get(timeout, TimeUnit.MILLISECONDS);
if (response != null) {
return response.getResult();
}
} catch (Exception e) {
log.error("send request error", e);
}
}
At this point, we have completed the full implementation of the ClientProxyBeanFactory
proxy object. Now we just need to perform dependency injection for the initialized proxy object.
// ConsumerAnnotationBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// Proxy class name
String serviceName = targetItf.getName();
// Registration center service
IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
// Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
// 3. Get service proxy object
Class<?> targetItf = rpcClient.targetItf();
if (targetItf == Object.class) {
targetItf = field.getType();
}
// Initialize proxy object
ClientProxyBeanFactory factory = ClientProxyBeanFactory.getInstance(targetItf, rpcClient.timeout(), rpcClient.consumeThreads());
ReflectionUtils.makeAccessible(field);
try {
// Set proxy object
field.set(bean, factory.getProxy());
} catch (IllegalAccessException e) {
log.error("ReferenceBeanPostProcessor post process properties error, beanName={}", beanName, e);
throw new RuntimeException("ReferenceBeanPostProcessor post process properties error, beanName=" + beanName, e);
}
}
By obtaining the proxy object through ClientProxyBeanFactory.getInstance
, we use the field.set
method to perform the assignment.
After completing the above operations, when the client executes the get
method, it will invoke the invoke
method of ClientProxyBeanFactory
, subsequently executing the content of Opening Netty Connection and Sending/Receiving Messages, and finally returning the result from the service provider.
Query Service Nodes and Pre-create Netty Connections#
This part overlaps with the second step, and the core logic is as follows:
At this point, all processes on the client side have been completed. However, to clarify the main ideas, the article only briefly touched on load balancing strategies, serialization, and deserialization, which are also very important parts of an RPC framework.