simon

simon

How to Develop Your Own RPC Framework (Part 1)

How to Develop Your Own RPC Framework (Part 1)#

December 11, 2023 12:00 AM

Preface: This article references the open-source project https://github.com/MIracleCczs/simple-rpc, much of the code is based on this project. This chapter mainly discusses the implementation of an RPC framework from the perspective of client calls, focusing on the functionalities that need to be implemented on the client side.

Definition of RPC#

You can refer to the wiki: https://en.wikipedia.org/wiki/Remote_procedure_call

What Comprises a Simple RPC Framework?#

A basic RPC framework needs to include three main components: 1. Registration Center 2. Service Provider 3. Service Consumer

Mermaid Loading...

From the diagram above, it can be seen that both the service provider and consumer need to communicate with the registration center.

How is a Remote Method Call Implemented?#

Mermaid Loading...

Next, we will explain step by step based on the flowchart above. To clarify the entire logic, we will start from actual business requirements.

Requirement:

There exists a service provider Producer (hereafter referred to as the server), providing the method get.

There exists a service consumer Consumer (hereafter referred to as the client), which needs to call the get method in Producer.

Mermaid Loading...

Basic Interface Definition#

Define the UserService interface, which contains the get method.

public interface UserService {
    String get(String username);
}

Client Initiates Service Call#

Client Annotation Definition#

How can the client call a remote service as if it were calling a local method? The RPC framework is designed to solve this problem. Typically, local method calls are made using:

@Autowired
private UserService userService;

Through Spring's dependency injection, the required methods are injected into the calling object. Can we also use this form for RPC calls? The answer is yes. To achieve the above requirement, the simplest way is to define a custom annotation RpcClient.

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface RpcClient {
}

Once the annotation is defined, we should consider what attributes need to be set in the annotation.

The first question that arises is: how does the client know which remote service it is calling? At this point, we need to set the first attribute remoteAppKey, which is the unique identifier for the service. Through remoteAppKey, the client can easily find the target service in the registration center.

The second question is: how to handle multiple versions of a service? For example, during gray upgrades, we need a second parameter groupName to find the specific group under the specific service.

The remaining parameters are relatively simple. The complete parameter configuration is as follows:

/**
 * Service interface: matches the local service provider obtained from the registration center, gets the list of service providers, and selects one to initiate the service call based on the load balancing strategy.
 */
Class<?> targetItf() default Object.class;

/**
 * Timeout: service call timeout.
 */
long timeout() default 3000 * 10L;

/**
 * Number of caller threads.
 */
int consumeThreads() default 10;

/**
 * Unique identifier for the service provider.
 */
String remoteAppKey() default "";

/**
 * Service group name.
 */
String groupName() default "default";

Client Initialization#

To achieve functionality similar to @Autowired, the framework needs to perform dependency injection for all objects annotated with RpcClient at the time of Bean initialization. How can this functionality be implemented? The InstantiationAwareBeanPostProcessor interface in Spring allows custom logic to be executed at various stages of Bean instantiation. Define a ConsumerAnnotationBean method that implements the InstantiationAwareBeanPostProcessor interface.

public class ConsumerAnnotationBean implements InstantiationAwareBeanPostProcessor {
    ...other methods omitted
    @Override
    public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
    }
}

The main implementation of the interface's postProcessProperties method is to set the property values of the Bean object.

Before writing the specific code, we need to clarify the objectives we need to achieve here:

  1. Client service registration (for monitoring purposes).
  2. Dependency injection of objects.
  3. Query service nodes and pre-create Netty connections.

Now that we have clarified the requirements, we can start writing the corresponding logic.

@Override
    public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
        Class<?> beanClass = bean.getClass();
        // Services for dependency injection
        Field[] fields = beanClass.getDeclaredFields();
        for (Field field : fields) {
            RpcClient rpcClient = field.getAnnotation(RpcClient.class);
            if (rpcClient == null) {
                continue;
            }
        }
    }

Get all parameters from the initialized bean and use getAnnotation(RpcClient.class); to determine whether the parameter is annotated with RpcClient.

Client Service Registration#

The purpose of this step is to record consumer information in the registration center for subsequent monitoring. This step is relatively simple; we just need to construct the client information and submit it to the registration center.

