simon

simon

如何开发一个自己的 RPC 框架 (一)

如何开发一个自己的 RPC 框架 (一)#

December 11, 2023 12:00 AM

写在前面:本文参考了开源项目 https://github.com/MIracleCczs/simple-rpc,其中大部分代码参考了该项目,本章主要从客户端的调用出发,讲讲一个 RPC 框架的实现在客户端测需要实现那些功能

RPC 的定义#

可以参考 wiki: https://zh.wikipedia.org/wiki/ 遠程過程調用

一个简单的 RPC 框架是如何组成的?#

一个基础的 RPC 框架,需要包含三大部分:1. 注册中心 2. 服务提供方 3. 服务消费方

Mermaid Loading...

从上图可以看出,服务提供方和消费方都需要和注册中心通信

一个远程方法的调用是如何实现的?#

Mermaid Loading...

下面,我们将根据上面流程图,一步步进行讲解。为了方便更加清楚的讲清整个逻辑,我们从实际的业务需求出发。

需求:

存在服务提供方 Producer (后面统称服务端),提供方法 get

存在服务消费方 Consumer (后面统称客户端), 需要调用 Producer 中的 get 方法

Mermaid Loading...

基础接口定义#

定义 UserService 接口,接口内包含 get 方法

public interface UserService {
    String get(String username);
}

客户端发起服务调用#

客户端注解定义#

客户端如何才能够像调用本地方法一样调用远程服务呢?RPC 框架就是用来解决这个问题。我们一般本地方法的调用都是采用

@Autowired
private UserService userService;

通过 Spring 依赖注入的方式,将需要用到的方法注入到调用对象中,那么我们 RPC 调用能不能也采用这种形式呢?答案当然是可以的。那么为了实现上面的需求,我们最简单的办法就是自定义一个注解 RpcClient

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface RpcClient {
}

注解定义完成后,我们就应该考虑注解中需要设置那些属性呢?

那么冒出来的第一个问题就是:客户端如何知道是调用的那个远程服务呢?这时我们就需要设置第一个属性 remoteAppKey 服务的唯一标识,通过 remoteAppKey 客户端可以轻松地在注册中心找到目标服务。

这个时候又会有第二个疑问,如果一个服务多个版本如何处理呢?比如进行灰度升级等操作的时,那么这个时候就需要第二个参数 groupName 找到具体服务下的具体分组

剩余的参数就比较简单了,完成的参数配置如下:

 		/**
     * 服务接口:匹配从注册中心获取到本地的服务提供者,得到服务提供者列表,再根据负载均衡策略选取一个发起服务调用
     */
    Class<?> targetItf() default Object.class;

    /**
     * 超时时间:服务调用超时时间
     */
    long timeout() default 3000 * 10L;

    /**
     * 调用者线程数
     */
    int consumeThreads() default 10;

    /**
     * 服务提供者唯一标识
     */
    String remoteAppKey() default "";

    /**
     * 服务分组组名
     */
    String groupName() default "default";

客户端初始化#

为了实现类似 @Autowired 的功能,框架需要在 Bean 初始化之时,将所有被 RpcClient 注解的对象进行依赖注入,那么如何实现这个功能呢? SpringInstantiationAwareBeanPostProcessor 接口,可以在 Bean 的实例化的各个阶段执行自定义逻辑。定义一个 ConsumerAnnotaionBean 方法,实现 InstantiationAwareBeanPostProcessor 接口。

public class ConsumerAnnotaionBean implements InstantiationAwareBeanPostProcessor {
		...其他方法省略
		@Override
    public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
	}
}

主要实现接口的 postProcessProperties 方法,设置 Bean 对象的属性值

进行具体代码编写之前,我们需要先理清楚这里需要实现那些目的:

  1. 客户端服务注册 (监控目的)
  2. 对象依赖的注入
  3. 查询服务节点,预创建 Netty 连接

好了,理清楚完需求之后,我们便开始对应的逻辑编写

@Override
    public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) throws BeansException {
        Class<?> beanClass = bean.getClass();
        // 依赖注入的服务
        Field[] fields = beanClass.getDeclaredFields();
        for (Field field : fields) {
            RpcClient rpcClient = field.getAnnotation(RpcClient.class);
            if (rpcClient == null) {
                continue;
            }
        }
    }

根据初始化的 bean 获取对象中的所有参数,然后使用 getAnnotation(RpcClient.class); 判断参数是否被 RpcClient 所注解

客户端服务注册#

这一步的目的是为了在注册中心中记录消费者信息,方便后续监控,所以这一步相对来说非常简单,只需要构造客户端信息然后提交到注册中心即可。

