【easy-rpc】一、自己实现一个RPC服务

1,762 阅读5分钟

本文为原创文章,转载请注明出处

源码地址:这里

目前市面上比较好的RPC服务框架,比如阿里的Dubbo框架、包括Netty也是支持RPC的,他们分别是在RPC的基础上,加上了诸如负载均衡、服务发现等功能,其核心在于对于远程方法的调用封装上面。

具体RPC是什么这里就不说了。

我们的RPC框架主要分成两部分,一个Client端,一个Server端。

需要实现的功能:

某类的方法具体实现在Server端; 在Client端调用Server端的方法。

Client端基本架构包括三层:

  • 底层是基本的协议层,我们使用HTTP协议或者直接使用Java的Socket传输二进制
  • 调用层我们使用对象序列化方法来序列化数据,并使用JDK动态代理来为为应用层提供服务
  • 应用层即为用户调用的使用层

Server端基本架构也是包含三层:

  • 协议层:使用HTTP或者Socket传输接受数据
  • 转发层:转发到具体被调用的Impl实现类上处理
  • 实现层:具体的接口实现

源码地址:gitee.com/wandali/eas…

我们先来定义一个客户端和服务器端都需要调用的类,该类实际上是在服务器端实现:

// 在common中
package com.codelifeliwan.rpc.common;

/**
 * 定义客户端和服务器端都通用的计算器类
 */
public interface Calculator {
    /**
     * 这里只定义一个简单的int的加法
     *
     * @param a
     * @param b
     * @return
     */
    int add(int a, int b);
}

然后再来定义一个在不同端之间通信的消息类型,并负责序列化和反序列化工作:

// 在common中
package com.codelifeliwan.rpc.common;

import java.io.*;

/**
 * Java对象序列化和反序列化类
 */
public class RPCMessage implements Serializable {
    private static final long serialVersionUID = 23875263L;

    private String className; // 使用的序列化类,不包含包名

    /**
     * 发送的消息体
     */
    private String methodName;
    private Class<?>[] paramTypes;
    private Object[] paramValues;

    /**
     * 接收的消息体
     */
    private Object value;

    // 序列化
    public void writeMessage(OutputStream out) throws Exception {
        ObjectOutputStream outputStream = new ObjectOutputStream(out);
        outputStream.writeObject(this);
    }

    // 反序列化
    public static RPCMessage getMessage(InputStream input) throws Exception {
        RPCMessage message = (RPCMessage) new ObjectInputStream(input).readObject();
        return message;
    }

    // getter和setter省略...
}

服务器端,主要包括三个类:

  • CalculatorImpl:对Calculator接口的具体实现
  • RPCServer:服务器的具体实现,包括协议层和转发层,并自定义了一些bean用来处理请求数据(类似Spring的Controller)
  • ServerStarter:启动服务器用

先看具体实现:

package com.codelifeliwan.rpc.server;

import com.codelifeliwan.rpc.common.Calculator;

/**
 * server实现层
 * 在server端定义计算器的具体实现
 */
public class CalculatorImpl implements Calculator {

    @Override
    public int add(int a, int b) {
        return a + b;
    }
}

再看RPCServer的具体实现:

package com.codelifeliwan.rpc.server;

import com.codelifeliwan.rpc.common.RPCMessage;

import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Method;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.HashMap;
import java.util.Map;

public class RPCServer {
    /**
     * 定义服务器端监听的地址和端口
     */
    private int port = 8921;

    /**
     * 定义被注册进来的处理消息的bean
     * key=bean-name  value=bean-object
     * bean就是接口的Impl具体实现
     */
    private Map<String, Object> register_unseril_classes = new HashMap<>();

    public RPCServer() {
    }

    public RPCServer(int port) {
        this.port = port;
    }

    /**
     * 服务器协议层,负责请求接收、序列化和反序列化等
     */
    public void start() throws Exception {
        ServerSocket server = new ServerSocket(port);
        while (true) {
            System.out.println("waiting for message...");
            Socket socket = server.accept(); // 监听等待消息
            try {
                InputStream inputStream = socket.getInputStream(); // 获取到消息后取得输入流
                RPCMessage message = RPCMessage.getMessage(inputStream); // 消息发送使用Java的序列化方法

                if (!register_unseril_classes.containsKey(message.getClassName())) {
                    System.out.println("未知对象:message.getClassName()");
                    throw new Exception("未知对象:message.getClassName()");
                }

                // 利用反射机制来实现方法调用
                RPCMessage response = processMessage(message);

                // 将返回结果写回
                new ObjectOutputStream(socket.getOutputStream()).writeObject(response);
            } catch (Exception e) {
                e.printStackTrace();
                RPCMessage response = new RPCMessage();
                response.setValue("error:" + e.getMessage());
                new ObjectOutputStream(socket.getOutputStream()).writeObject(response);
            }
        }
    }

    /**
     * 服务器转发层,负责寻找处理器并执行具体的业务实现
     */
    private RPCMessage processMessage(RPCMessage message) throws Exception {
        // 利用反射机制来实现方法调用
        Object bean = register_unseril_classes.get(message.getClassName());
        Method method = bean.getClass().getMethod(message.getMethodName(), message.getParamTypes());
        RPCMessage response = new RPCMessage();
        response.setClassName(message.getClassName());
        response.setValue(method.invoke(bean, message.getParamValues()));

        return response;
    }

