diff --git a/processes.h b/processes.h new file mode 100644 index 0000000..52e07e2 --- /dev/null +++ b/processes.h @@ -0,0 +1,22 @@ +#ifndef PROCESSES_H +#define PROCESSES_H + +#ifndef WATCHERS_TYPE_DEFINED +#error Define watchers type before including processes.h +#endif + +/* Global state for a loop */ +struct loop_info { + int num_probing; /* Number of connections currently probing + * We use this to know if we need to time out of + * select() */ + gap_array* probing_list; /* Pointers to cnx that are in probing mode */ + + watchers watchers; + + cnx_collection* collection; /* Collection of connections linked to this loop */ + + time_t next_timeout; /* time at which next UDP connection times out */ +}; + +#endif diff --git a/sslh-select.c b/sslh-select.c index a5cc45b..f2258ee 100644 --- a/sslh-select.c +++ b/sslh-select.c @@ -41,35 +41,59 @@ const char* server_type = "sslh-select"; -/* Global state for a select() loop */ -struct select_info { - int max_fd; /* Highest fd number to pass to select() */ - - int num_probing; /* Number of connections currently probing - * We use this to know if we need to time out of - * select() */ - gap_array* probing_list; /* Pointers to cnx that are in probing mode */ - +/* watcher type for a select() loop */ +typedef struct watchers { fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */ - cnx_collection* collection; /* Collection of connections linked to this loop */ + int max_fd; /* Highest fd number to pass to select() */ +} watchers; +#define WATCHERS_TYPE_DEFINED /* To notify processes.h */ - time_t next_timeout; /* time at which next UDP connection times out */ -}; +#include "processes.h" + +void watchers_init(watchers* w) +{ + FD_ZERO(&w->fds_r); + FD_ZERO(&w->fds_w); +} + +void watchers_add_read(watchers* w, int fd) +{ + FD_SET(fd, &w->fds_r); + if (fd > w->max_fd) + w->max_fd = fd + 1; +} + +void watchers_del_read(watchers* w, int fd) +{ + FD_CLR(fd, &w->fds_r); +} + +void watchers_add_write(watchers* w, int fd) +{ + FD_SET(fd, &w->fds_w); + if (fd > w->max_fd) + w->max_fd = fd + 1; +} + +void watchers_del_write(watchers* w, int fd) +{ + FD_CLR(fd, &w->fds_w); +} +/* /end watchers */ -static int tidy_connection(struct connection *cnx, struct select_info* fd_info) + +static int tidy_connection(struct connection *cnx, struct loop_info* fd_info) { int i; - fd_set* fds = &fd_info->fds_r; - fd_set* fds2 = &fd_info->fds_w; for (i = 0; i < 2; i++) { if (cnx->q[i].fd != -1) { print_message(msg_fd, "closing fd %d\n", cnx->q[i].fd); - FD_CLR(cnx->q[i].fd, fds); - FD_CLR(cnx->q[i].fd, fds2); + watchers_del_read(&fd_info->watchers, cnx->q[i].fd); + watchers_del_write(&fd_info->watchers, cnx->q[i].fd); close(cnx->q[i].fd); if (cnx->q[i].deferred_data) free(cnx->q[i].deferred_data); @@ -125,7 +149,7 @@ static struct connection* accept_new_connection(int listen_socket, struct cnx_co /* Connect queue 1 of connection to SSL; returns new file descriptor */ static int connect_queue(struct connection* cnx, - struct select_info* fd_info) + struct loop_info* fd_info) { struct queue *q = &cnx->q[1]; @@ -134,10 +158,10 @@ static int connect_queue(struct connection* cnx, log_connection(NULL, cnx); flush_deferred(q); if (q->deferred_data) { - FD_SET(q->fd, &fd_info->fds_w); - FD_CLR(cnx->q[0].fd, &fd_info->fds_r); + FD_SET(q->fd, &fd_info->watchers.fds_w); + FD_CLR(cnx->q[0].fd, &fd_info->watchers.fds_r); } - FD_SET(q->fd, &fd_info->fds_r); + FD_SET(q->fd, &fd_info->watchers.fds_r); collection_add_fd(fd_info->collection, cnx, q->fd); return q->fd; } else { @@ -149,7 +173,7 @@ static int connect_queue(struct connection* cnx, /* shovels data from active fd to the other returns after one socket closed or operation would block */ -static void shovel(struct connection *cnx, int active_fd, struct select_info* fd_info) +static void shovel(struct connection *cnx, int active_fd, struct loop_info* fd_info) { struct queue *read_q, *write_q; @@ -165,8 +189,8 @@ static void shovel(struct connection *cnx, int active_fd, struct select_info* fd break; case FD_STALLED: - FD_SET(write_q->fd, &fd_info->fds_w); - FD_CLR(read_q->fd, &fd_info->fds_r); + watchers_add_write(&fd_info->watchers, write_q->fd); + watchers_del_read(&fd_info->watchers, read_q->fd); break; default: /* Nothing */ @@ -259,13 +283,13 @@ static void connect_proxy(struct connection *cnx) } /* Removes cnx from probing list */ -static void remove_probing_cnx(struct select_info* fd_info, struct connection* cnx) +static void remove_probing_cnx(struct loop_info* fd_info, struct connection* cnx) { gap_remove_ptr(fd_info->probing_list, cnx, fd_info->num_probing); fd_info->num_probing--; } -static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx) +static void add_probing_cnx(struct loop_info* fd_info, struct connection* cnx) { gap_set(fd_info->probing_list, fd_info->num_probing, cnx); fd_info->num_probing++; @@ -278,7 +302,7 @@ static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx) * */ static void probing_read_process(struct connection* cnx, - struct select_info* fd_info) + struct loop_info* fd_info) { int res; @@ -318,9 +342,6 @@ static void probing_read_process(struct connection* cnx, } else { res = connect_queue(cnx, fd_info); } - - if (res >= fd_info->max_fd) - fd_info->max_fd = res + 1;; } @@ -335,7 +356,7 @@ int active_queue(struct connection* cnx, int fd) } /* Process a connection that is active in read */ -static void tcp_read_process(struct select_info* fd_info, +static void tcp_read_process(struct loop_info* fd_info, int fd) { cnx_collection* collection = fd_info->collection; @@ -368,7 +389,7 @@ static void tcp_read_process(struct select_info* fd_info, } } -static void cnx_read_process(struct select_info* fd_info, int fd) +static void cnx_read_process(struct loop_info* fd_info, int fd) { cnx_collection* collection = fd_info->collection; struct connection* cnx = collection_get_cnx_from_fd(collection, fd); @@ -389,7 +410,7 @@ static void cnx_read_process(struct select_info* fd_info, int fd) } /* Process a connection that is active in write */ -static void cnx_write_process(struct select_info* fd_info, int fd) +void cnx_write_process(struct loop_info* fd_info, int fd) { struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, fd); int res; @@ -403,21 +424,22 @@ static void cnx_write_process(struct select_info* fd_info, int fd) /* If no deferred data is left, stop monitoring the fd * for write, and restart monitoring the other one for reads*/ if (!cnx->q[queue].deferred_data_size) { - FD_CLR(cnx->q[queue].fd, &fd_info->fds_w); - FD_SET(cnx->q[1-queue].fd, &fd_info->fds_r); + watchers_del_write(&fd_info->watchers, cnx->q[queue].fd); + watchers_add_read(&fd_info->watchers, cnx->q[1-queue].fd); } } } /* Process a connection that accepts a socket * (For UDP, this means all traffic coming from remote clients) + * Returns new file descriptor, or -1 * */ -void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* listen_socket) +void cnx_accept_process(struct loop_info* fd_info, struct listen_endpoint* listen_socket) { int fd = listen_socket->socketfd; int type = listen_socket->type; struct connection* cnx; - int new_fd; + int new_fd = -1; switch (type) { case SOCK_STREAM: @@ -430,7 +452,7 @@ void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* lis break; case SOCK_DGRAM: - new_fd = udp_c2s_forward(fd, fd_info->collection, fd_info->max_fd); + new_fd = udp_c2s_forward(fd, fd_info->collection, fd_info->watchers.max_fd); print_message(msg_fd, "new_fd %d\n", new_fd); if (new_fd == -1) return; @@ -442,18 +464,16 @@ void cnx_accept_process(struct select_info* fd_info, struct listen_endpoint* lis return; } - FD_SET(new_fd, &fd_info->fds_r); - if (new_fd >= fd_info->max_fd) - fd_info->max_fd = new_fd + 1; - + watchers_add_read(&fd_info->watchers, new_fd); } + /* Check all connections to see if a UDP connections has timed out, then free * it. At the same time, keep track of the closest, next timeout. Only do the * search through connections if that timeout actually happened. If the * connection that would have timed out has had activity, it doesn't matter: we * go through connections to find the next timeout, which was needed anyway. */ -static void udp_timeouts(struct select_info* fd_info) +static void udp_timeouts(struct loop_info* fd_info) { time_t now = time(NULL); @@ -461,10 +481,10 @@ static void udp_timeouts(struct select_info* fd_info) time_t next_timeout = INT_MAX; - for (int i = 0; i < fd_info->max_fd; i++) { + for (int i = 0; i < fd_info->watchers.max_fd; i++) { /* if it's either in read or write set, there is a connection * behind that file descriptor */ - if (FD_ISSET(i, &fd_info->fds_r) || FD_ISSET(i, &fd_info->fds_w)) { + if (FD_ISSET(i, &fd_info->watchers.fds_r) || FD_ISSET(i, &fd_info->watchers.fds_w)) { struct connection* cnx = collection_get_cnx_from_fd(fd_info->collection, i); if (cnx) { time_t timeout = udp_timeout(cnx); @@ -472,8 +492,8 @@ static void udp_timeouts(struct select_info* fd_info) if (cnx && (timeout <= now)) { print_message(msg_fd, "timed out UDP %d\n", cnx->target_sock); close(cnx->target_sock); - FD_CLR(i, &fd_info->fds_r); - FD_CLR(i, &fd_info->fds_w); + watchers_del_read(&fd_info->watchers, i); + watchers_del_write(&fd_info->watchers, i); collection_remove_cnx(fd_info->collection, cnx); } else { if (timeout < next_timeout) next_timeout = timeout; @@ -502,35 +522,34 @@ static void udp_timeouts(struct select_info* fd_info) */ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) { - struct select_info fd_info = {0}; + struct loop_info fd_info = {0}; fd_set readfds, writefds; /* working read and write fd sets */ struct timeval tv; int i, res; fd_info.num_probing = 0; - FD_ZERO(&fd_info.fds_r); - FD_ZERO(&fd_info.fds_w); fd_info.probing_list = gap_init(0); + watchers_init(&fd_info.watchers); + for (i = 0; i < num_addr_listen; i++) { - FD_SET(listen_sockets[i].socketfd, &fd_info.fds_r); + watchers_add_read(&fd_info.watchers, listen_sockets[i].socketfd); set_nonblock(listen_sockets[i].socketfd); } - fd_info.max_fd = listen_sockets[num_addr_listen-1].socketfd + 1; - fd_info.collection = collection_init(fd_info.max_fd); + fd_info.collection = collection_init(fd_info.watchers.max_fd); while (1) { memset(&tv, 0, sizeof(tv)); tv.tv_sec = cfg.timeout; - memcpy(&readfds, &fd_info.fds_r, sizeof(readfds)); - memcpy(&writefds, &fd_info.fds_w, sizeof(writefds)); + memcpy(&readfds, &fd_info.watchers.fds_r, sizeof(readfds)); + memcpy(&writefds, &fd_info.watchers.fds_w, sizeof(writefds)); print_message(msg_fd, "selecting... max_fd=%d num_probing=%d\n", - fd_info.max_fd, fd_info.num_probing); - res = select(fd_info.max_fd, &readfds, &writefds, + fd_info.watchers.max_fd, fd_info.num_probing); + res = select(fd_info.watchers.max_fd, &readfds, &writefds, NULL, fd_info.num_probing ? &tv : NULL); if (res < 0) perror("select"); @@ -550,7 +569,7 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) } /* Check all sockets for write activity */ - for (i = 0; i < fd_info.max_fd; i++) { + for (i = 0; i < fd_info.watchers.max_fd; i++) { if (FD_ISSET(i, &writefds)) { cnx_write_process(&fd_info, i); } @@ -572,11 +591,11 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) } /* Check all sockets for read activity */ - for (i = 0; i < fd_info.max_fd; i++) { + for (i = 0; i < fd_info.watchers.max_fd; i++) { /* Check if it's active AND currently monitored (if a connection * died, it gets tidied, which closes both sockets, but readfs does * not know about that */ - if (FD_ISSET(i, &readfds) && FD_ISSET(i, &fd_info.fds_r)) { + if (FD_ISSET(i, &readfds) && FD_ISSET(i, &fd_info.watchers.fds_r)) { cnx_read_process(&fd_info, i); } }