refactor: abstract watchers from loop data

This commit is contained in:
yrutschle 2021-10-11 22:40:46 +02:00
parent 0cde3d794a
commit 25abd765cb
2 changed files with 101 additions and 60 deletions

22
processes.h Normal file
View File

@ -0,0 +1,22 @@
#ifndef PROCESSES_H
#define PROCESSES_H
#ifndef WATCHERS_TYPE_DEFINED
#error Define watchers type before including processes.h
#endif
/* Global state for a loop */
struct loop_info {
int num_probing; /* Number of connections currently probing
* We use this to know if we need to time out of
* select() */
gap_array* probing_list; /* Pointers to cnx that are in probing mode */
watchers watchers;
cnx_collection* collection; /* Collection of connections linked to this loop */
time_t next_timeout; /* time at which next UDP connection times out */
};
#endif

View File

@ -41,35 +41,59 @@
const char* server_type = "sslh-select"; const char* server_type = "sslh-select";
/* Global state for a select() loop */ /* watcher type for a select() loop */
struct select_info { typedef struct watchers {
int max_fd; /* Highest fd number to pass to select() */
int num_probing; /* Number of connections currently probing
* We use this to know if we need to time out of
* select() */
gap_array* probing_list; /* Pointers to cnx that are in probing mode */
fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */ fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */
cnx_collection* collection; /* Collection of connections linked to this loop */ int max_fd; /* Highest fd number to pass to select() */
} watchers;
#define WATCHERS_TYPE_DEFINED /* To notify processes.h */
time_t next_timeout; /* time at which next UDP connection times out */ #include "processes.h"
};
void watchers_init(watchers* w)
{
FD_ZERO(&w->fds_r);
FD_ZERO(&w->fds_w);
}
void watchers_add_read(watchers* w, int fd)
{
FD_SET(fd, &w->fds_r);
if (fd > w->max_fd)
w->max_fd = fd + 1;
}
void watchers_del_read(watchers* w, int fd)
{
FD_CLR(fd, &w->fds_r);
}
void watchers_add_write(watchers* w, int fd)
{
FD_SET(fd, &w->fds_w);
if (fd > w->max_fd)
w->max_fd = fd + 1;
}
void watchers_del_write(watchers* w, int fd)
{
FD_CLR(fd, &w->fds_w);
}
/* /end watchers */
static int tidy_connection(struct connection *cnx, struct select_info* fd_info)
static int tidy_connection(struct connection *cnx, struct loop_info* fd_info)
{ {
int i; int i;
fd_set* fds = &fd_info->fds_r;
fd_set* fds2 = &fd_info->fds_w;
for (i = 0; i < 2; i++) { for (i = 0; i < 2; i++) {
if (cnx->q[i].fd != -1) { if (cnx->q[i].fd != -1) {
print_message(msg_fd, "closing fd %d\n", cnx->q[i].fd); print_message(msg_fd, "closing fd %d\n", cnx->q[i].fd);
FD_CLR(cnx->q[i].fd, fds); watchers_del_read(&fd_info->watchers, cnx->q[i].fd);
FD_CLR(cnx->q[i].fd, fds2); watchers_del_write(&fd_info->watchers, cnx->q[i].fd);
close(cnx->q[i].fd); close(cnx->q[i].fd);
if (cnx->q[i].deferred_data) if (cnx->q[i].deferred_data)
free(cnx->q[i].deferred_data); free(cnx->q[i].deferred_data);
@ -125,7 +149,7 @@ static struct connection* accept_new_connection(int listen_socket, struct cnx_co
/* Connect queue 1 of connection to SSL; returns new file descriptor */ /* Connect queue 1 of connection to SSL; returns new file descriptor */
static int connect_queue(struct connection* cnx, static int connect_queue(struct connection* cnx,
struct select_info* fd_info) struct loop_info* fd_info)
{ {
struct queue *q = &cnx->q[1]; struct queue *q = &cnx->q[1];
@ -134,10 +158,10 @@ static int connect_queue(struct connection* cnx,
log_connection(NULL, cnx); log_connection(NULL, cnx);
flush_deferred(q); flush_deferred(q);
if (q->deferred_data) { if (q->deferred_data) {
FD_SET(q->fd, &fd_info->fds_w); FD_SET(q->fd, &fd_info->watchers.fds_w);
FD_CLR(cnx->q[0].fd, &fd_info->fds_r); FD_CLR(cnx->q[0].fd, &fd_info->watchers.fds_r);
} }
FD_SET(q->fd, &fd_info->fds_r); FD_SET(q->fd, &fd_info->watchers.fds_r);
collection_add_fd(fd_info->collection, cnx, q->fd); collection_add_fd(fd_info->collection, cnx, q->fd);
return q->fd; return q->fd;
} else { } else {
@ -149,7 +173,7 @@ static int connect_queue(struct connection* cnx,
/* shovels data from active fd to the other /* shovels data from active fd to the other
returns after one socket closed or operation would block returns after one socket closed or operation would block
*/ */
static void shovel(struct connection *cnx, int active_fd, struct select_info* fd_info) static void shovel(struct connection *cnx, int active_fd, struct loop_info* fd_info)
{ {
struct queue *read_q, *write_q; struct queue *read_q, *write_q;
@ -165,8 +189,8 @@ static void shovel(struct connection *cnx, int active_fd, struct select_info* fd
break; break;
case FD_STALLED: case FD_STALLED:
FD_SET(write_q->fd, &fd_info->fds_w); watchers_add_write(&fd_info->watchers, write_q->fd);
FD_CLR(read_q->fd, &fd_info->fds_r); watchers_del_read(&fd_info->watchers, read_q->fd);
break; break;
default: /* Nothing */ default: /* Nothing */
@ -259,13 +283,13 @@ static void connect_proxy(struct connection *cnx)
} }
/* Removes cnx from probing list */ /* Removes cnx from probing list */
static void remove_probing_cnx(struct select_info* fd_info, struct connection* cnx) static void remove_probing_cnx(struct loop_info* fd_info, struct connection* cnx)
{ {
gap_remove_ptr(fd_info->probing_list, cnx, fd_info->num_probing); gap_remove_ptr(fd_info->probing_list, cnx, fd_info->num_probing);
fd_info->num_probing--; fd_info->num_probing--;
} }
static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx) static void add_probing_cnx(struct loop_info* fd_info, struct connection* cnx)
{ {
gap_set(fd_info->probing_list, fd_info->num_probing, cnx); gap_set(fd_info->probing_list, fd_info->num_probing, cnx);
fd_info->num_probing++; fd_info->num_probing++;
@ -278,7 +302,7 @@ static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx)
* */ * */
static void probing_read_process(struct connection* cnx, static void probing_read_process(struct connection* cnx,
struct select_info* fd_info) struct loop_info* fd_info)
{ {
int res; int res;
@ -318,9 +342,6 @@ static void probing_read_process(struct connection* cnx,
} else { } else {
res = connect_queue(cnx, fd_info); res = connect_queue(cnx, fd_info);
} }
if (res >= fd_info->max_fd)
fd_info->max_fd = res + 1;;
} }
@ -335,7 +356,7 @@ int active_queue(struct connection* cnx, int fd)
} }
/* Process a connection that is active in read */ /* Process a connection that is active in read */
static void tcp_read_process(struct select_info* fd_info, static void tcp_read_process(struct loop_info* fd_info,
int fd) int fd)
{ {
cnx_collection* collection = fd_info->collection; cnx_collection* collection = fd_info->collection;
@ -368,7 +389,7 @@ static void tcp_read_process(struct select_info* fd_info,
} }
} }
static void cnx_read_process(struct select_info* fd_info, int fd) static void cnx_read_process(struct loop_info* fd_info, int fd)
{ {
cnx_collection* collection = fd_info->collection; cnx_collection* collection = fd_info->collection;
struct connection* cnx = collection_get_cnx_from_fd(collection, fd); struct connection* cnx = collection_get_cnx_from_fd(collection, fd);
@ -389,7 +410,7 @@ static void cnx_read_process(struct select_info* fd_info, int fd)
} }
/* Process a connection that is active in write */ /* Process a connection that is active in write */
static void cnx_write_process(struct select_info* fd_info, int fd) void cnx_write_process(struct loop_info* fd_info, int fd)
{ {
struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, fd); struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, fd);
int res; int res;
@ -403,21 +424,22 @@ static void cnx_write_process(struct select_info* fd_info, int fd)
/* If no deferred data is left, stop monitoring the fd /* If no deferred data is left, stop monitoring the fd
* for write, and restart monitoring the other one for reads*/ * for write, and restart monitoring the other one for reads*/
if (!cnx->q[queue].deferred_data_size) { if (!cnx->q[queue].deferred_data_size) {
FD_CLR(cnx->q[queue].fd, &fd_info->fds_w); watchers_del_write(&fd_info->watchers, cnx->q[queue].fd);
FD_SET(cnx->q[1-queue].fd, &fd_info->fds_r); watchers_add_read(&fd_info->watchers, cnx->q[1-queue].fd);
} }
} }
} }
/* Process a connection that accepts a socket /* Process a connection that accepts a socket
* (For UDP, this means all traffic coming from remote clients) * (For UDP, this means all traffic coming from remote clients)
* Returns new file descriptor, or -1
* */ * */
void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* listen_socket) void cnx_accept_process(struct loop_info* fd_info, struct listen_endpoint* listen_socket)
{ {
int fd = listen_socket->socketfd; int fd = listen_socket->socketfd;
int type = listen_socket->type; int type = listen_socket->type;
struct connection* cnx; struct connection* cnx;
int new_fd; int new_fd = -1;
switch (type) { switch (type) {
case SOCK_STREAM: case SOCK_STREAM:
@ -430,7 +452,7 @@ void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* lis
break; break;
case SOCK_DGRAM: case SOCK_DGRAM:
new_fd = udp_c2s_forward(fd, fd_info->collection, fd_info->max_fd); new_fd = udp_c2s_forward(fd, fd_info->collection, fd_info->watchers.max_fd);
print_message(msg_fd, "new_fd %d\n", new_fd); print_message(msg_fd, "new_fd %d\n", new_fd);
if (new_fd == -1) if (new_fd == -1)
return; return;
@ -442,18 +464,16 @@ void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* lis
return; return;
} }
FD_SET(new_fd, &fd_info->fds_r); watchers_add_read(&fd_info->watchers, new_fd);
if (new_fd >= fd_info->max_fd)
fd_info->max_fd = new_fd + 1;
} }
/* Check all connections to see if a UDP connections has timed out, then free /* Check all connections to see if a UDP connections has timed out, then free
* it. At the same time, keep track of the closest, next timeout. Only do the * it. At the same time, keep track of the closest, next timeout. Only do the
* search through connections if that timeout actually happened. If the * search through connections if that timeout actually happened. If the
* connection that would have timed out has had activity, it doesn't matter: we * connection that would have timed out has had activity, it doesn't matter: we
* go through connections to find the next timeout, which was needed anyway. */ * go through connections to find the next timeout, which was needed anyway. */
static void udp_timeouts(struct select_info* fd_info) static void udp_timeouts(struct loop_info* fd_info)
{ {
time_t now = time(NULL); time_t now = time(NULL);
@ -461,10 +481,10 @@ static void udp_timeouts(struct select_info* fd_info)
time_t next_timeout = INT_MAX; time_t next_timeout = INT_MAX;
for (int i = 0; i < fd_info->max_fd; i++) { for (int i = 0; i < fd_info->watchers.max_fd; i++) {
/* if it's either in read or write set, there is a connection /* if it's either in read or write set, there is a connection
* behind that file descriptor */ * behind that file descriptor */
if (FD_ISSET(i, &fd_info->fds_r) || FD_ISSET(i, &fd_info->fds_w)) { if (FD_ISSET(i, &fd_info->watchers.fds_r) || FD_ISSET(i, &fd_info->watchers.fds_w)) {
struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, i); struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, i);
if (cnx) { if (cnx) {
time_t timeout = udp_timeout(cnx); time_t timeout = udp_timeout(cnx);
@ -472,8 +492,8 @@ static void udp_timeouts(struct select_info* fd_info)
if (cnx && (timeout <= now)) { if (cnx && (timeout <= now)) {
print_message(msg_fd, "timed out UDP %d\n", cnx->target_sock); print_message(msg_fd, "timed out UDP %d\n", cnx->target_sock);
close(cnx->target_sock); close(cnx->target_sock);
FD_CLR(i, &fd_info->fds_r); watchers_del_read(&fd_info->watchers, i);
FD_CLR(i, &fd_info->fds_w); watchers_del_write(&fd_info->watchers, i);
collection_remove_cnx(fd_info->collection, cnx); collection_remove_cnx(fd_info->collection, cnx);
} else { } else {
if (timeout < next_timeout) next_timeout = timeout; if (timeout < next_timeout) next_timeout = timeout;
@ -502,35 +522,34 @@ static void udp_timeouts(struct select_info* fd_info)
*/ */
void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
{ {
struct select_info fd_info = {0}; struct loop_info fd_info = {0};
fd_set readfds, writefds; /* working read and write fd sets */ fd_set readfds, writefds; /* working read and write fd sets */
struct timeval tv; struct timeval tv;
int i, res; int i, res;
fd_info.num_probing = 0; fd_info.num_probing = 0;
FD_ZERO(&fd_info.fds_r);
FD_ZERO(&fd_info.fds_w);
fd_info.probing_list = gap_init(0); fd_info.probing_list = gap_init(0);
watchers_init(&fd_info.watchers);
for (i = 0; i < num_addr_listen; i++) { for (i = 0; i < num_addr_listen; i++) {
FD_SET(listen_sockets[i].socketfd, &fd_info.fds_r); watchers_add_read(&fd_info.watchers, listen_sockets[i].socketfd);
set_nonblock(listen_sockets[i].socketfd); set_nonblock(listen_sockets[i].socketfd);
} }
fd_info.max_fd = listen_sockets[num_addr_listen-1].socketfd + 1;
fd_info.collection = collection_init(fd_info.max_fd); fd_info.collection = collection_init(fd_info.watchers.max_fd);
while (1) while (1)
{ {
memset(&tv, 0, sizeof(tv)); memset(&tv, 0, sizeof(tv));
tv.tv_sec = cfg.timeout; tv.tv_sec = cfg.timeout;
memcpy(&readfds, &fd_info.fds_r, sizeof(readfds)); memcpy(&readfds, &fd_info.watchers.fds_r, sizeof(readfds));
memcpy(&writefds, &fd_info.fds_w, sizeof(writefds)); memcpy(&writefds, &fd_info.watchers.fds_w, sizeof(writefds));
print_message(msg_fd, "selecting... max_fd=%d num_probing=%d\n", print_message(msg_fd, "selecting... max_fd=%d num_probing=%d\n",
fd_info.max_fd, fd_info.num_probing); fd_info.watchers.max_fd, fd_info.num_probing);
res = select(fd_info.max_fd, &readfds, &writefds, res = select(fd_info.watchers.max_fd, &readfds, &writefds,
NULL, fd_info.num_probing ? &tv : NULL); NULL, fd_info.num_probing ? &tv : NULL);
if (res < 0) if (res < 0)
perror("select"); perror("select");
@ -550,7 +569,7 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
} }
/* Check all sockets for write activity */ /* Check all sockets for write activity */
for (i = 0; i < fd_info.max_fd; i++) { for (i = 0; i < fd_info.watchers.max_fd; i++) {
if (FD_ISSET(i, &writefds)) { if (FD_ISSET(i, &writefds)) {
cnx_write_process(&fd_info, i); cnx_write_process(&fd_info, i);
} }
@ -572,11 +591,11 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
} }
/* Check all sockets for read activity */ /* Check all sockets for read activity */
for (i = 0; i < fd_info.max_fd; i++) { for (i = 0; i < fd_info.watchers.max_fd; i++) {
/* Check if it's active AND currently monitored (if a connection /* Check if it's active AND currently monitored (if a connection
* died, it gets tidied, which closes both sockets, but readfs does * died, it gets tidied, which closes both sockets, but readfs does
* not know about that */ * not know about that */
if (FD_ISSET(i, &readfds) && FD_ISSET(i, &fd_info.fds_r)) { if (FD_ISSET(i, &readfds) && FD_ISSET(i, &fd_info.watchers.fds_r)) {
cnx_read_process(&fd_info, i); cnx_read_process(&fd_info, i);
} }
} }