• 用Golang手写一个RPC,理解RPC原理


    代码结构

    .
    ├── client.go
    ├── coder.go
    ├── coder_test.go
    ├── rpc_test.go
    ├── server.go
    ├── session.go
    └── session_test.go
    

    代码

    client.go

    package rpc
    
    import (
    	"net"
    	"reflect"
    )
    
    // rpc 客户端实现
    
    // 抽象客户端方法
    type Client struct {
    	conn net.Conn
    }
    
    // client构造方法
    func NewClient(conn net.Conn) *Client {
    	return &Client{conn: conn}
    }
    
    // 客户端调用服务端rpc实现
    // client.RpcCall("login", &req)
    func (c *Client) RpcCall(name string, fpr interface{}) {
    	// 反射获取函数原型
    	fn := reflect.ValueOf(fpr).Elem()
    	// 客户端逻辑的实现
    	f := func(args []reflect.Value) (results []reflect.Value) {
    		// 从匿名函数中构建请求参数
    		inArgs := make([]interface{}, 0, len(args))
    		for _, v := range args {
    			inArgs = append(inArgs, v.Interface())
    		}
    		// 组装rpc data请求数据
    		reqData := RpcData{Name: name, Args: inArgs}
    		// 进行数据编码
    		reqByteData, err := encode(reqData)
    		if err != nil {
    			return
    		}
    		// 创建session 对象
    		session := NewSession(c.conn)
    		// 客户端发送数据
    		err = session.Write(reqByteData)
    		if err != nil {
    			return
    		}
    		// 读取客户端数据
    		rspByteData, err := session.Read()
    		if err != nil {
    			return
    		}
    		// 数据进行解码
    		rspData, err := decode(rspByteData)
    		if err != nil {
    			return
    		}
    		// 处理服务端返回的数据结果
    		outArgs := make([]reflect.Value, 0, len(rspData.Args))
    		for i, v := range rspData.Args {
    			// 数据特殊情况处理
    			if v == nil {
    				// reflect.Zero() 返回某类型的零值的value
    				// .Out()返回函数输出的参数类型
    				// 得到具体第几个位置的参数的零值
    				outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
    				continue
    			}
    			outArgs = append(outArgs, reflect.ValueOf(v))
    		}
    
    		return outArgs
    	}
    
    	// 函数原型到调用的关键,需要2个参数
    	// 参数1:函数原型,是Type类型
    	// 参数2:返回类型是Value类型
    	// 简单理解:参数1是函数原型,参数2是客户端逻辑
    	v := reflect.MakeFunc(fn.Type(), f)
    	fn.Set(v)
    }
    

    coder.go

    package rpc
    
    import (
    	"bytes"
    	"encoding/gob"
    	"fmt"
    )
    
    // 对传输的数据进行编解码
    // 使用Golang自带的一个数据结构序列化编码/解码工具 gob
    
    // 定义rpc数据交互式数据传输格式
    type RpcData struct {
    	Name string        // 调用方法名
    	Args []interface{} // 调用和返回的参数列表
    }
    
    // 编码
    func encode(data RpcData) ([]byte, error) {
    	// gob进行编码
    	var buf bytes.Buffer
    	// 得到字节编码器
    	encoder := gob.NewEncoder(&buf)
    	// 进行编码
    	if err := encoder.Encode(data); err != nil {
    		fmt.Printf("gob encode failed, err: %v
    ", err)
    		return nil, err
    	}
    	return buf.Bytes(), nil
    }
    
    // 解码
    func decode(data []byte) (RpcData, error) {
    	// 得到字节解码器
    	buf := bytes.NewBuffer(data)
    	decoder := gob.NewDecoder(buf)
    	// 解码数据
    	var rd RpcData
    	if err := decoder.Decode(&rd); err != nil {
    		fmt.Printf("gob decode failed, err: %v
    ", err)
    		return rd, err
    	}
    	return rd, nil
    }
    

    server.go

    package rpc
    
    import (
    	"net"
    	"reflect"
    )
    
    // rpc 服务端实现
    
    // 抽象服务端
    type Server struct {
    	add   string                   // 连接地址
    	funcs map[string]reflect.Value // 存储方法名和方法的对应关系,服务注册
    }
    
    // server 构造方法
    func NewServer(addr string) *Server {
    	return &Server{add: addr, funcs: make(map[string]reflect.Value)}
    }
    
    // 注册接口
    func (s *Server) Register(name string, fc interface{}) {
    	if _, ok := s.funcs[name]; ok {
    		return
    	}
    	s.funcs[name] = reflect.ValueOf(fc)
    }
    
    func (s *Server) Run() (err error) {
    	listener, err := net.Listen("tcp", s.add)
    	if err != nil {
    		return
    	}
    	for {
    		// 监听连接
    		conn, err := listener.Accept()
    		if err != nil {
    			conn.Close()
    			continue
    		}
    		// 创建会话
    		session := NewSession(conn)
    		// 读取会话请求数据
    		reqData, err := session.Read()
    		if err != nil {
    			conn.Close()
    			continue
    		}
    		// 数据解码
    		rpcReqData, err := decode(reqData)
    		// 获取客户端要调用的方法
    		fc, ok := s.funcs[rpcReqData.Name];
    		if !ok {
    			conn.Close()
    			continue
    		}
    		// 获取请求的参数列表
    		args := make([]reflect.Value, 0, len(rpcReqData.Args))
    		for _, v := range rpcReqData.Args {
    			args = append(args, reflect.ValueOf(v))
    		}
    		// 调用
    		callReslut := fc.Call(args)
    		// 处理调用返回的数据结果
    		rargs := make([]interface{}, 0, len(callReslut))
    		for _, rv := range callReslut {
    			rargs = append(rargs, rv.Interface())
    		}
    		// 构建返回的rpc数据
    		rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs}
    		// 返回数据进行编码
    		rspData, err := encode(rpcRspData)
    		if err != nil {
    			conn.Close()
    			continue
    		}
    		err = session.Write(rspData)
    		if err != nil {
    			conn.Close()
    			continue
    		}
    	}
    	return
    }
    

    session.go

    package rpc
    
    import (
    	"encoding/binary"
    	"fmt"
    	"io"
    	"net"
    )
    
    // 处理连接会话
    
    // 会话对象结构体
    type Session struct {
    	conn net.Conn
    }
    
    // 传输数据存储方式
    // 字节数组, 添加4个字节的头,用来存储数据的长度
    
    // 会话构造函数
    func NewSession(conn net.Conn) *Session {
    	return &Session{conn: conn}
    }
    
    // 从连接中读取数据
    func (s *Session) Read() (data []byte, err error) {
    	// 读取数据header数据
    	header := make([]byte, 4)
    	_, err = s.conn.Read(header)
    	if err != nil {
    		fmt.Printf("read conn header data failed, err: %v
    ", err)
    		return
    	}
    	// 读取body数据
    	hlen := binary.BigEndian.Uint32(header)
    	data = make([]byte, hlen)
    	_, err = io.ReadFull(s.conn, data)
    	if err != nil {
    		fmt.Printf("read conn body data failed, err: %v
    ", err)
    		return
    	}
    	return
    }
    
    // 向连接中写入数据
    func (s *Session) Write(data []byte) (err error) {
    	// 创建数据字节切片
    	buf := make([]byte, 4+len(data))
    	// 向header写入数据长度
    	binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
    	// 写入body内容
    	copy(buf[4:], data)
    	// 写入连接数据
    	_, err = s.conn.Write(buf)
    	if err != nil {
    		fmt.Printf("write conn data failed, err: %v
    ", err)
    		return
    	}
    	return
    }
    

    coder_test.go

    package rpc
    
    import (
    	"testing"
    )
    
    func TestCoder(t *testing.T) {
    	rd := RpcData{
    		Name: "login",
    		Args: []interface{}{"zhangsan", "zs123"},
    	}
    
    	eData, err := encode(rd)
    	if err != nil {
    		t.Error(err)
    		return
    	}
    	t.Logf("gob 编码后数据长度: %d
    ", len(eData))
    
    	dData, err := decode(eData)
    	if err != nil {
    		t.Error(err)
    		return
    	}
    	t.Logf("%#v
    ", dData)
    }
    

    session_test.go

    package rpc
    
    import (
    	"net"
    	"sync"
    	"testing"
    )
    
    func TestSession(t *testing.T) {
    	addr := ":8080"
    	test_data := "my is test data"
    	var wg sync.WaitGroup
    	wg.Add(2)
    	// 写数据
    	go func() {
    		defer wg.Done()
    		listener, err := net.Listen("tcp", addr)
    		if err != nil {
    			t.Fatal(err)
    			return
    		}
    		conn, _ := listener.Accept()
    		s := NewSession(conn)
    		data, err := s.Read()
    		if err != nil {
    			t.Error(err)
    			return
    		}
    		t.Log(string(data))
    	}()
    
    	// 读数据
    	go func() {
    		defer wg.Done()
    		conn, err := net.Dial("tcp", addr)
    		if err != nil {
    			t.Fatal(err)
    			return
    		}
    		s := NewSession(conn)
    		err = s.Write([]byte(test_data))
    		if err != nil {
    			return
    		}
    		t.Log("写入数据成功")
    		return
    	}()
    
    	wg.Wait()
    }
    

    rpc_test.go

    package rpc
    
    import (
    	"encoding/gob"
    	"fmt"
    	"net"
    	"testing"
    )
    
    // rpc 客户端和服务端测试
    
    // 定义一个服务端结构体
    // 定义一个方法
    // 通过调用rpc方法查询用户的信息
    
    type User struct {
    	Name string
    	Age  int
    }
    
    // 定义查询用户的方法
    // 通过用户id查询用户数据
    func queryUser(id int) (User, error) {
    	// 造一些查询user的假数据
    	users := make(map[int]User)
    	users[0] = User{"user01", 22}
    	users[1] = User{"user02", 23}
    	users[2] = User{"user03", 24}
    	if u, ok := users[id]; ok {
    		return u, nil
    	}
    	return User{}, fmt.Errorf("%d id not found", id)
    
    }
    
    func TestRpc(t *testing.T) {
    	// 给gob注册类型
    	gob.Register(User{})
    
    	addr := ":8080"
    
    	// 创建服务端
    	server := NewServer(addr)
    	// 注册服务
    	server.Register("queryUser", queryUser)
    	// 启动服务端
    	go server.Run()
    
    	// 创建客户端连接
    	conn, err := net.Dial("tcp", addr)
    	if err != nil {
    		return
    	}
    	// 创客户端
    	client := NewClient(conn)
    	// 定义函数调用原型
    	var query func(int) (User, error)
    	// 客户端调用rpc
    	client.RpcCall("queryUser", &query)
    	// 得到返回结果
    	user, err := query(1)
    	if err != nil {
    		t.Error(err)
    		return
    	}
    	fmt.Printf("%#v
    ", user)
    }
    
  • 相关阅读:
    SQLZOO:SELECT from WORLD Tutorial
    Spyder——小技巧+快捷键
    JDK国内镜像
    debian 安装 plymouth 美化开机动画
    docker 国内镜像加速
    有关npm镜像加速的问题 yarn nvm yrm
    调整vscode工具栏侧边栏字体大小
    github的淘宝代理?
    fcitx5 主题设置
    debian testing安装qemu-kvm和virt-manager
  • 原文地址:https://www.cnblogs.com/zhichaoma/p/12638184.html
Copyright © 2020-2023  润新知