跨进程复制socket

linux为什么能跨进程传递socket文件描述符

在Linux中一切皆文件,文件系统是进程所共有的。而socket本身是在网络文件系统空间申请的,socket也是文件一种,所以在同一台主机下,socket是可以跨进程传递的。 下面仔细跟踪一下socket创建的过程(3.10内核)。

int server_fd = socket(AF_INET, SOCK_STREAM, 0 );
//  ...
// 
static int sock_map_fd(struct socket *sock, int flags)
{
	struct file *newfile;
    // 分配fd
	int fd = get_unused_fd_flags(flags);
	if (unlikely(fd < 0))
		return fd;
    // 到网络空间分配文件
	newfile = sock_alloc_file(sock, flags, NULL);
	if (likely(!IS_ERR(newfile))) {
		fd_install(fd, newfile);
		return fd;
	}

	put_unused_fd(fd);
	return PTR_ERR(newfile);
}


// 网络空间分配文件
struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname)
{
	struct qstr name = { .name = "" };
	struct path path;
	struct file *file;

	if (dname) {
		name.name = dname;
		name.len = strlen(name.name);
	} else if (sock->sk) {
		name.name = sock->sk->sk_prot_creator->name;
		name.len = strlen(name.name);
	}
	path.dentry = d_alloc_pseudo(sock_mnt->mnt_sb, &name);
	if (unlikely(!path.dentry))
		return ERR_PTR(-ENOMEM);
	path.mnt = mntget(sock_mnt);

	d_instantiate(path.dentry, SOCK_INODE(sock));
	SOCK_INODE(sock)->i_fop = &socket_file_ops;

	file = alloc_file(&path, FMODE_READ | FMODE_WRITE,
		  &socket_file_ops);
	if (unlikely(IS_ERR(file))) {
		/* drop dentry, keep inode */
		ihold(path.dentry->d_inode);
		path_put(&path);
		return file;
	}
    // 对于下面两行的拓展使用:
    // 系统在使用socket接口进行操作到时候,都需要通过这个文件来获取socket结构,那么只要有文件描述符,就可以在file结构中
    // private_data字段获取socket结构,并对其进行操作
    // 所以同一台主机上socket文件是可以传递的。
	sock->file = file;
	file->f_flags = O_RDWR | (flags & O_NONBLOCK);
	file->private_data = sock;
	return file;
}

如何传递、传递过程中注意哪些问题


// socket发送、接收函数
ssize_t sendmsg(int socket, const struct msghdr *message, int flags);
ssize_t recvmsg(int socket, struct msghdr *message, int flags);

// 相关数据结构
struct msghdr {
	void		*msg_name;	/* ptr to socket address structure */ // 数据的目的地址,网络包指向sockaddr_in, netlink则指向sockaddr_nl;
	int		msg_namelen;	/* size of socket address structure */ // msg_name 所代表的地址长度
	struct iovec	*msg_iov;	/* scatter/gather array */  // 指向的是缓冲区数组
	__kernel_size_t	msg_iovlen;	/* # elements in msg_iov */ // 缓冲区数组长度
	void		*msg_control;	/* ancillary data */    // 辅助数据,控制信息(发送任何的控制信息)
	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */  // 辅助信息长度
	unsigned int	msg_flags;	/* flags on received message */ // 消息标识
};

struct iovec
{
	void __user *iov_base;	/* BSD uses caddr_t (1003.1g requires void *) */
	__kernel_size_t iov_len; /* Must be size_t (1003.1g) */
};

struct cmsghdr {
	__kernel_size_t	cmsg_len;	/* data byte count, including hdr */
        int		cmsg_level;	/* originating protocol */
        int		cmsg_type;	/* protocol-specific type */
};
// 简单的例子
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <sys/uio.h>
#include <errno.h>
#include <netinet/in.h>
#include <time.h>
#include <signal.h>
#include <arpa/inet.h>


using namespace std;

int tcpServer();

void send_fd(int sock, int fd)
{
    iovec iov[1];
    char c = 0;
    iov[0].iov_base = &c;
    iov[0].iov_len  = 1;
    
    int cmsgsize = CMSG_LEN(sizeof(int));
    cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
    if(cmptr == NULL){
        cout << "[send_fd] init cmptr error" << endl;
        exit(1);
    }
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type = SCM_RIGHTS; // we are sending fd.
    cmptr->cmsg_len = cmsgsize;
 
    msghdr msg;
    msg.msg_iov = iov;
    msg.msg_iovlen = 1;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
    msg.msg_control = cmptr;
    msg.msg_controllen = cmsgsize;
    *(int *)CMSG_DATA(cmptr) = fd;
    
    int ret = sendmsg(sock, &msg, 0);
    free(cmptr);
    if (ret == -1){
        cout << "[send_fd] sendmsg error" << endl;
        exit(1);
    }
}
 