// Construct consumer object
Consumer consumer = Consumer
                    .builder()
                    .groupName(rpcClient.groupName())
                    .remoteAppKey(rpcClient.remoteAppKey())
                    .targetItf(targetItf)
                    .build();
// Register consumer in the registration center
registerCenter.registerConsumer(consumer);

Dependency Injection of Objects#

Similarly, let's clarify the requirements here. We need to implement a dynamic proxy that, during method calls, retrieves the node information of the remote service provider based on the method call name and class name, constructs a NettyRequest message, sends it to the service provider, and finally receives the NettyResponse returned by the service.

Mermaid Loading...

Now that we have clarified the overall process, let's start writing the specific code.

First, define an object ClientProxyBeanFactory that implements the InvocationHandler interface.

The main task is to implement the interface's invoke method.

@Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {}

In the invoke method, we need to retrieve available nodes from the registration center based on the class name and method name, so the specific proxy class needs to be passed in during object instantiation. Therefore, we need to define several member variables in ClientProxyBeanFactory.

@Slf4j
public class ClientProxyBeanFactory implements InvocationHandler {

    // Connection pool for calls (Netty requests)
    private ExecutorService executorService;

    // Target proxy class
    private Class<?> targetItf;

    // Timeout
    private long timeout;

    // Number of calling threads
    private int consumeThreads;
}

When the object is initialized, the values of the member variables need to be set.

private static volatile ClientProxyBeanFactory instance;

    public ClientProxyBeanFactory(Class<?> targetItf, long timeout, int consumeThreads) {
        this.executorService = new ThreadPoolExecutor(consumeThreads, consumeThreads,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<>(), new ThreadFactoryBuilder()
                .setNameFormat("simple-rpc-%d").build(), new ThreadPoolExecutor.AbortPolicy());

        this.targetItf = targetItf;
        this.timeout = timeout;
        this.consumeThreads = consumeThreads;
    }

    /**
     * Get proxy object
     *
     * @return
     */
    public Object getProxy() {
        return Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{targetItf}, this);
    }

    public static ClientProxyBeanFactory getInstance(Class<?> targetItf, long timeout, int consumeThreads) {
        if (null == instance) {
            synchronized (ClientProxyBeanFactory.class) {
                if (null == instance) {
                    instance = new ClientProxyBeanFactory(targetItf, timeout, consumeThreads);
                }
            }
        }

        return instance;
    }

After completing the assignment of the above member variables, we can start retrieving service nodes from the registration center.

// ConsumerAnnotationBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                // Proxy class name
        String serviceName = targetItf.getName();
                // Registration center service
        IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
                // Get available nodes based on serviceName + methodName
        List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
                ...
}

Once we have the service nodes, we can retrieve the node information to be used based on the load balancing strategy. Assuming we use a random selection method to obtain the node Producer.

Open Netty Connection and Send/Receive Messages

Having obtained the Producer means we can get the remote service's Netty IP + port information, at which point we can establish a connection to the remote service. However, there is an optimization logic here: if we establish a connection every time a method is called, it will be a very time-consuming operation. If we pre-establish a Channel pool based on IP + port, then during method calls, we only need to retrieve the Channel from the connection pool, which will greatly improve service efficiency.

Mermaid Loading...

Based on the above logic, we need to implement a NettyChannelPoolFactory to cache the client's Netty request cache, while providing two methods externally: acquire to get Channel information and release to release Channel information.

The specific implementation code is as follows:

@Slf4j
public class NettyChannelPoolFactory {

    private static final NettyChannelPoolFactory CHANNEL_POOL_FACTORY = new NettyChannelPoolFactory();

    // Connection pool cache, key is service provider address, value is Netty Channel blocking queue
    public static final Map<InetSocketAddress, ArrayBlockingQueue<Channel>> CHANNEL_POOL_MAP = Maps.newConcurrentMap();

    /**
     * Initialize the length of the Netty Channel blocking queue, this value is configurable.
     */
    private static final Integer CHANNEL_CONNECT_SIZE = 3;

    public static NettyChannelPoolFactory getInstance() {
        return CHANNEL_POOL_FACTORY;
    }

