linux为什么能跨进程传递socket文件描述符

在Linux中一切皆文件,文件系统是进程所共有的。而socket本身是在网络文件系统空间申请的,socket也是文件一种,所以在同一台主机下,socket是可以跨进程传递的。 下面仔细跟踪一下socket创建的过程(3.10内核)。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
int server_fd = socket(AF_INET, SOCK_STREAM, 0 );
//  ...
// 
static int sock_map_fd(struct socket *sock, int flags)
{
	struct file *newfile;
    // 分配fd
	int fd = get_unused_fd_flags(flags);
	if (unlikely(fd < 0))
		return fd;
    // 到网络空间分配文件
	newfile = sock_alloc_file(sock, flags, NULL);
	if (likely(!IS_ERR(newfile))) {
		fd_install(fd, newfile);
		return fd;
	}

	put_unused_fd(fd);
	return PTR_ERR(newfile);
}


// 网络空间分配文件
struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname)
{
	struct qstr name = { .name = "" };
	struct path path;
	struct file *file;

	if (dname) {
		name.name = dname;
		name.len = strlen(name.name);
	} else if (sock->sk) {
		name.name = sock->sk->sk_prot_creator->name;
		name.len = strlen(name.name);
	}
	path.dentry = d_alloc_pseudo(sock_mnt->mnt_sb, &name);
	if (unlikely(!path.dentry))
		return ERR_PTR(-ENOMEM);
	path.mnt = mntget(sock_mnt);

	d_instantiate(path.dentry, SOCK_INODE(sock));
	SOCK_INODE(sock)->i_fop = &socket_file_ops;

	file = alloc_file(&path, FMODE_READ | FMODE_WRITE,
		  &socket_file_ops);
	if (unlikely(IS_ERR(file))) {
		/* drop dentry, keep inode */
		ihold(path.dentry->d_inode);
		path_put(&path);
		return file;
	}
    // 对于下面两行的拓展使用:
    // 系统在使用socket接口进行操作到时候,都需要通过这个文件来获取socket结构,那么只要有文件描述符,就可以在file结构中
    // private_data字段获取socket结构,并对其进行操作
    // 所以同一台主机上socket文件是可以传递的。
	sock->file = file;
	file->f_flags = O_RDWR | (flags & O_NONBLOCK);
	file->private_data = sock;
	return file;
}

如何传递、传递过程中注意哪些问题

  • 父子进程建立unix socket连接传递
  • 假设传递的socket为client_fd,调用dup_fd = dup(client_fd)(复制fd),然后将 sendmsg发送给子进程,然后close(client_fd)(在发往子进程网络的过程中,socket依然可以接收数据,这时候父进程可能捕获到该事件并读取了数据可能导致子进程 获取不到该事件,导致数据漏读,如果先dup一个出来,然后把原来的关闭,那么等dup_fd到达之后就可以响应到该数据事件),
  • 子进程recvmsg收到dup_fd之后,调用new_fd = dup(dup_fd),然后close(dup_fd)(原因同样是在传递过程中接收到数据,这样dup_fd没有办法捕捉到,dup之后就能获取到该数据响应事件)
  • DupCloseOnExec close_on_exec,当父进程打开文件时,只需要应用程序设置FD_CLOSEXEC标志位,则当fork后exec其他程序的时候,内核自动会将其继承的父进程FD关闭
  • unix socket也可以像其它fd一样进行跨进程复制
  • 跨进程复制的listen fd如果不关闭,都可以accept
  • 跨进程复制的socket如果不关闭,都可以进行收发数据,收数据的时候竞争关系
  • 复制过去的socket接收、发送缓存区是同一个
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

// socket发送、接收函数
ssize_t sendmsg(int socket, const struct msghdr *message, int flags);
ssize_t recvmsg(int socket, struct msghdr *message, int flags);

