diff --git a/collection.c b/collection.c index ccb768c..97ea9a1 100644 --- a/collection.c +++ b/collection.c @@ -80,8 +80,10 @@ struct connection* collection_alloc_cnx_from_fd(struct cnx_collection* collectio /* Remove a connection from the collection */ int collection_remove_cnx(cnx_collection* collection, struct connection *cnx) { - gap_set(collection->fd2cnx, cnx->q[0].fd, NULL); - gap_set(collection->fd2cnx, cnx->q[1].fd, NULL); + if (cnx->q[0].fd != -1) + gap_set(collection->fd2cnx, cnx->q[0].fd, NULL); + if (cnx->q[1].fd != -1) + gap_set(collection->fd2cnx, cnx->q[1].fd, NULL); free(cnx); return 0; } diff --git a/gap.c b/gap.c index 04212bf..f3f71d2 100644 --- a/gap.c +++ b/gap.c @@ -100,3 +100,29 @@ void gap_destroy(gap_array* gap) free(gap); } + +/* In gap, find element pointing to ptr, then shift the rest of the array that + * is considered len elements long. + * A poor man's list, if you will. Currently only used to remove probing + * connections, so it only copies a few pointers at most. + * Returns -1 if ptr was not found */ +int gap_remove_ptr(gap_array* gap, void* ptr, int len) +{ + int start, i; + + for (i = 0; i < len; i++) + if (gap->array[i] == ptr) + break; + + if (i < len) + start = i; + else + return -1; + + for (i = start; i < len; i++) { + gap->array[i] = gap->array[i+1]; + } + + return 0; +} + diff --git a/gap.h b/gap.h index 58a7e4b..8044673 100644 --- a/gap.h +++ b/gap.h @@ -8,4 +8,6 @@ void* gap_get(gap_array* gap, int index); int gap_set(gap_array* gap, int index, void* ptr); void gap_destroy(gap_array* gap); +int gap_remove_ptr(gap_array* gap, void* ptr, int len); + #endif diff --git a/sslh-select.c b/sslh-select.c index 9ccea2b..6de23f9 100644 --- a/sslh-select.c +++ b/sslh-select.c @@ -25,6 +25,7 @@ #include "common.h" #include "probe.h" #include "collection.h" +#include "gap.h" static int debug = 0; @@ -33,9 +34,12 @@ const char* server_type = "sslh-select"; /* Global state for a select() loop */ struct select_info { int max_fd; /* Highest fd number to pass to select() */ + int num_probing; /* Number of connections currently probing * We use this to know if we need to time out of * select() */ + gap_array* probing_list; /* Pointers to cnx that are in probing mode */ + fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */ cnx_collection* collection; /* Collection of connections linked to this loop */ }; @@ -92,7 +96,7 @@ static int fd_is_in_range(int fd) { /* Accepts a connection from the main socket and assigns it to an empty slot. * If no slots are available, allocate another few. If that fails, drop the * connexion */ -static int accept_new_connection(int listen_socket, struct cnx_collection *collection) +static struct connection* accept_new_connection(int listen_socket, struct cnx_collection *collection) { int in_socket, res; @@ -100,26 +104,26 @@ static int accept_new_connection(int listen_socket, struct cnx_collection *colle if (cfg.verbose) fprintf(stderr, "accepting from %d\n", listen_socket); in_socket = accept(listen_socket, 0, 0); - CHECK_RES_RETURN(in_socket, "accept", -1); + CHECK_RES_RETURN(in_socket, "accept", NULL); if (!fd_is_in_range(in_socket)) { close(in_socket); - return -1; + return NULL; } res = set_nonblock(in_socket); if (res == -1) { close(in_socket); - return -1; + return NULL; } struct connection* cnx = collection_alloc_cnx_from_fd(collection, in_socket); if (!cnx) { close(in_socket); - return -1; + return NULL; } - return in_socket; + return cnx; } @@ -263,10 +267,27 @@ static void connect_proxy(struct connection *cnx) exit(0); } +/* Removes cnx from probing list */ +static void remove_probing_cnx(struct select_info* fd_info, struct connection* cnx) +{ + fprintf(stderr, "remove_probing_cnx %d\n", fd_info->num_probing); + gap_remove_ptr(fd_info->probing_list, cnx, fd_info->num_probing); + fd_info->num_probing--; +} + +static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx) +{ + fprintf(stderr, "add_probing_cnx %d\n", fd_info->num_probing); + gap_set(fd_info->probing_list, fd_info->num_probing, cnx); + fd_info->num_probing++; +} + + /* Process read activity on a socket in probe state * IN/OUT cnx: connection data, updated if connected * IN/OUT info: updated if connected * */ + static void probing_read_process(struct connection* cnx, struct select_info* fd_info) { @@ -286,7 +307,7 @@ static void probing_read_process(struct connection* cnx, return; } - fd_info->num_probing--; + remove_probing_cnx(fd_info, cnx); cnx->state = ST_SHOVELING; /* libwrap check if required for this protocol */ @@ -375,7 +396,7 @@ static void cnx_write_process(struct select_info* fd_info, int fd) res = flush_deferred(&cnx->q[queue]); if ((res == -1) && ((errno == EPIPE) || (errno == ECONNRESET))) { - if (cnx->state == ST_PROBING) fd_info->num_probing--; + if (cnx->state == ST_PROBING) remove_probing_cnx(fd_info, cnx); tidy_connection(cnx, fd_info); } else { /* If no deferred data is left, stop monitoring the fd @@ -392,12 +413,14 @@ void cnx_accept_process(struct select_info* fd_info, int fd) { if (debug) fprintf(stderr, "cnx_accept_process fd %d\n", fd); - int in_socket = accept_new_connection(fd, fd_info->collection); - if (in_socket > 0) { - fd_info->num_probing++; - FD_SET(in_socket, &fd_info->fds_r); - if (in_socket >= fd_info->max_fd) - fd_info->max_fd = in_socket + 1; + struct connection* cnx = accept_new_connection(fd, fd_info->collection); + + if (cnx) { + add_probing_cnx(fd_info, cnx); + int new_socket = cnx->q[0].fd; + FD_SET(new_socket, &fd_info->fds_r); + if (new_socket >= fd_info->max_fd) + fd_info->max_fd = new_socket + 1; } } @@ -418,7 +441,6 @@ void cnx_accept_process(struct select_info* fd_info, int fd) void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) { struct select_info fd_info = {0}; - fd_set readfds, writefds; /* working read and write fd sets */ struct timeval tv; int i, res; @@ -426,6 +448,7 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) fd_info.num_probing = 0; FD_ZERO(&fd_info.fds_r); FD_ZERO(&fd_info.fds_w); + fd_info.probing_list = gap_init(); for (i = 0; i < num_addr_listen; i++) { FD_SET(listen_sockets[i].socketfd, &fd_info.fds_r); @@ -468,17 +491,19 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) } } - /* Check all sockets for timeouts */ - /* TODO: refactor to use a list of probing connections to avoid linear - * search through all connections */ - for (i = 0; i < fd_info.max_fd; i++) { - struct connection* cnx = collection_get_cnx_from_fd(fd_info.collection, i); - if (cnx) { - if ((cnx->state == ST_PROBING) && (cnx->probe_timeout < time(NULL))) { - if (cfg.verbose) - fprintf(stderr, "timeout slot %d\n", i); - probing_read_process(cnx, &fd_info); - } + /* Check sockets in probing state for timeouts */ + for (i = 0; i < fd_info.num_probing; i++) { + struct connection* cnx = gap_get(fd_info.probing_list, i); + if (!cnx || cnx->state != ST_PROBING) { + log_message(LOG_ERR, "Inconsistent probing: cnx=%0xp\n", cnx); + if (cnx) + log_message(LOG_ERR, "Inconsistent probing: state=%d\n", cnx); + exit(1); + } + if (cnx->probe_timeout < time(NULL)) { + if (cfg.verbose) + fprintf(stderr, "timeout slot %d\n", i); + probing_read_process(cnx, &fd_info); } }