int recv_fd(int sock)
{
    int cmsgsize = CMSG_LEN(sizeof(int));
    cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
    
    char buf[32]; // the max buf in msg.
    iovec iov[1];
    iov[0].iov_base = buf;
    iov[0].iov_len = sizeof(buf);
    
    msghdr msg;
    msg.msg_iov = iov;
    msg.msg_iovlen  = 1;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
    msg.msg_control = cmptr;
    msg.msg_controllen = cmsgsize;
    
    int ret = recvmsg(sock, &msg, 0);
    // free(cmptr);
    if (ret == -1) {
        cout << "[recv_fd] recvmsg error" << endl;
        exit(1);
    }
    
    int fd = *(int *)CMSG_DATA(cmptr);
    cout<< "接收的fd为"<< fd << endl;
    
    return fd;

    // int nfd = dup(fd);
    // return nfd;
}
 
void master_process_cycle(int fds[2]){
    cout << "master process #" << getpid() << endl;
    
    // master use fds[0], and close fds[1]
    int fd = fds[0];
    close(fds[1]);
    cout << "channel: #" << fds[0] << ", #" << fds[1] << ", fd=#" << fd << endl;


    int listenFD = tcpServer();
    if (listenFD < 0){
        cout << "tcp server fail" << endl;
    }

    send_fd(fd, listenFD);

    
   for(;;){
        sleep(1);
        // pause();
    }
}
 
void worker_process_cycle(int fds[2]){
    cout << "worker process #" << getpid() << endl;
    int fd = fds[1];
            
    int file = recv_fd(fd);
    if(file < 0){
        cout << "[worker] invalid fd! " << endl;
        exit(1);
    }
      for(;;){
        sleep(1);
    }
}
 
int main(int argc, char** argv){
    cout << "current pid: " << getpid() << endl;
    
    int fds[2];
    if(socketpair(AF_UNIX, SOCK_STREAM, 0, fds) == -1){
        cout << "failed to create domain socket by socketpair" << endl;
        exit(1);
    }

    cout << "create domain socket by socketpair success" << endl;
    
    cout << "create progress to communicate over domain socket" << endl;
    pid_t pid = fork();
    if(pid == 0){
        worker_process_cycle(fds);
    }
    else{
        master_process_cycle(fds);
    }
    
    for(;;){
        sleep(1);
        // pause();
    }
}


 int tcpServer() {
    int listenFD;
    int on = 1;
    socklen_t addrLen = 0;
    pid_t pid, pid_child, pid_send;
    struct sockaddr_in server_addr;
    struct sockaddr_in client_addr;


    if ((listenFD = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
        printf("create socket err \n");
        return -1;
    }

    /*设置服务端地址*/
    addrLen = sizeof(struct sockaddr_in);
    memset(&server_addr, 0, addrLen);
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
    server_addr.sin_port = htons(6666);


    /*绑定地址结构到套接字描述符*/
    if (bind(listenFD, (struct sockaddr *) &server_addr, sizeof(server_addr)) == -1) {
        cout << "bind fail" << endl;
        return -1;
    }

    if (listen(listenFD, 100) == -1) {
        cout << "listen fail" << endl;
        return -1;
    }
    return listenFD;
}

Linux 使用率

测试不两个进程同时拥有socket的测试代码,按需求自己打开注释

package main

import (
	"fmt"
	"net"
	"os"
	"syscall"
	"time"

	"golang.org/x/sys/unix"
)

func main() {
	tcpSrv := NewTcpSrv()
	if len(os.Args) <= 1 {
		if err := tcpSrv.Init(); err != nil {
			fmt.Println("tcp srv init fail, err is", err)
			return
		}
		if err := tcpSrv.Start(); err != nil {
			fmt.Println("tcp srv start fail, err is", err)
			return
		}
		//// 迁移listen
		//if err := tcpSrv.SendListenerWithUnixSocket(); err != nil {
		//	fmt.Println("send listener with unix socket fail, err is", err)
		//	return
		//}

		// 迁移conn
		if err := tcpSrv.SendConnWithUnixSocket(); err != nil {
			fmt.Println("send listener with unix socket fail, err is", err)
			return
		}

	} else {
		//// 迁移listen
		//if err := tcpSrv.RecvListenerFromUnixSocket(); err != nil {
		//	fmt.Println("recv listener with unix socket fail, err is", err)
		//	return
		//}

		// 迁移conn
		if err := tcpSrv.RecvConnFromUnixSocket(); err != nil {
			fmt.Println("recv listener with unix socket fail, err is", err)
			return
		}
	}

	select {}
}

type TcpSrv struct {
	listener *net.TCPListener
	conns    map[string]*net.TCPConn
}

func NewTcpSrv() *TcpSrv {
	return &TcpSrv{
		conns: make(map[string]*net.TCPConn),
	}
}

func (t *TcpSrv) Init() error {
	listener, err := net.Listen("tcp", ":7000")
	if err != nil {
		return err
	}

	t.listener = listener.(*net.TCPListener)
	return nil
}

func (t *TcpSrv) Start() error {
	go func() {
		for {
			conn, err := t.listener.Accept()
			if err != nil {
				fmt.Println("accept fail, err msg is", err)
				continue
			}
			go t.clientSrv(conn)
			storeConn := conn.(*net.TCPConn)
			t.conns[conn.RemoteAddr().String()] = storeConn
		}
	}()
	return nil
}

func (t *TcpSrv) StartWithListenSocket(listener *net.TCPListener) error {
	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("accept fail, err msg is", err)
				continue
			}
			go t.clientSrv(c)
		}
	}()
	return nil
}