// 相关数据结构
struct msghdr {
	void		*msg_name;	/* ptr to socket address structure */ // 数据的目的地址,网络包指向sockaddr_in, netlink则指向sockaddr_nl;
	int		msg_namelen;	/* size of socket address structure */ // msg_name 所代表的地址长度
	struct iovec	*msg_iov;	/* scatter/gather array */  // 指向的是缓冲区数组
	__kernel_size_t	msg_iovlen;	/* # elements in msg_iov */ // 缓冲区数组长度
	void		*msg_control;	/* ancillary data */    // 辅助数据,控制信息(发送任何的控制信息)
	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */  // 辅助信息长度
	unsigned int	msg_flags;	/* flags on received message */ // 消息标识
};

struct iovec
{
	void __user *iov_base;	/* BSD uses caddr_t (1003.1g requires void *) */
	__kernel_size_t iov_len; /* Must be size_t (1003.1g) */
};

struct cmsghdr {
	__kernel_size_t	cmsg_len;	/* data byte count, including hdr */
        int		cmsg_level;	/* originating protocol */
        int		cmsg_type;	/* protocol-specific type */
};
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// 简单的例子
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <sys/uio.h>
#include <errno.h>
#include <netinet/in.h>
#include <time.h>
#include <signal.h>
#include <arpa/inet.h>


using namespace std;

int tcpServer();

void send_fd(int sock, int fd)
{
    iovec iov[1];
    char c = 0;
    iov[0].iov_base = &c;
    iov[0].iov_len  = 1;
    
    int cmsgsize = CMSG_LEN(sizeof(int));
    cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
    if(cmptr == NULL){
        cout << "[send_fd] init cmptr error" << endl;
        exit(1);
    }
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type = SCM_RIGHTS; // we are sending fd.
    cmptr->cmsg_len = cmsgsize;
 
    msghdr msg;
    msg.msg_iov = iov;
    msg.msg_iovlen = 1;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
    msg.msg_control = cmptr;
    msg.msg_controllen = cmsgsize;
    *(int *)CMSG_DATA(cmptr) = fd;
    
    int ret = sendmsg(sock, &msg, 0);
    free(cmptr);
    if (ret == -1){
        cout << "[send_fd] sendmsg error" << endl;
        exit(1);
    }
}
 
int recv_fd(int sock)
{
    int cmsgsize = CMSG_LEN(sizeof(int));
    cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
    
    char buf[32]; // the max buf in msg.
    iovec iov[1];
    iov[0].iov_base = buf;
    iov[0].iov_len = sizeof(buf);
    
    msghdr msg;
    msg.msg_iov = iov;
    msg.msg_iovlen  = 1;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
    msg.msg_control = cmptr;
    msg.msg_controllen = cmsgsize;
    
    int ret = recvmsg(sock, &msg, 0);
    // free(cmptr);
    if (ret == -1) {
        cout << "[recv_fd] recvmsg error" << endl;
        exit(1);
    }
    
    int fd = *(int *)CMSG_DATA(cmptr);
    cout<< "接收的fd为"<< fd << endl;
    
    return fd;

    // int nfd = dup(fd);
    // return nfd;
}
 
void master_process_cycle(int fds[2]){
    cout << "master process #" << getpid() << endl;
    
    // master use fds[0], and close fds[1]
    int fd = fds[0];
    close(fds[1]);
    cout << "channel: #" << fds[0] << ", #" << fds[1] << ", fd=#" << fd << endl;


    int listenFD = tcpServer();
    if (listenFD < 0){
        cout << "tcp server fail" << endl;
    }

    send_fd(fd, listenFD);

    
   for(;;){
        sleep(1);
        // pause();
    }
}
 
void worker_process_cycle(int fds[2]){
    cout << "worker process #" << getpid() << endl;
    int fd = fds[1];
            
    int file = recv_fd(fd);
    if(file < 0){
        cout << "[worker] invalid fd! " << endl;
        exit(1);
    }
      for(;;){
        sleep(1);
    }
}
 