    /**
     * Initialize netty connection pool
     */
    public void initChannelFactory(List<Producer> producerNodeList) {
        for (Producer producer : producerNodeList) {
            InetSocketAddress address = new InetSocketAddress(producer.getIp(), producer.getPort());
            while (CHANNEL_POOL_MAP.get(address) == null || CHANNEL_POOL_MAP.get(address).size() < CHANNEL_CONNECT_SIZE) {
                ArrayBlockingQueue<Channel> channels = CHANNEL_POOL_MAP.get(address);
                if (channels == null || channels.size() < CHANNEL_CONNECT_SIZE) {
                    // Initialize Netty Channel blocking queue
                    Channel channel = null;
                    while (channel == null) {
                        channel = registerChannel(address);
                    }

                    if (channels == null) {
                        channels = new ArrayBlockingQueue<>(CHANNEL_CONNECT_SIZE);
                    }

                    boolean offer = channels.offer(channel);
                    if (!offer) {
                        log.debug("channelArrayBlockingQueue fail");

                    } else {
                        CHANNEL_POOL_MAP.put(address, channels);
                    }
                }
            }
        }
    }

    /**
     * Get client queue by address
     */
    public ArrayBlockingQueue<Channel> acquire(InetSocketAddress address) {
        return CHANNEL_POOL_MAP.get(address);
    }

    /**
     * After use, put the channel back into the blocking queue
     */
    public void release(ArrayBlockingQueue<Channel> queue, Channel channel, InetSocketAddress address) {
        if (queue == null) {
            return;
        }

        // Before recycling, check if the channel is available
        if (channel == null || !channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
            if (channel != null) {
                channel.deregister().syncUninterruptibly().awaitUninterruptibly();
                channel.closeFuture().syncUninterruptibly().awaitUninterruptibly();

            } else {
                while (channel == null) {
                    channel = registerChannel(address);
                }
            }
        }

        queue.offer(channel);
    }

    /**
     * Register netty client
     */
    public Channel registerChannel(InetSocketAddress address) {
        try {
            EventLoopGroup group = new NioEventLoopGroup(10);
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.remoteAddress(address);
            bootstrap.group(group)
                    .channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            // Register Netty encoder
                            ch.pipeline().addLast(new NettyEncoderHandler());
                            // Register Netty decoder
                            ch.pipeline().addLast(new NettyDecoderHandler(NettyResponse.class));
                            // Register client business logic handler
                            ch.pipeline().addLast(new NettyHandlerClient());
                        }
                    });

            ChannelFuture channelFuture = bootstrap.connect().sync();
            final Channel channel = channelFuture.channel();
            final CountDownLatch countDownLatch = new CountDownLatch(1);
            final List<Boolean> isSuccessHolder = Lists.newArrayListWithCapacity(1);
            // Listen for whether the channel is successfully established
            channelFuture.addListener(future -> {
                if (future.isSuccess()) {
                    isSuccessHolder.add(Boolean.TRUE);

                } else {
                    // If establishment fails, save the failure mark
                    log.error("registerChannel fail , {}", future.cause().getMessage());
                    isSuccessHolder.add(Boolean.FALSE);
                }

                countDownLatch.countDown();
            });

            countDownLatch.await();
            // If the Channel is successfully established, return the newly created Channel
            if (Boolean.TRUE.equals(isSuccessHolder.get(0))) {
                return channel;
            }

        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("registerChannel fail", e);
        }

        return null;
    }
}

The NettyChannelPoolFactory object also defines a method registerChannel that takes InetSocketAddress as a parameter and returns a Channel. The method mainly creates a Netty connection based on the passed address, sets up the serialization and deserialization encoders and decoders, and then adds a NettyHandlerClient for client message processing. Finally, it returns the initialized Channel connection.

With the NettyChannelPoolFactory in place, we can obtain the Channel based on the Producer information retrieved from the registration center, allowing us to send NettyRequest messages.

Constructing the NettyRequest Message

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    // Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
    // Directly take the first one; load balancing strategies can be applied here
    Producer providerCopy = producerList.get(0);
    // Construct NettyRequest
NettyRequest request = NettyRequest.builder()
                                // Service node information
                .producer(providerCopy)
                            // Unique identifier for this request
                .uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
                            // Request timeout
                .invokeTimeout(timeout)
                            // Request method name
                .invokeMethodName(method.getName())
                            // Request parameters
                .args(args)
                .build();

}