func (t *TcpSrv) clientSrv(conn net.Conn) {
	defer conn.Close()
	buf := make([]byte, 1024)
	for {
		time.Sleep(1 * time.Second)

		nRead, err := conn.Read(buf)
		if err != nil {
			fmt.Println("read msg fail, err is", err)
			return
		}
		fmt.Println("recv msg is", string(buf[:nRead]))

		if _, err := conn.Write(buf[:nRead]); err != nil {
			fmt.Println("write msg fail, err is", err)
			return
		}
	}
}

func (t *TcpSrv) SendListenerWithUnixSocket() error {
	_ = os.Remove("/tmp/unix_socket_tcp")
	addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("Cannot resolve unix addr: " + err.Error())
		return err
	}

	listener, err := net.ListenUnix("unix", addr)
	if err != nil {
		fmt.Println("Cannot listen to unix domain socket: " + err.Error())
		return err
	}
	fmt.Println("Listening on", listener.Addr())

	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("Accept: " + err.Error())
				return
			}

			file, _ := t.listener.File()
			buf := make([]byte, 1)
			buf[0] = 0
			rights := syscall.UnixRights(int(file.Fd()))
			_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
			if err != nil {
				fmt.Println("同步listen socket fail, err is", err.Error())
			}
		}
	}()

	return nil
}

func (t *TcpSrv) RecvListenerFromUnixSocket() error {
	connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("net dial unix fail", err.Error())
		return err
	}
	defer func() {
		_ = connInterface.Close()
	}()

	unixConn := connInterface.(*net.UnixConn)

	b := make([]byte, 1)
	oob := make([]byte, 32)
	for {
		err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if n != 1 || b[0] != 0 {
			if n != 1 {
				fmt.Printf("recv fd type error: %d\n", n)
			} else {
				fmt.Println("init finish")
			}
			return err
		}

		scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(scms) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(scms))
			return err
		}
		fds, err := unix.ParseUnixRights(&scms[0])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(fds) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(fds))
			return err
		}
		fmt.Printf("recv fd %d\n", fds[0])
		// 这里需要把file close, 不然每次重启都会多复制一个socket
		file := os.NewFile(uintptr(fds[0]), "fd-from-old")
		conn, err := net.FileListener(file)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		_ = file.Close()
		fmt.Println(conn)

		lc := conn.(*net.TCPListener)
		go t.StartWithListenSocket(lc)
	}
}

func (t *TcpSrv) SendConnWithUnixSocket() error {
	_ = os.Remove("/tmp/unix_socket_tcp")
	addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("Cannot resolve unix addr: " + err.Error())
		return err
	}

	listener, err := net.ListenUnix("unix", addr)
	if err != nil {
		fmt.Println("Cannot listen to unix domain socket: " + err.Error())
		return err
	}
	fmt.Println("Listening on", listener.Addr())

	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("Accept: " + err.Error())
				return
			}
			for _, conn := range t.conns {
				file, _ := conn.File()
				buf := make([]byte, 1)
				buf[0] = 0
				rights := syscall.UnixRights(int(file.Fd()))
				_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
				if err != nil {
					fmt.Println("同步listen socket fail, err is", err.Error())
				}
			}
		}
	}()

	return nil
}

func (t *TcpSrv) RecvConnFromUnixSocket() error {
	connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("net dial unix fail", err.Error())
		return err
	}
	defer func() {
		_ = connInterface.Close()
	}()

	unixConn := connInterface.(*net.UnixConn)

	b := make([]byte, 1)
	oob := make([]byte, 32)
	for {
		err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if n != 1 || b[0] != 0 {
			if n != 1 {
				fmt.Printf("recv fd type error: %d\n", n)
			} else {
				fmt.Println("init finish")
			}
			return err
		}

		scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(scms) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(scms))
			return err
		}
		fds, err := unix.ParseUnixRights(&scms[0])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(fds) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(fds))
			return err
		}
		fmt.Printf("recv fd %d\n", fds[0])
		// 这里需要把file close, 不然每次重启都会多复制一个socket
		file := os.NewFile(uintptr(fds[0]), "fd-from-old")
		conn, err := net.FileConn(file)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		_ = file.Close()
		fmt.Println(conn)
		t.clientSrv(conn)
	}
}