跨进程复制socket
linux为什么能跨进程传递socket文件描述符
在Linux中一切皆文件,文件系统是进程所共有的。而socket本身是在网络文件系统空间申请的,socket也是文件一种,所以在同一台主机下,socket是可以跨进程传递的。 下面仔细跟踪一下socket创建的过程(3.10内核)。
int server_fd = socket(AF_INET, SOCK_STREAM, 0 );
// ...
//
static int sock_map_fd(struct socket *sock, int flags)
{
struct file *newfile;
// 分配fd
int fd = get_unused_fd_flags(flags);
if (unlikely(fd < 0))
return fd;
// 到网络空间分配文件
newfile = sock_alloc_file(sock, flags, NULL);
if (likely(!IS_ERR(newfile))) {
fd_install(fd, newfile);
return fd;
}
put_unused_fd(fd);
return PTR_ERR(newfile);
}
// 网络空间分配文件
struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname)
{
struct qstr name = { .name = "" };
struct path path;
struct file *file;
if (dname) {
name.name = dname;
name.len = strlen(name.name);
} else if (sock->sk) {
name.name = sock->sk->sk_prot_creator->name;
name.len = strlen(name.name);
}
path.dentry = d_alloc_pseudo(sock_mnt->mnt_sb, &name);
if (unlikely(!path.dentry))
return ERR_PTR(-ENOMEM);
path.mnt = mntget(sock_mnt);
d_instantiate(path.dentry, SOCK_INODE(sock));
SOCK_INODE(sock)->i_fop = &socket_file_ops;
file = alloc_file(&path, FMODE_READ | FMODE_WRITE,
&socket_file_ops);
if (unlikely(IS_ERR(file))) {
/* drop dentry, keep inode */
ihold(path.dentry->d_inode);
path_put(&path);
return file;
}
// 对于下面两行的拓展使用:
// 系统在使用socket接口进行操作到时候,都需要通过这个文件来获取socket结构,那么只要有文件描述符,就可以在file结构中
// private_data字段获取socket结构,并对其进行操作
// 所以同一台主机上socket文件是可以传递的。
sock->file = file;
file->f_flags = O_RDWR | (flags & O_NONBLOCK);
file->private_data = sock;
return file;
}
如何传递、传递过程中注意哪些问题
- 父子进程建立unix socket连接传递
- 假设传递的socket为client_fd,调用dup_fd = dup(client_fd)(复制fd),然后将 sendmsg发送给子进程,然后close(client_fd)(在发往子进程网络的过程中,socket依然可以接收数据,这时候父进程可能捕获到该事件并读取了数据可能导致子进程 获取不到该事件,导致数据漏读,如果先dup一个出来,然后把原来的关闭,那么等dup_fd到达之后就可以响应到该数据事件),
- 子进程recvmsg收到dup_fd之后,调用new_fd = dup(dup_fd),然后close(dup_fd)(原因同样是在传递过程中接收到数据,这样dup_fd没有办法捕捉到,dup之后就能获取到该数据响应事件)
- DupCloseOnExec close_on_exec,当父进程打开文件时,只需要应用程序设置FD_CLOSEXEC标志位,则当fork后exec其他程序的时候,内核自动会将其继承的父进程FD关闭
- unix socket也可以像其它fd一样进行跨进程复制
- 跨进程复制的listen fd如果不关闭,都可以accept
- 跨进程复制的socket如果不关闭,都可以进行收发数据,收数据的时候竞争关系
- 复制过去的socket接收、发送缓存区是同一个
// socket发送、接收函数
ssize_t sendmsg(int socket, const struct msghdr *message, int flags);
ssize_t recvmsg(int socket, struct msghdr *message, int flags);
// 相关数据结构
struct msghdr {
void *msg_name; /* ptr to socket address structure */ // 数据的目的地址,网络包指向sockaddr_in, netlink则指向sockaddr_nl;
int msg_namelen; /* size of socket address structure */ // msg_name 所代表的地址长度
struct iovec *msg_iov; /* scatter/gather array */ // 指向的是缓冲区数组
__kernel_size_t msg_iovlen; /* # elements in msg_iov */ // 缓冲区数组长度
void *msg_control; /* ancillary data */ // 辅助数据,控制信息(发送任何的控制信息)
__kernel_size_t msg_controllen; /* ancillary data buffer length */ // 辅助信息长度
unsigned int msg_flags; /* flags on received message */ // 消息标识
};
struct iovec
{
void __user *iov_base; /* BSD uses caddr_t (1003.1g requires void *) */
__kernel_size_t iov_len; /* Must be size_t (1003.1g) */
};
struct cmsghdr {
__kernel_size_t cmsg_len; /* data byte count, including hdr */
int cmsg_level; /* originating protocol */
int cmsg_type; /* protocol-specific type */
};
// 简单的例子
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <sys/uio.h>
#include <errno.h>
#include <netinet/in.h>
#include <time.h>
#include <signal.h>
#include <arpa/inet.h>
using namespace std;
int tcpServer();
void send_fd(int sock, int fd)
{
iovec iov[1];
char c = 0;
iov[0].iov_base = &c;
iov[0].iov_len = 1;
int cmsgsize = CMSG_LEN(sizeof(int));
cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
if(cmptr == NULL){
cout << "[send_fd] init cmptr error" << endl;
exit(1);
}
cmptr->cmsg_level = SOL_SOCKET;
cmptr->cmsg_type = SCM_RIGHTS; // we are sending fd.
cmptr->cmsg_len = cmsgsize;
msghdr msg;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_name = NULL;
msg.msg_namelen = 0;
msg.msg_control = cmptr;
msg.msg_controllen = cmsgsize;
*(int *)CMSG_DATA(cmptr) = fd;
int ret = sendmsg(sock, &msg, 0);
free(cmptr);
if (ret == -1){
cout << "[send_fd] sendmsg error" << endl;
exit(1);
}
}
int recv_fd(int sock)
{
int cmsgsize = CMSG_LEN(sizeof(int));
cmsghdr* cmptr = (cmsghdr*)malloc(cmsgsize);
char buf[32]; // the max buf in msg.
iovec iov[1];
iov[0].iov_base = buf;
iov[0].iov_len = sizeof(buf);
msghdr msg;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_name = NULL;
msg.msg_namelen = 0;
msg.msg_control = cmptr;
msg.msg_controllen = cmsgsize;
int ret = recvmsg(sock, &msg, 0);
// free(cmptr);
if (ret == -1) {
cout << "[recv_fd] recvmsg error" << endl;
exit(1);
}
int fd = *(int *)CMSG_DATA(cmptr);
cout<< "接收的fd为"<< fd << endl;
return fd;
// int nfd = dup(fd);
// return nfd;
}
void master_process_cycle(int fds[2]){
cout << "master process #" << getpid() << endl;
// master use fds[0], and close fds[1]
int fd = fds[0];
close(fds[1]);
cout << "channel: #" << fds[0] << ", #" << fds[1] << ", fd=#" << fd << endl;
int listenFD = tcpServer();
if (listenFD < 0){
cout << "tcp server fail" << endl;
}
send_fd(fd, listenFD);
for(;;){
sleep(1);
// pause();
}
}
void worker_process_cycle(int fds[2]){
cout << "worker process #" << getpid() << endl;
int fd = fds[1];
int file = recv_fd(fd);
if(file < 0){
cout << "[worker] invalid fd! " << endl;
exit(1);
}
for(;;){
sleep(1);
}
}
int main(int argc, char** argv){
cout << "current pid: " << getpid() << endl;
int fds[2];
if(socketpair(AF_UNIX, SOCK_STREAM, 0, fds) == -1){
cout << "failed to create domain socket by socketpair" << endl;
exit(1);
}
cout << "create domain socket by socketpair success" << endl;
cout << "create progress to communicate over domain socket" << endl;
pid_t pid = fork();
if(pid == 0){
worker_process_cycle(fds);
}
else{
master_process_cycle(fds);
}
for(;;){
sleep(1);
// pause();
}
}
int tcpServer() {
int listenFD;
int on = 1;
socklen_t addrLen = 0;
pid_t pid, pid_child, pid_send;
struct sockaddr_in server_addr;
struct sockaddr_in client_addr;
if ((listenFD = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
printf("create socket err \n");
return -1;
}
/*设置服务端地址*/
addrLen = sizeof(struct sockaddr_in);
memset(&server_addr, 0, addrLen);
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
server_addr.sin_port = htons(6666);
/*绑定地址结构到套接字描述符*/
if (bind(listenFD, (struct sockaddr *) &server_addr, sizeof(server_addr)) == -1) {
cout << "bind fail" << endl;
return -1;
}
if (listen(listenFD, 100) == -1) {
cout << "listen fail" << endl;
return -1;
}
return listenFD;
}
测试不两个进程同时拥有socket的测试代码,按需求自己打开注释
package main
import (
"fmt"
"net"
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
)
func main() {
tcpSrv := NewTcpSrv()
if len(os.Args) <= 1 {
if err := tcpSrv.Init(); err != nil {
fmt.Println("tcp srv init fail, err is", err)
return
}
if err := tcpSrv.Start(); err != nil {
fmt.Println("tcp srv start fail, err is", err)
return
}
//// 迁移listen
//if err := tcpSrv.SendListenerWithUnixSocket(); err != nil {
// fmt.Println("send listener with unix socket fail, err is", err)
// return
//}
// 迁移conn
if err := tcpSrv.SendConnWithUnixSocket(); err != nil {
fmt.Println("send listener with unix socket fail, err is", err)
return
}
} else {
//// 迁移listen
//if err := tcpSrv.RecvListenerFromUnixSocket(); err != nil {
// fmt.Println("recv listener with unix socket fail, err is", err)
// return
//}
// 迁移conn
if err := tcpSrv.RecvConnFromUnixSocket(); err != nil {
fmt.Println("recv listener with unix socket fail, err is", err)
return
}
}
select {}
}
type TcpSrv struct {
listener *net.TCPListener
conns map[string]*net.TCPConn
}
func NewTcpSrv() *TcpSrv {
return &TcpSrv{
conns: make(map[string]*net.TCPConn),
}
}
func (t *TcpSrv) Init() error {
listener, err := net.Listen("tcp", ":7000")
if err != nil {
return err
}
t.listener = listener.(*net.TCPListener)
return nil
}
func (t *TcpSrv) Start() error {
go func() {
for {
conn, err := t.listener.Accept()
if err != nil {
fmt.Println("accept fail, err msg is", err)
continue
}
go t.clientSrv(conn)
storeConn := conn.(*net.TCPConn)
t.conns[conn.RemoteAddr().String()] = storeConn
}
}()
return nil
}
func (t *TcpSrv) StartWithListenSocket(listener *net.TCPListener) error {
go func() {
for {
c, err := listener.Accept()
if err != nil {
fmt.Println("accept fail, err msg is", err)
continue
}
go t.clientSrv(c)
}
}()
return nil
}
func (t *TcpSrv) clientSrv(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 1024)
for {
time.Sleep(1 * time.Second)
nRead, err := conn.Read(buf)
if err != nil {
fmt.Println("read msg fail, err is", err)
return
}
fmt.Println("recv msg is", string(buf[:nRead]))
if _, err := conn.Write(buf[:nRead]); err != nil {
fmt.Println("write msg fail, err is", err)
return
}
}
}
func (t *TcpSrv) SendListenerWithUnixSocket() error {
_ = os.Remove("/tmp/unix_socket_tcp")
addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
if err != nil {
fmt.Println("Cannot resolve unix addr: " + err.Error())
return err
}
listener, err := net.ListenUnix("unix", addr)
if err != nil {
fmt.Println("Cannot listen to unix domain socket: " + err.Error())
return err
}
fmt.Println("Listening on", listener.Addr())
go func() {
for {
c, err := listener.Accept()
if err != nil {
fmt.Println("Accept: " + err.Error())
return
}
file, _ := t.listener.File()
buf := make([]byte, 1)
buf[0] = 0
rights := syscall.UnixRights(int(file.Fd()))
_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
if err != nil {
fmt.Println("同步listen socket fail, err is", err.Error())
}
}
}()
return nil
}
func (t *TcpSrv) RecvListenerFromUnixSocket() error {
connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
if err != nil {
fmt.Println("net dial unix fail", err.Error())
return err
}
defer func() {
_ = connInterface.Close()
}()
unixConn := connInterface.(*net.UnixConn)
b := make([]byte, 1)
oob := make([]byte, 32)
for {
err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
if err != nil {
fmt.Println(err.Error())
return err
}
n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
if err != nil {
fmt.Println(err.Error())
return err
}
if n != 1 || b[0] != 0 {
if n != 1 {
fmt.Printf("recv fd type error: %d\n", n)
} else {
fmt.Println("init finish")
}
return err
}
scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
if err != nil {
fmt.Println(err.Error())
return err
}
if len(scms) != 1 {
fmt.Printf("recv fd num != 1 : %d\n", len(scms))
return err
}
fds, err := unix.ParseUnixRights(&scms[0])
if err != nil {
fmt.Println(err.Error())
return err
}
if len(fds) != 1 {
fmt.Printf("recv fd num != 1 : %d\n", len(fds))
return err
}
fmt.Printf("recv fd %d\n", fds[0])
// 这里需要把file close, 不然每次重启都会多复制一个socket
file := os.NewFile(uintptr(fds[0]), "fd-from-old")
conn, err := net.FileListener(file)
if err != nil {
fmt.Println(err.Error())
return err
}
_ = file.Close()
fmt.Println(conn)
lc := conn.(*net.TCPListener)
go t.StartWithListenSocket(lc)
}
}
func (t *TcpSrv) SendConnWithUnixSocket() error {
_ = os.Remove("/tmp/unix_socket_tcp")
addr, err := net.ResolveUnixAddr("unix", "/tmp/unix_socket_tcp")
if err != nil {
fmt.Println("Cannot resolve unix addr: " + err.Error())
return err
}
listener, err := net.ListenUnix("unix", addr)
if err != nil {
fmt.Println("Cannot listen to unix domain socket: " + err.Error())
return err
}
fmt.Println("Listening on", listener.Addr())
go func() {
for {
c, err := listener.Accept()
if err != nil {
fmt.Println("Accept: " + err.Error())
return
}
for _, conn := range t.conns {
file, _ := conn.File()
buf := make([]byte, 1)
buf[0] = 0
rights := syscall.UnixRights(int(file.Fd()))
_, _, err = c.(*net.UnixConn).WriteMsgUnix(buf, rights, nil)
if err != nil {
fmt.Println("同步listen socket fail, err is", err.Error())
}
}
}
}()
return nil
}
func (t *TcpSrv) RecvConnFromUnixSocket() error {
connInterface, err := net.Dial("unix", "/tmp/unix_socket_tcp")
if err != nil {
fmt.Println("net dial unix fail", err.Error())
return err
}
defer func() {
_ = connInterface.Close()
}()
unixConn := connInterface.(*net.UnixConn)
b := make([]byte, 1)
oob := make([]byte, 32)
for {
err = unixConn.SetWriteDeadline(time.Now().Add(time.Minute * 3))
if err != nil {
fmt.Println(err.Error())
return err
}
n, oobn, _, _, err := unixConn.ReadMsgUnix(b, oob)
if err != nil {
fmt.Println(err.Error())
return err
}
if n != 1 || b[0] != 0 {
if n != 1 {
fmt.Printf("recv fd type error: %d\n", n)
} else {
fmt.Println("init finish")
}
return err
}
scms, err := unix.ParseSocketControlMessage(oob[0:oobn])
if err != nil {
fmt.Println(err.Error())
return err
}
if len(scms) != 1 {
fmt.Printf("recv fd num != 1 : %d\n", len(scms))
return err
}
fds, err := unix.ParseUnixRights(&scms[0])
if err != nil {
fmt.Println(err.Error())
return err
}
if len(fds) != 1 {
fmt.Printf("recv fd num != 1 : %d\n", len(fds))
return err
}
fmt.Printf("recv fd %d\n", fds[0])
// 这里需要把file close, 不然每次重启都会多复制一个socket
file := os.NewFile(uintptr(fds[0]), "fd-from-old")
conn, err := net.FileConn(file)
if err != nil {
fmt.Println(err.Error())
return err
}
_ = file.Close()
fmt.Println(conn)
t.clientSrv(conn)
}
}