// 构造消费者对象
Consumer consumer = Consumer
                    .builder()
                    .groupName(rpcClient.groupName())
                    .remoteAppKey(rpcClient.remoteAppKey())
                    .targetItf(targetItf)
                    .build();
// 注册中心注册消费者
registerCenter.registerConsumer(consumer);

对象依赖的注入#

同样的,我们先梳理一下这里的需求。我们需要实现一个动态代理,在方法调用时,根据方法调用名 + 类名获取远程服务提供方的节点信息,然后构造一个 NettyRequest 信息,发送到服务方,最后需要接收服务放返回的 NettyResponse 解析成方法的返回值进行返回

Mermaid Loading...

好了,我们已经理清楚了上面整体流程,那么就开始具体的代码编写吧

首先,定义一个对象 ClientProxyBeanFactory 实现 InvocationHandler 接口

主要是实现接口的 invoke 方法

@Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {}

invoke 方法中,需要根据类名 + 方法名从注册中心中获取可用的节点,那么具体代理的类这个时候就需要从对象实例化中传入,所以我们在定义 ClientProxyBeanFactory 时,需要定义几个成员变量

@Slf4j
public class ClientProxyBeanFactory implements InvocationHandler {

    // 调用连接池(Netty 请求)
    private ExecutorService executorService;

    // 目标代理类
    private Class<?> targetItf;

    // 超时
    private long timeout;

    // 调用线程数
    private int consumeThreads;
}

对象初始化的时候,需要设置成员变量的值

private static volatile ClientProxyBeanFactory instance;

    public ClientProxyBeanFactory(Class<?> targetItf, long timeout, int consumeThreads) {
        this.executorService = new ThreadPoolExecutor(consumeThreads, consumeThreads,
                0L, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue<>(), new ThreadFactoryBuilder()
                .setNameFormat("simple-rpc-%d").build(), new ThreadPoolExecutor.AbortPolicy());

        this.targetItf = targetItf;
        this.timeout = timeout;
        this.consumeThreads = consumeThreads;
    }

    /**
     * 获取代理对象
     *
     * @return
     */
    public Object getProxy() {
        return Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class[]{targetItf}, this);
    }

    public static ClientProxyBeanFactory getInstance(Class<?> targetItf, long timeout, int consumeThreads) {
        if (null == instance) {
            synchronized (ClientProxyBeanFactory.class) {
                if (null == instance) {
                    instance = new ClientProxyBeanFactory(targetItf, timeout, consumeThreads);
                }
            }
        }

        return instance;
    }

完成上述成员变量赋值后,便可以开始从注册中心中获取服务节点了

// ConsumerAnnotaionBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
				// 代理 className
        String serviceName = targetItf.getName();
				// 注册中心服务
        IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
				// 根据 serviceName + methodName 获取可以使用的节点
        List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
				...
}

获取到服务节点后,这里可以根据设置的负载均衡策略获取本次使用的节点信息,假设这里采用随机获取的方法获取得到节点 Producer

开启 Netty 连接,进行消息收发送

拿到了 Producer 就意味着我们可以获取到远程服务 Netty 的 ip + port 信息了,这个时候就可以建立远程服务连接了。但是这里存在一个优化逻辑,就是如果我们每次都是方法调用时再去建立链接,那么建立连接将会是一个非常耗时的操作,但是如果我们提前根据 ip + port 建立一个 Channel 池,方法调用时只需要从连接池中获取 Channel ,那么服务的效率是不是会大大提高了?

Mermaid Loading...

基于上面的逻辑,我们需要实现一个 NettyChannelPoolFactory 用来缓存客户端的 Netty 的请求缓存,同时对外提供两个方法: acquire 获取 Channel 信息 release 释放 Channel 信息

具体实现代码如下:

@Slf4j
public class NettyChannelPoolFactory {

    private static final NettyChannelPoolFactory CHANNEL_POOL_FACTORY = new NettyChannelPoolFactory();

    // 连接池缓存 key 为服务提供者地址,value为Netty Channel阻塞队列
    public static final Map<InetSocketAddress, ArrayBlockingQueue<Channel>> CHANNEL_POOL_MAP = Maps.newConcurrentMap();

    /**
     * 初始化Netty Channel阻塞队列的长度,该值为可配置信息
     */
    private static final Integer CHANNEL_CONNECT_SIZE = 3;

    public static NettyChannelPoolFactory getInstance() {
        return CHANNEL_POOL_FACTORY;
    }