int main(int argc, char** argv){
    cout << "current pid: " << getpid() << endl;
    
    int fds[2];
    if(socketpair(AF_UNIX, SOCK_STREAM, 0, fds) == -1){
        cout << "failed to create domain socket by socketpair" << endl;
        exit(1);
    }

    cout << "create domain socket by socketpair success" << endl;
    
    cout << "create progress to communicate over domain socket" << endl;
    pid_t pid = fork();
    if(pid == 0){
        worker_process_cycle(fds);
    }
    else{
        master_process_cycle(fds);
    }
    
    for(;;){
        sleep(1);
        // pause();
    }
}


 int tcpServer() {
    int listenFD;
    int on = 1;
    socklen_t addrLen = 0;
    pid_t pid, pid_child, pid_send;
    struct sockaddr_in server_addr;
    struct sockaddr_in client_addr;


    if ((listenFD = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
        printf("create socket err \n");
        return -1;
    }

    /*设置服务端地址*/
    addrLen = sizeof(struct sockaddr_in);
    memset(&server_addr, 0, addrLen);
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
    server_addr.sin_port = htons(6666);


    /*绑定地址结构到套接字描述符*/
    if (bind(listenFD, (struct sockaddr *) &server_addr, sizeof(server_addr)) == -1) {
        cout << "bind fail" << endl;
        return -1;
    }

    if (listen(listenFD, 100) == -1) {
        cout << "listen fail" << endl;
        return -1;
    }
    return listenFD;
}

Linux 使用率

测试不两个进程同时拥有socket的测试代码,按需求自己打开注释

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
package main

import (
	"fmt"
	"net"
	"os"
	"syscall"
	"time"

	"golang.org/x/sys/unix"
)

func main() {
	tcpSrv := NewTcpSrv()
	if len(os.Args) <= 1 {
		if err := tcpSrv.Init(); err != nil {
			fmt.Println("tcp srv init fail, err is", err)
			return
		}
		if err := tcpSrv.Start(); err != nil {
			fmt.Println("tcp srv start fail, err is", err)
			return
		}
		//// 迁移listen
		//if err := tcpSrv.SendListenerWithUnixSocket(); err != nil {
		//	fmt.Println("send listener with unix socket fail, err is", err)
		//	return
		//}

		// 迁移conn
		if err := tcpSrv.SendConnWithUnixSocket(); err != nil {
			fmt.Println("send listener with unix socket fail, err is", err)
			return
		}

	} else {
		//// 迁移listen
		//if err := tcpSrv.RecvListenerFromUnixSocket(); err != nil {
		//	fmt.Println("recv listener with unix socket fail, err is", err)
		//	return
		//}

		// 迁移conn
		if err := tcpSrv.RecvConnFromUnixSocket(); err != nil {
			fmt.Println("recv listener with unix socket fail, err is", err)
			return
		}
	}

	select {}
}

type TcpSrv struct {
	listener *net.TCPListener
	conns    map[string]*net.TCPConn
}

func NewTcpSrv() *TcpSrv {
	return &TcpSrv{
		conns: make(map[string]*net.TCPConn),
	}
}

func (t *TcpSrv) Init() error {
	listener, err := net.Listen("tcp", ":7000")
	if err != nil {
		return err
	}

	t.listener = listener.(*net.TCPListener)
	return nil
}

func (t *TcpSrv) Start() error {
	go func() {
		for {
			conn, err := t.listener.Accept()
			if err != nil {
				fmt.Println("accept fail, err msg is", err)
				continue
			}
			go t.clientSrv(conn)
			storeConn := conn.(*net.TCPConn)
			t.conns[conn.RemoteAddr().String()] = storeConn
		}
	}()
	return nil
}

func (t *TcpSrv) StartWithListenSocket(listener *net.TCPListener) error {
	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("accept fail, err msg is", err)
				continue
			}
			go t.clientSrv(c)
		}
	}()
	return nil
}

func (t *TcpSrv) clientSrv(conn net.Conn) {
	defer conn.Close()
	buf := make([]byte, 1024)
	for {
		time.Sleep(1 * time.Second)

		nRead, err := conn.Read(buf)
		if err != nil {
			fmt.Println("read msg fail, err is", err)
			return
		}
		fmt.Println("recv msg is", string(buf[:nRead]))

		if _, err := conn.Write(buf[:nRead]); err != nil {
			fmt.Println("write msg fail, err is", err)
			return
		}
	}
}