Now that we have both the NettyRequest and the Channel for sending messages, we just need to send the message and receive the result, deserializing it into the method's return value.

We can use a thread pool to handle the sending of Netty messages and the decoding of return values.

Define a ClientServiceCallable that implements the Callable<NettyResponse> interface for tasks with return values.

The Callable interface only requires the implementation of one method, call(), in which we need to complete 1. Retrieve the Channel object 2. Send the request 3. Return the result.

@Slf4j
public class ClientServiceCallable implements Callable<NettyResponse> {

    /**
     * Netty communication pipeline
     */
    private Channel channel;

    /**
     * Request parameters
     */
    private final NettyRequest request;

    public static ClientServiceCallable of(NettyRequest request) {
        return new ClientServiceCallable(request);
    }

    public ClientServiceCallable(NettyRequest request) {
        this.request = request;
    }

@Override
public NettyResponse call() throws Exception {
    InetSocketAddress inetSocketAddress = new InetSocketAddress(request.getProducer().getIp(), request.getProducer().getPort());
    // Get local cached Channel queue
    ArrayBlockingQueue<Channel> blockingQueue = NettyChannelPoolFactory.getInstance().acquire(inetSocketAddress);
    try {
        if (channel == null) {
            // Retrieve Channel from the queue
            channel = blockingQueue.take();
        }

        if (channel == null) {
            throw new RuntimeException("can't find channel to resolve this request");
        }
    } catch (Exception e) {
        log.error("client send request error", e);

    } finally {
        // After the request ends, return the Channel to the queue
        NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
    }
}
}

In the call method of the above code, we first retrieve the Channel queue from the local cache, and then in the finally block, we return the Channel to the queue. The remaining logic in the method is to send the NettyRequest request and return the result.

try {
            if (channel == null) {
                channel = blockingQueue.take();
            }

            if (channel == null) {
                throw new RuntimeException("can't find channel to resolve this request");

            } else {
                1️⃣ ClientResponseHolder.initResponseData(request.getUniqueKey());

                2️⃣ while (!channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
                    log.warn("retry get new channel");
                    channel = blockingQueue.poll(request.getInvokeTimeout(), TimeUnit.MILLISECONDS);
                    if (channel == null) {
                        // If there are no available Channels in the queue, re-register a Channel
                        channel = NettyChannelPoolFactory.getInstance().registerChannel(inetSocketAddress);
                    }
                }

                // Write the information of this call into the Netty channel and initiate an asynchronous call
                3️⃣ ChannelFuture channelFuture = channel.writeAndFlush(request);
                channelFuture.syncUninterruptibly();
                // Retrieve the return result from the result container, setting the wait timeout to invokeTimeout
                long invokeTimeout = request.getInvokeTimeout();
                4️⃣ return ClientResponseHolder.getValue(request.getUniqueKey(), invokeTimeout);
            }

        } catch (Exception e) {
            log.error("client send request error", e);

        } finally {
            NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
        }

1️⃣ ClientResponseHolder Class

ClientResponseHolder.initResponseData(request.getUniqueKey()); introduces a new class ClientResponseHolder. What is this class for?

Since the message sending is asynchronous, a Map<String, NettyResponseWrapper> is used for local data caching, where the Map key is the NeettyRequest's uniqueKey, and the value is the Netty return result, which is the return value after the server executes.

The specific implementation of ClientResponseHolder is as follows:

@Slf4j
public class ClientResponseHolder {

    private static final Map<String, NettyResponseWrapper> RESPONSE_WRAPPER_MAP = Maps.newConcurrentMap();

    private static final ScheduledExecutorService executorService;

    static {
        executorService = new ScheduledThreadPoolExecutor(1, new RemoveExpireThreadFactory("simple-rpc", false));
        // Remove expired data
        executorService.scheduleWithFixedDelay(() -> {
            for (Map.Entry<String, NettyResponseWrapper> entry : RESPONSE_WRAPPER_MAP.entrySet()) {
                boolean expire = entry.getValue().isExpire();
                if (expire) {
                    RESPONSE_WRAPPER_MAP.remove(entry.getKey());
                }
            }
        }, 1, 20, TimeUnit.MILLISECONDS);
    }