    /**
     * 注册bean
     *
     * @param beanName
     * @param clazz
     */
    public void registerBean(String beanName, Class<?> clazz) {
        if (beanName == null || clazz == null) return;
        try {
            if (register_unseril_classes.containsKey(beanName)) throw new Exception("bean name existed");
            register_unseril_classes.put(beanName, clazz.getConstructor().newInstance());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

最后,我们启动服务器:

package com.codelifeliwan.rpc.server;

import com.codelifeliwan.rpc.common.Calculator;

public class ServerStarter {

    public static void main(String[] args) throws Exception {
        RPCServer server = new RPCServer();

        // 注册解析bean
        server.registerBean(Calculator.class.getName(), CalculatorImpl.class);

        // 启动服务器
        server.start();
    }
}

到此,服务器端启动成功,客户端可以与服务器端进行通信。

客户端也包括了三个类:

  • ClientProtocol:协议层,用来发送和接收数据
  • RPCBeanFactory:调用层,使用JDK动态代理用来动态创建目标对象,并对调用进行封装
  • ClientTest:应用层,像使用本地方法一样调用rpc服务

协议层实现代码:

package com.codelifeliwan.rpc.client;

import com.codelifeliwan.rpc.common.RPCMessage;

import java.io.ObjectOutputStream;
import java.net.Socket;

/**
 * 协议层
 */
public class ClientProtocol {
    private String host = "localhost";
    private int port = 8921;

    public ClientProtocol(String host, int port) {
        this.host = host;
        this.port = port;
    }

    /**
     * 与server建立连接,发送和接收数据
     *
     * @return
     */
    public RPCMessage call(RPCMessage message) throws Exception {
        Socket socket = new Socket(host, port);

        // 向server发送请求
        new ObjectOutputStream(socket.getOutputStream()).writeObject(message);
        socket.shutdownOutput();

        // 获取server的回复
        RPCMessage response = RPCMessage.getMessage(socket.getInputStream());
        socket.close();
        // System.out.println("client socket closed");
        return response;
    }
}

调用层实现代码:

package com.codelifeliwan.rpc.client;

import com.codelifeliwan.rpc.common.RPCMessage;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

/**
 * 调用层,在该层中调用对象
 * 代理动态对象
 */
public class RPCBeanFactory<T> implements InvocationHandler {
    /**
     * 定义服务器端监听的地址和端口
     */
    private String host = "localhost";
    private int port = 8921;
    private ClientProtocol protocol;
    private T object;
    private Class<T> clazz;

    public RPCBeanFactory(String host, int port, Class<T> interfaceClass) {
        this.host = host;
        this.port = port;
        this.protocol = new ClientProtocol(host, port);
        this.clazz = interfaceClass;
        this.object = (T) Proxy.newProxyInstance(RPCBeanFactory.class.getClassLoader(), new Class[]{interfaceClass}, this);
    }

    /**
     * 负责使用JDK动态代理来生成对象
     *
     * @param proxy
     * @param method
     * @param args
     * @return
     * @throws Throwable
     */
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

        // 序列化数据并通过协议层调用
        RPCMessage message = new RPCMessage();
        message.setClassName(this.clazz.getName());
        message.setMethodName(method.getName());
        message.setParamTypes(method.getParameterTypes());
        message.setParamValues(args);

        RPCMessage response = protocol.call(message);
        return response.getValue();
    }

    /**
     * 获取可被调用的对象
     *
     * @return
     * @throws Exception
     */
    public T getObject() throws Exception {
        return this.object;
    }
}

最后,应用层代码:

package com.codelifeliwan.rpc.client;

import com.codelifeliwan.rpc.common.Calculator;

public class ClientTest {

    // 应用层
    public static void main(String[] args) throws Exception {
        RPCBeanFactory<Calculator> factory = new RPCBeanFactory<Calculator>("localhost", 8921, Calculator.class);

        Calculator cal = factory.getObject();

        /**
         * 直接像调用本地方法一样调用远程方法
         */
        System.out.println("1+4=" + cal.add(1, 4));
    }
}

在使用的时候,先运行ServerStarter启动服务器,再运行ClientTest调用服务器来进行测试。

我们再使用多线程来测试一下:

package com.codelifeliwan.rpc.client;

import com.codelifeliwan.rpc.common.Calculator;

import java.util.Random;

public class ClientTest {

    // 应用层
    public static void main(String[] args) throws Exception {
        RPCBeanFactory<Calculator> factory = new RPCBeanFactory<Calculator>("localhost", 8921, Calculator.class);

        Calculator cal = factory.getObject();

        Runnable r = () -> {
            for (int i = 0; i < 5; i++) {
                Random random = new Random(System.currentTimeMillis());
                int a = random.nextInt(20);
                int b = random.nextInt(20);

                // 直接像调用本地方法一样调用远程方法
                System.out.println("" + a + " + " + b + " = " + cal.add(a, b));

                try {
                    Thread.sleep(10);
                } catch (InterruptedException e) {
                }
            }
        };

        // 开启多线程调用试试
        for (int i = 0; i < 5; i++) new Thread(r).start();
    }
}

多线程的输出结果:

0 + 3 = 3
5 + 10 = 15
0 + 3 = 3
0 + 3 = 3
0 + 3 = 3
10 + 17 = 27
10 + 17 = 27
10 + 17 = 27
16 + 5 = 21
6 + 12 = 18
7 + 9 = 16
7 + 9 = 16
15 + 1 = 16
9 + 13 = 22
3 + 5 = 8
10 + 0 = 10
16 + 8 = 24
16 + 8 = 24
6 + 16 = 22
11 + 4 = 15
18 + 1 = 19
7 + 13 = 20
7 + 13 = 20
4 + 9 = 13
17 + 0 = 17

至此,一个rpc实现完了,读者可自行优化。