    /**
     * 初始化 netty 连接池
     */
    public void initChannelFactory(List<Producer> producerNodeList) {
        for (Producer producer : producerNodeList) {
            InetSocketAddress address = new InetSocketAddress(producer.getIp(), producer.getPort());
            while (CHANNEL_POOL_MAP.get(address) == null || CHANNEL_POOL_MAP.get(address).size() < CHANNEL_CONNECT_SIZE) {
                ArrayBlockingQueue<Channel> channels = CHANNEL_POOL_MAP.get(address);
                if (channels == null || channels.size() < CHANNEL_CONNECT_SIZE) {
                    // 初始化 Netty Channel 阻塞队列
                    Channel channel = null;
                    while (channel == null) {
                        channel = registerChannel(address);
                    }

                    if (channels == null) {
                        channels = new ArrayBlockingQueue<>(CHANNEL_CONNECT_SIZE);
                    }

                    boolean offer = channels.offer(channel);
                    if (!offer) {
                        log.debug("channelArrayBlockingQueue fail");

                    } else {
                        CHANNEL_POOL_MAP.put(address, channels);
                    }
                }
            }
        }
    }

    /**
     * 根据 address 获取客户端队列
     */
    public ArrayBlockingQueue<Channel> acquire(InetSocketAddress address) {
        return CHANNEL_POOL_MAP.get(address);
    }

    /**
     * 使用完成之后,将 channel 放回到 阻塞队列
     */
    public void release(ArrayBlockingQueue<Channel> queue, Channel channel, InetSocketAddress address) {
        if (queue == null) {
            return;
        }

        // 回收之前判断channel 是否可用
        if (channel == null || !channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
            if (channel != null) {
                channel.deregister().syncUninterruptibly().awaitUninterruptibly();
                channel.closeFuture().syncUninterruptibly().awaitUninterruptibly();

            } else {
                while (channel == null) {
                    channel = registerChannel(address);
                }
            }
        }

        queue.offer(channel);
    }

    /**
     * 注册 netty 客户端
     */
    public Channel registerChannel(InetSocketAddress address) {
        try {
            EventLoopGroup group = new NioEventLoopGroup(10);
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.remoteAddress(address);
            bootstrap.group(group)
                    .channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            // 注册Netty编码器
                            ch.pipeline().addLast(new NettyEncoderHandler());
                            // 注册Netty解码器
                            ch.pipeline().addLast(new NettyDecoderHandler(NettyResponse.class));
                            // 注册客户端业务处理逻辑Handler
                            ch.pipeline().addLast(new NettyHandlerClient());
                        }
                    });

            ChannelFuture channelFuture = bootstrap.connect().sync();
            final Channel channel = channelFuture.channel();
            final CountDownLatch countDownLatch = new CountDownLatch(1);
            final List<Boolean> isSuccessHolder = Lists.newArrayListWithCapacity(1);
            // 监听channel是否建立成功
            channelFuture.addListener(future -> {
                if (future.isSuccess()) {
                    isSuccessHolder.add(Boolean.TRUE);

                } else {
                    // 如果建立失败,保存建立失败标记
                    log.error("registerChannel fail , {}", future.cause().getMessage());
                    isSuccessHolder.add(Boolean.FALSE);
                }

                countDownLatch.countDown();
            });

            countDownLatch.await();
            // 如果Channel建立成功,返回新建的Channel
            if (Boolean.TRUE.equals(isSuccessHolder.get(0))) {
                return channel;
            }

        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("registerChannel fail", e);
        }

        return null;
    }
}

NettyChannelPoolFactory 对象中还定义了一个方法 registerChannel 接收 InetSocketAddress 的入参,返回值为 Channel 。方法中主要根据传入的 address 信息,创建了 Netty 连接,设置了序列化和反序列化的编解码器,然后增加了一个 NettyHandlerClient 的客户端消息处理器。最后将初始化好的 Channel 连接进行返回

有了上面的 NettyChannelPoolFactory ,便可以将从注册中心获得到的 Producer 信息,根据 ip + port 获取 Channel ,从而进行 NettyRequest 消息的发送

**NettyRequest 消息的构造 **

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    // 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
    // 直接取第0 个,这里可以采用负载均衡策略进行获取
    Producer providerCopy =producerList.get(0) ;
		// NettyRequest 构造
NettyRequest request = NettyRequest.builder()
								// 服务节点信息
                .producer(providerCopy)
							// 本次请求的唯一编号
                .uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
							// 请求超时时间
                .invokeTimeout(timeout)
							// 请求方法名称
                .invokeMethodName(method.getName())
							// 请求参数
                .args(args)
                .build();

}

好了,现在 NettyRequest 和发送消息的 Channel 都已经有了,只需要将消息发送出去,然后接收消息然后序列成方法的出参即可。

