-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.c
215 lines (176 loc) · 6.63 KB
/
main.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include <arpa/inet.h>
#include <errno.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <sys/epoll.h>
#include <unistd.h>
#include "lib/asyncaddrinfo/asyncaddrinfo.h"
#include "log.h"
#include "poll.h"
#include "proxy/proxy_server.h"
#include "util.h"
#define CONNECT_BACKLOG 512
#define DEFAULT_THREAD_COUNT 8
#define MAX_BLOCKLIST_LEN 100
int create_bind_listen(unsigned short port) {
int listening_socket = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP);
if (listening_socket < 0) {
die(hsprintf("failed to create listening socket: %s", errno2s(errno)));
}
struct sockaddr_in listen_addr;
listen_addr.sin_family = AF_INET;
listen_addr.sin_addr.s_addr = INADDR_ANY;
listen_addr.sin_port = htons(port);
if (bind(listening_socket, (struct sockaddr*)&listen_addr, sizeof(listen_addr)) < 0) {
die(hsprintf("failed to bind listening socket to port: %s", errno2s(errno)));
}
if (listen(listening_socket, CONNECT_BACKLOG) < 0) {
die(hsprintf("failed to listen: %s", errno2s(errno)));
}
return listening_socket;
}
struct connection_thread_args {
unsigned short thread_id;
struct proxy_server* server;
};
void handle_connections(struct proxy_server* server) {
struct poll* p = poll_create();
if (p == NULL) {
die(hsprintf("failed to create poll instance: %s", errno2s(errno)));
}
// Since we will call `accept4` until there are no more incoming connections,
// we can register edge-triggered notification for read events on the listening socket.
// Edge-triggered is more efficient than level-triggered.
if (poll_wait_for_readability(
p, server->listening_socket, server, false, true, (poll_callback)accept_incoming_connections) < 0) {
die(hsprintf("failed to register readability notification for listening socket: %s", errno2s(errno)));
}
// start the event loop and run until termination
if (poll_run(p) < 0) {
die(hsprintf("poll_run returned error: %s", errno2s(errno)));
}
poll_destroy(p);
}
void* handle_connections_pthread_wrapper(void* raw_args) {
struct connection_thread_args* args = raw_args;
thread_id__ = args->thread_id; // to identify the current thread in logging
handle_connections(args->server);
return NULL;
}
int read_blocklist(const char* blocklist_path, char*** blocklist_ptr) {
char** blocklist = *blocklist_ptr = calloc(MAX_BLOCKLIST_LEN, sizeof(char*));
FILE* fp = fopen(blocklist_path, "r");
if (fp == NULL) {
die(hsprintf("could not open file: '%s'", blocklist_path));
}
size_t buffer_len = 0;
int blocklist_len = 0;
while (1) {
if (blocklist_len >= MAX_BLOCKLIST_LEN) {
die("too many entries in the blocklist. Only up to 100 is supported.");
}
if (getline(&blocklist[blocklist_len], &buffer_len, fp) == -1) {
free(blocklist[blocklist_len]);
break;
}
size_t char_count = strcspn(blocklist[blocklist_len], "\r\n");
if (char_count == 0) {
// empty line
continue;
}
blocklist[blocklist_len][char_count] = '\0';
DEBUG_LOG("Read blocklist entry %d: %s", blocklist_len, blocklist[blocklist_len]);
blocklist_len++;
}
fclose(fp);
return blocklist_len;
}
int main(int argc, char** argv) {
if (argc < 4 || argc > 5) {
die(hsprintf("Usage: %s port flag_stats path_to_blocklist [thread_count]", argv[0]));
}
char* endptr;
unsigned short listening_port = strtol(argv[1], &endptr, 10);
if (*argv[1] == '\0' // empty argument
|| *endptr != '\0' // unrecognized characters
) {
die(hsprintf("failed to parse port number '%s'", argv[1]));
}
bool stats_enabled;
if (strcmp(argv[2], "0") == 0) {
stats_enabled = false;
} else if (strcmp(argv[2], "1") == 0) {
stats_enabled = true;
} else {
die(hsprintf("expected flag_stats to be either 0 or 1, got '%s'", argv[2]));
}
const char* blocklist_path = argv[3];
char** blocklist;
int blocklist_len = read_blocklist(blocklist_path, &blocklist);
unsigned short thread_count = DEFAULT_THREAD_COUNT;
if (argc == 5) {
thread_count = strtol(argv[4], &endptr, 10);
if (*argv[4] == '\0' || *endptr != '\0') {
die(hsprintf("failed to parse thread count '%s'", argv[4]));
}
if (thread_count < 2) {
die("at least 2 threads are required");
}
}
// use a quarter of the threads (or minimally 1) for async getaddrinfo
// use the rest (including the main thread) to run event loops and handle connections
unsigned short asyncaddrinfo_threads = thread_count / 4;
if (asyncaddrinfo_threads < 1) {
asyncaddrinfo_threads = 1;
}
unsigned short connection_threads = thread_count - asyncaddrinfo_threads;
printf("- listening port: %hu\n", listening_port);
printf("- stats enabled: %s\n", stats_enabled ? "yes" : "no");
printf("- path to blocklist file: %s\n", blocklist_path);
printf("- number of entries in the blocklist file: %d\n", blocklist_len);
printf("- number of connection threads: %hu\n", connection_threads);
printf("- number of async addrinfo (DNS) threads: %hu\n", asyncaddrinfo_threads);
// start the addr info lookup threads
asyncaddrinfo_init(asyncaddrinfo_threads);
// start the connection threads
int listening_socket = create_bind_listen(listening_port);
struct proxy_server server = {
.listening_socket = listening_socket,
.stats_enabled = stats_enabled,
.blocklist = blocklist,
.blocklist_len = blocklist_len,
};
struct connection_thread_args args_list[connection_threads];
for (int i = 0; i < connection_threads; i++) {
args_list[i].thread_id = i;
args_list[i].server = &server;
}
pthread_t workers[connection_threads - 1];
for (int i = 0; i < connection_threads - 1; i++) {
// child threads will have id from 1 onwards
// the main thread will be thread 0
if (0 != pthread_create(&workers[i], NULL, handle_connections_pthread_wrapper, &args_list[i + 1])) {
die(hsprintf("error creating thread %d: %s", i + 1, errno2s(errno)));
}
}
printf("Accepting requests\n");
// run another event loop on the main thread
handle_connections_pthread_wrapper(&args_list[0]);
// We will never reach here, the cleanup code below is just for completeness' sake
if (close(listening_socket) < 0) {
die(hsprintf("failed to close listening socket: %s", errno2s(errno)));
}
for (int i = 0; i < connection_threads; i++) {
if (0 != pthread_join(workers[i], NULL)) {
die(hsprintf("error joining thread %d: %s", i + 1, errno2s(errno)));
}
}
for (int i = 0; i < blocklist_len; i++) {
free(blocklist[i]);
}
free(blocklist);
asyncaddrinfo_cleanup();
return 0;
}