func (t *TcpSrv) SendListenerWithUnixSocket() error {
	_ = os.Remove("/tmp/unix_socket_tcp")
	addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("Cannot resolve unix addr: " + err.Error())
		return err
	}

	listener, err := net.ListenUnix("unix", addr)
	if err != nil {
		fmt.Println("Cannot listen to unix domain socket: " + err.Error())
		return err
	}
	fmt.Println("Listening on", listener.Addr())

	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("Accept: " + err.Error())
				return
			}

			file, _ := t.listener.File()
			buf := make([]byte, 1)
			buf[0] = 0
			rights := syscall.UnixRights(int(file.Fd()))
			_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
			if err != nil {
				fmt.Println("同步listen socket fail, err is", err.Error())
			}
		}
	}()

	return nil
}

func (t *TcpSrv) RecvListenerFromUnixSocket() error {
	connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("net dial unix fail", err.Error())
		return err
	}
	defer func() {
		_ = connInterface.Close()
	}()

	unixConn := connInterface.(*net.UnixConn)

	b := make([]byte, 1)
	oob := make([]byte, 32)
	for {
		err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if n != 1 || b[0] != 0 {
			if n != 1 {
				fmt.Printf("recv fd type error: %d\n", n)
			} else {
				fmt.Println("init finish")
			}
			return err
		}

		scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(scms) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(scms))
			return err
		}
		fds, err := unix.ParseUnixRights(&scms[0])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(fds) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(fds))
			return err
		}
		fmt.Printf("recv fd %d\n", fds[0])
		// 这里需要把file close, 不然每次重启都会多复制一个socket
		file := os.NewFile(uintptr(fds[0]), "fd-from-old")
		conn, err := net.FileListener(file)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		_ = file.Close()
		fmt.Println(conn)

		lc := conn.(*net.TCPListener)
		go t.StartWithListenSocket(lc)
	}
}

func (t *TcpSrv) SendConnWithUnixSocket() error {
	_ = os.Remove("/tmp/unix_socket_tcp")
	addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("Cannot resolve unix addr: " + err.Error())
		return err
	}

	listener, err := net.ListenUnix("unix", addr)
	if err != nil {
		fmt.Println("Cannot listen to unix domain socket: " + err.Error())
		return err
	}
	fmt.Println("Listening on", listener.Addr())

	go func() {
		for {
			c, err := listener.Accept()
			if err != nil {
				fmt.Println("Accept: " + err.Error())
				return
			}
			for _, conn := range t.conns {
				file, _ := conn.File()
				buf := make([]byte, 1)
				buf[0] = 0
				rights := syscall.UnixRights(int(file.Fd()))
				_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
				if err != nil {
					fmt.Println("同步listen socket fail, err is", err.Error())
				}
			}
		}
	}()

	return nil
}

func (t *TcpSrv) RecvConnFromUnixSocket() error {
	connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
	if err != nil {
		fmt.Println("net dial unix fail", err.Error())
		return err
	}
	defer func() {
		_ = connInterface.Close()
	}()

	unixConn := connInterface.(*net.UnixConn)

	b := make([]byte, 1)
	oob := make([]byte, 32)
	for {
		err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if n != 1 || b[0] != 0 {
			if n != 1 {
				fmt.Printf("recv fd type error: %d\n", n)
			} else {
				fmt.Println("init finish")
			}
			return err
		}

		scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(scms) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(scms))
			return err
		}
		fds, err := unix.ParseUnixRights(&scms[0])
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		if len(fds) != 1 {
			fmt.Printf("recv fd num != 1 : %d\n", len(fds))
			return err
		}
		fmt.Printf("recv fd %d\n", fds[0])
		// 这里需要把file close, 不然每次重启都会多复制一个socket
		file := os.NewFile(uintptr(fds[0]), "fd-from-old")
		conn, err := net.FileConn(file)
		if err != nil {
			fmt.Println(err.Error())
			return err
		}
		_ = file.Close()
		fmt.Println(conn)
		t.clientSrv(conn)
	}
}