这里可以采用线程池的方式,进行 Netty 消息的发送和返回值的解码

定义一个 ClientServiceCallable 集成自 Callable<NettyResponse> 带返回值的任务的接口

Callable 只有一个需要实现的方法 call() , 在该方法中,需要完成 1. 获取 Channel 对象 2. 发送请求 3. 结果值返回

@Slf4j
public class ClientServiceCallable implements Callable<NettyResponse> {

    /**
     * Netty 通信管道
     */
    private Channel channel;

    /**
     * 请求参数
     */
    private final NettyRequest request;

    public static ClientServiceCallable of(NettyRequest request) {
        return new ClientServiceCallable(request);
    }

    public ClientServiceCallable(NettyRequest request) {
        this.request = request;
    }

@Override
public NettyResponse call() throws Exception {
    InetSocketAddress inetSocketAddress = new InetSocketAddress(request.getProducer().getIp(), request.getProducer().getPort());
    // 获取本地缓存 Channel 队列
    ArrayBlockingQueue<Channel> blockingQueue = NettyChannelPoolFactory.getInstance().acquire(inetSocketAddress);
    try {
        if (channel == null) {
            // 从队列中获取 Channel
            channel = blockingQueue.take();
        }

        if (channel == null) {
            throw new RuntimeException("can't find channel to resolve this request");
        }
    } catch (Exception e) {
        log.error("client send request error", e);

    } finally {
        // 请求结束,队列归还 Channel
        NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
    }
}
}

上述代码的 call 方法中,首先从本地缓存中获取到了 Channel 队列,然后在 finally 中将 Channel 归还到队列中。那么方法中剩下的逻辑就是发送 NettyRequest 请求,然后返回结果了

try {
            if (channel == null) {
                channel = blockingQueue.take();
            }

            if (channel == null) {
                throw new RuntimeException("can't find channel to resolve this request");

            } else {
                1️⃣ ClientResponseHolder.initResponseData(request.getUniqueKey());

                2️⃣while (!channel.isOpen() || !channel.isActive() || !channel.isWritable()) {
                    log.warn("retry get new channel");
                    channel = blockingQueue.poll(request.getInvokeTimeout(), TimeUnit.MILLISECONDS);
                    if (channel == null) {
                        // 若队列中没有可用的Channel,则重新注册一个Channel
                        channel = NettyChannelPoolFactory.getInstance().registerChannel(inetSocketAddress);
                    }
                }

                // 将本次调用的信息写入Netty通道,发起异步调用
                3️⃣ ChannelFuture channelFuture = channel.writeAndFlush(request);
                channelFuture.syncUninterruptibly();
                // 从返回结果容器中获取返回结果,同时设置等待超时时间为invokeTimeout
                long invokeTimeout = request.getInvokeTimeout();
                4️⃣ return ClientResponseHolder.getValue(request.getUniqueKey(), invokeTimeout);
            }

        } catch (Exception e) {
            log.error("client send request error", e);

        } finally {
            NettyChannelPoolFactory.getInstance().release(blockingQueue, channel, inetSocketAddress);
        }

1️⃣ ClientResponseHolder 类

ClientResponseHolder.initResponseData(request.getUniqueKey()); 这里又增加了一个新的类ClientResponseHolder, 那么这个类是干嘛的呢?

由于消息的发送都是异步的形式,这里使用了 Map<String,NettyResponseWrapper> 进行本地数据缓存,Map 的 KEY 是 NeettyRequestuniqueKey ,而 Value 就是 Netty 的返回结果,即是服务端执行之后的返回值

ClientResponseHolder 的具体实现如下:

@Slf4j
public class ClientResponseHolder {

    private static final Map<String, NettyResponseWrapper> RESPONSE_WRAPPER_MAP = Maps.newConcurrentMap();

    private static final ScheduledExecutorService executorService;

    static {
        executorService = new ScheduledThreadPoolExecutor(1, new RemoveExpireThreadFactory("simple-rpc", false));
        // 删除过期的数据
        executorService.scheduleWithFixedDelay(() -> {
            for (Map.Entry<String, NettyResponseWrapper> entry : RESPONSE_WRAPPER_MAP.entrySet()) {
                boolean expire = entry.getValue().isExpire();
                if (expire) {
                    RESPONSE_WRAPPER_MAP.remove(entry.getKey());
                }
            }
        }, 1, 20, TimeUnit.MILLISECONDS);
    }

    /**
     * 初始化返回结果容器,requestUniqueKey唯一标识本次调用
     */
    public static void initResponseData(String requestUniqueKey) {
        RESPONSE_WRAPPER_MAP.put(requestUniqueKey, NettyResponseWrapper.of());
    }

