linux io 复用高级应用2:聊天室程序

客户端:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#define _GNU_SOURCE 1
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <stdlib.h>
#include <poll.h>
#include <fcntl.h>

#define BUFFER_SIZE 64

int main(int argc, char* argv[]) {
if (argc <= 2) {
printf("usage: %s ip_address port_number\n", basename(argv[0]));
return 1;
}
const char* ip = argv[1];
int port = atoi(argv[2]);

struct sockaddr_in server_address;
bzero(&server_address, sizeof(server_address));
server_address.sin_family = AF_INET;
inet_pton(AF_INET, ip, &server_address.sin_addr);
server_address.sin_port = htons(port);

int sockfd = socket(PF_INET, SOCK_STREAM, 0);
assert(sockfd >= 0);
if (connect(sockfd, (struct sockaddr*)&server_address, sizeof(server_address)) < 0) {
printf("connection failed!\n");
close(sockfd);
return 1;
}

pollfd fds[2];
// 0文件描述符为标准输入 注册可读事件POLLIN
fds[0].fd = 0;
fds[0].events = POLLIN;
fds[0].revents = 0;

// sockfd文件描述符为监听的端口 注册可读事件POLLIN和对方关闭连接事件POLLRDHUP
fds[1].fd = sockfd;
fds[1].events = POLLIN | POLLRDHUP;
fds[1].revents = 0;

char read_buf[BUFFER_SIZE];
int pipefd[2];
int ret = pipe(pipefd);
assert(ret != -1);

while (1) {
ret = poll(fds, 2, -1);
if (ret < 0) {
printf("poll failure\n");
break;
}

if (fds[1].revents & POLLRDHUP) {
printf("server close the connection");
break;
} else if (fds[1].revents & POLLIN) {
memset(read_buf, '\0', BUFFER_SIZE);
recv(fds[1].fd, read_buf, BUFFER_SIZE - 1, 0);
printf("%s\n", read_buf);
}

if (fds[0].revents & POLLIN) {
// 零拷贝 0 -> sockfd
ret = splice(0, NULL, pipefd[1], NULL, 32768, SPLICE_F_MORE | SPLICE_F_MOVE);
ret = splice(pipefd[0], NULL, sockfd, NULL, 32768, SPLICE_F_MORE | SPLICE_F_MOVE);
}
}
close(sockfd);
return 0;
}

服务器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#define _GNU_SOURCE 1
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <stdlib.h>
#include <poll.h>

#define USER_LIMIT 5
#define BUFFER_SIZE 64
#define FD_LIMIT 65535

struct client_data
{
sockaddr_in address;
char* write_buf;
char buf[BUFFER_SIZE];
};

int setnonblocking(int fd)
{
int old_option = fcntl(fd, F_GETFL);
int new_option = old_option | O_NONBLOCK;
fcntl(fd, F_SETFL, new_option);
return old_option;
}

int main(int argc, char* argv[])
{
if (argc <= 2) {
printf("usage: %s ip_address portnumber\n", basename(argv[0]));
return 1;
}
const char* ip = argv[1];
int port = atoi(argv[2]);

int ret = 0;
struct sockaddr_in address;
bzero(&address, sizeof(address));
address.sin_family = AF_INET;
inet_pton(AF_INET, ip, &address.sin_addr);
address.sin_port = htons(port);

int listenfd = socket(PF_INET, SOCK_STREAM, 0);
assert(listenfd >= 0);

ret = bind(listenfd, (struct sockaddr*)&address, sizeof(address));
assert(ret != 1);

ret = listen(listenfd, 5);
assert(ret != -1);

client_data* users = new client_data[FD_LIMIT];

pollfd fds[USER_LIMIT + 1];
int user_counter = 0;
for (int i = 1; i <= USER_LIMIT; i++) {
fds[i].fd = -1;
fds[i].events = 0;
}
// 注册事件:监听端口是否可读或者出现错误
fds[0].fd = listenfd;
fds[0].events = POLLIN | POLLERR;
fds[0].revents = 0;

while (1) {
ret = poll(fds, user_counter + 1, -1);
if (ret < 0) {
printf("poll failure\n");
break;
}

for (int i = 0; i < user_counter + 1; i++) {
if ((fds[i].fd == listenfd) && (fds[i].revents & POLLIN)) {
struct sockaddr_in client_address;
socklen_t client_addrlength = sizeof(client_address);
int connfd = accept(listenfd, (struct sockaddr*)&client_address, &client_addrlength);

if (connfd < 0) {
printf("errno is: %d\n", errno);
continue;
}

if (user_counter >= USER_LIMIT) {
const char* info = "too many users\n";
printf("%s", info);
send(connfd, info, strlen(info), 0);
close(connfd);
continue;
}

user_counter++;
users[connfd].address = client_address;
setnonblocking(connfd);
// 用于和新连接进来的客户的交互的文件描述符connfd
fds[user_counter].fd = connfd;
fds[user_counter].events = POLLIN | POLLRDHUP | POLLERR;
fds[user_counter].revents = 0;
printf("comes a new user, now have %d users\n", user_counter);
} else if (fds[i].revents & POLLERR) { // 某个和用户连接出错了
printf("get an error from %d\n", fds[i].fd);
char errors[100];
memset(errors, '\0', 100);
socklen_t length = sizeof(errors);
if (getsockopt(fds[i].fd, SOL_SOCKET, SO_ERROR, &errors, &length) < 0) {
printf("get socket options failed]n");
}
continue;
} else if (fds[i].revents & POLLRDHUP) { // 用户断开了连接
users[fds[i].fd] = users[fds[user_counter].fd];
close(fds[i].fd);
fds[i] = fds[user_counter];
i--;
user_counter--;
printf("a client left\n");
} else if (fds[i].revents & POLLIN) { // 用户连接可读,表示用户发送了数据过来
int connfd = fds[i].fd;
memset(users[connfd].buf, '\0', BUFFER_SIZE);
ret = recv(connfd, users[connfd].buf, BUFFER_SIZE - 1, 0);
printf("get %d bytes of client data %s from %d\n", ret, users[connfd].buf, connfd);
if (ret < 0) {
if (errno != EAGAIN) {
close(connfd);
users[fds[i].fd] = users[fds[user_counter].fd];
fds[i] = fds[user_counter];
i--;
user_counter--;
}
} else if (ret == 0) { // 对方关闭了连接 此时有POLLHUP处理

} else {
for (int j = 1; j <= user_counter; j++) {
if (fds[j].fd == connfd) {
continue;
}

fds[j].events |= ~POLLIN; // 原书这里是这样写的 目的为注销可读。但是本人认为应该是 &= ,下同。
fds[j].events |= POLLOUT;
users[fds[j].fd].write_buf = users[connfd].buf; // 其他用户的buf写上当前收到的某fd所发送的信息
}
}
} else if (fds[i].revents & POLLOUT) { // 用户连接可写,把buffer信息写进去
int connfd = fds[i].fd;
if (!users[connfd].write_buf) {
continue;
}
ret = send(connfd, users[connfd].write_buf, strlen(users[connfd].write_buf), 0);
users[connfd].write_buf = NULL;
fds[i].events |= ~POLLOUT;
fds[i].events |= POLLIN;
}
}
}
delete[] users;
close(listenfd);
return 0;

}

演示结果:
服务器:
在这里插入图片描述


客户端1:
客户端发送:client 1 : 1111111


客户端2:
客户端发送:client 2 : 2222222


reference:
linux高性能服务器编程——游双