    /**
     * Initialize the return result container, requestUniqueKey uniquely identifies this call
     */
    public static void initResponseData(String requestUniqueKey) {
        RESPONSE_WRAPPER_MAP.put(requestUniqueKey, NettyResponseWrapper.of());
    }

    /**
     * Put the asynchronous return result of the Netty call into the blocking queue
     */
    public static void putResultValue(NettyResponse response) {
        long currentTimeMillis = System.currentTimeMillis();
        NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(response.getUniqueKey());
        responseWrapper.setResponseTime(currentTimeMillis);
        responseWrapper.getResponseBlockingQueue().add(response);
        RESPONSE_WRAPPER_MAP.put(response.getUniqueKey(), responseWrapper);
    }

    /**
     * Get the asynchronous return result from the blocking queue
     */
    public static NettyResponse getValue(String requestUniqueKey, long timeout) {
        NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(requestUniqueKey);
        try {
            return responseWrapper.getResponseBlockingQueue().poll(timeout, TimeUnit.MILLISECONDS);
            
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("get value error", e);

        } finally {
            RESPONSE_WRAPPER_MAP.remove(requestUniqueKey);
        }
        return null;
    }

}
  • initResponseData: Initializes the Map based on uniqueKey.
  • putResultValue: Inserts the NettyResponse return result.
  • getValue: Retrieves the result based on uniqueKey.

A scheduled executor is also defined to periodically check for expired messages and clean up memory data.

2️⃣ Channel State Check

Check the current state of the Netty channel. If the current Channel is unavailable, a new channel needs to be requested.

3️⃣ Netty Message Sending

4️⃣ Retrieve the Netty Return Result from Local Cache

Asynchronous Call to Netty Service, Using Future to Get Return Result

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    // Get available nodes based on serviceName + methodName
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
    // Directly take the first one; load balancing strategies can be applied here
    Producer providerCopy = producerList.get(0);
    // Construct NettyRequest
NettyRequest request = NettyRequest.builder()
                                // Service node information
                .producer(providerCopy)
                            // Unique identifier for this request
                .uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
                            // Request timeout
                .invokeTimeout(timeout)
                            // Request method name
                .invokeMethodName(method.getName())
                            // Request parameters
                .args(args)
                .build();
        // Initiate asynchronous call, sending request through NettyClient
        try {
            Future<NettyResponse> responseFuture = executorService.submit(ClientServiceCallable.of(request));
            NettyResponse response = responseFuture.get(timeout, TimeUnit.MILLISECONDS);
            if (response != null) {
                return response.getResult();
            }

        } catch (Exception e) {
            log.error("send request error", e);
        }
}

At this point, we have completed the full implementation of the ClientProxyBeanFactory proxy object. Now we just need to perform dependency injection for the initialized proxy object.

// ConsumerAnnotationBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                // Proxy class name
        String serviceName = targetItf.getName();
                // Registration center service
        IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
                // Get available nodes based on serviceName + methodName
        List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
                // 3. Get service proxy object
            Class<?> targetItf = rpcClient.targetItf();
            if (targetItf == Object.class) {
                targetItf = field.getType();
           }
                // Initialize proxy object
                ClientProxyBeanFactory factory = ClientProxyBeanFactory.getInstance(targetItf, rpcClient.timeout(), rpcClient.consumeThreads());
            ReflectionUtils.makeAccessible(field);
            try {
                // Set proxy object
                field.set(bean, factory.getProxy());

            } catch (IllegalAccessException e) {
                log.error("ReferenceBeanPostProcessor post process properties error, beanName={}", beanName, e);
                throw new RuntimeException("ReferenceBeanPostProcessor post process properties error, beanName=" + beanName, e);
            }
}

By obtaining the proxy object through ClientProxyBeanFactory.getInstance, we use the field.set method to perform the assignment.

After completing the above operations, when the client executes the get method, it will invoke the invoke method of ClientProxyBeanFactory, subsequently executing the content of Opening Netty Connection and Sending/Receiving Messages, and finally returning the result from the service provider.

Query Service Nodes and Pre-create Netty Connections#

This part overlaps with the second step, and the core logic is as follows:

Mermaid Loading...

At this point, all processes on the client side have been completed. However, to clarify the main ideas, the article only briefly touched on load balancing strategies, serialization, and deserialization, which are also very important parts of an RPC framework.

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.