    /**
     * 将Netty调用异步返回结果放入阻塞队列
     */
    public static void putResultValue(NettyResponse response) {
        long currentTimeMillis = System.currentTimeMillis();
        NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(response.getUniqueKey());
        responseWrapper.setResponseTime(currentTimeMillis);
        responseWrapper.getResponseBlockingQueue().add(response);
        RESPONSE_WRAPPER_MAP.put(response.getUniqueKey(), responseWrapper);
    }

    /**
     * 从阻塞队列中获取异步返回结果
     */
    public static NettyResponse getValue(String requestUniqueKey, long timeout) {
        NettyResponseWrapper responseWrapper = RESPONSE_WRAPPER_MAP.get(requestUniqueKey);
        try {
            return responseWrapper.getResponseBlockingQueue().poll(timeout, TimeUnit.MILLISECONDS);
            
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error("get value error", e);

        } finally {
            RESPONSE_WRAPPER_MAP.remove(requestUniqueKey);
        }
        return null;
    }

}
  • initResponseData: 根据 uniqueKey 初始化 Map
  • putResultValue: 插入 NettyResponse 返回结果
  • getValue: 根据 uniqueKey 获取结果

同时定义了一个定时执行的队列,队列中根据 responseTime 判断消息是否过期进行内存数据清洗

2️⃣ Channel 状态判断

判断当前 Netty 通道状态,如果当前 Channel 不可用,则需要重新申请通道

3️⃣ Netty 消息发送

4️⃣ 从本地缓存中获取 Netty 返回结果

异步调用 Netty 服务,使用 Future 获取返回结果

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    // 根据 serviceName + methodName 获取可以使用的节点
List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
    // 直接取第0 个,这里可以采用负载均衡策略进行获取
    Producer providerCopy =producerList.get(0) ;
		// NettyRequest 构造
NettyRequest request = NettyRequest.builder()
								// 服务节点信息
                .producer(providerCopy)
							// 本次请求的唯一编号
                .uniqueKey(UUID.randomUUID() + "-" + Thread.currentThread().getId())
							// 请求超时时间
                .invokeTimeout(timeout)
							// 请求方法名称
                .invokeMethodName(method.getName())
							// 请求参数
                .args(args)
                .build();
        // 发起异步调用,通过 NettyClient 发送请求
        try {
            Future<NettyResponse> responseFuture = executorService.submit(ClientServiceCallable.of(request));
            NettyResponse response = responseFuture.get(timeout, TimeUnit.MILLISECONDS);
            if (response != null) {
                return response.getResult();
            }

        } catch (Exception e) {
            log.error("send request error", e);
        }
}

到这里,我们就完成了 ClientProxyBeanFactory 代理对象的完整编写,现在就需要将初始化好的代理对象进行依赖注入

// ConsumerAnnotaionBean.class
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
				// 代理 className
        String serviceName = targetItf.getName();
				// 注册中心服务
        IRegisterCenter registerCenter = IRegisterCenterZkImpl.getInstance();
				// 根据 serviceName + methodName 获取可以使用的节点
        List<Producer> producerList = registerCenter.getServiceProducer(serviceName, method.getName());
				// 3.获取服务代理对象
            Class<?> targetItf = rpcClient.targetItf();
            if (targetItf == Object.class) {
                targetItf = field.getType();
           }
				// 初始化代理对象
				ClientProxyBeanFactory factory = ClientProxyBeanFactory.getInstance(targetItf, rpcClient.timeout(), rpcClient.consumeThreads());
            ReflectionUtils.makeAccessible(field);
            try {
                // 设置代理对象
                field.set(bean, factory.getProxy());

            } catch (IllegalAccessException e) {
                log.error("ReferenceBeanPostProcessor post process properties error, beanName={}", beanName, e);
                throw new RuntimeException("ReferenceBeanPostProcessor post process properties error, beanName=" + beanName, e);
            }
}

通过 ClientProxyBeanFactory.getInstance 获取到代理对象后,使用 field.set 方法进行执行赋值

完成上述操作之后,当客户端执行 get 方法时,便会 invokeClientProxyBeanFactoryinvoke 方法上,随后执行 开启 Netty 连接,进行消息收发送 内容,随后将服务方结果进行返回

查询服务节点,预创建 Netty 连接#

这部分内容和第二步有所重叠,其核心逻辑如下:

Mermaid Loading...

到这里,客户端的所有流程就都编写完成了。但是为了理清楚主要思路,文章中对负载均衡策略、序列化和反序列化等都只是一笔带过。这些也是一个 RPC 框架非常很重要的一部分。

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。