manage probing sockets in a specific list instead of searching through all connections

This commit is contained in:
yrutschle 2021-04-19 09:38:22 +02:00
parent 02f6c6999d
commit b258b0e0f7
4 changed files with 83 additions and 28 deletions

View File

@ -80,8 +80,10 @@ struct connection* collection_alloc_cnx_from_fd(struct cnx_collection* collectio
/* Remove a connection from the collection */ /* Remove a connection from the collection */
int collection_remove_cnx(cnx_collection* collection, struct connection *cnx) int collection_remove_cnx(cnx_collection* collection, struct connection *cnx)
{ {
gap_set(collection->fd2cnx, cnx->q[0].fd, NULL); if (cnx->q[0].fd != -1)
gap_set(collection->fd2cnx, cnx->q[1].fd, NULL); gap_set(collection->fd2cnx, cnx->q[0].fd, NULL);
if (cnx->q[1].fd != -1)
gap_set(collection->fd2cnx, cnx->q[1].fd, NULL);
free(cnx); free(cnx);
return 0; return 0;
} }

26
gap.c
View File

@ -100,3 +100,29 @@ void gap_destroy(gap_array* gap)
free(gap); free(gap);
} }
/* In gap, find element pointing to ptr, then shift the rest of the array that
* is considered len elements long.
* A poor man's list, if you will. Currently only used to remove probing
* connections, so it only copies a few pointers at most.
* Returns -1 if ptr was not found */
int gap_remove_ptr(gap_array* gap, void* ptr, int len)
{
int start, i;
for (i = 0; i < len; i++)
if (gap->array[i] == ptr)
break;
if (i < len)
start = i;
else
return -1;
for (i = start; i < len; i++) {
gap->array[i] = gap->array[i+1];
}
return 0;
}

2
gap.h
View File

@ -8,4 +8,6 @@ void* gap_get(gap_array* gap, int index);
int gap_set(gap_array* gap, int index, void* ptr); int gap_set(gap_array* gap, int index, void* ptr);
void gap_destroy(gap_array* gap); void gap_destroy(gap_array* gap);
int gap_remove_ptr(gap_array* gap, void* ptr, int len);
#endif #endif

View File

@ -25,6 +25,7 @@
#include "common.h" #include "common.h"
#include "probe.h" #include "probe.h"
#include "collection.h" #include "collection.h"
#include "gap.h"
static int debug = 0; static int debug = 0;
@ -33,9 +34,12 @@ const char* server_type = "sslh-select";
/* Global state for a select() loop */ /* Global state for a select() loop */
struct select_info { struct select_info {
int max_fd; /* Highest fd number to pass to select() */ int max_fd; /* Highest fd number to pass to select() */
int num_probing; /* Number of connections currently probing int num_probing; /* Number of connections currently probing
* We use this to know if we need to time out of * We use this to know if we need to time out of
* select() */ * select() */
gap_array* probing_list; /* Pointers to cnx that are in probing mode */
fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */ fd_set fds_r, fds_w; /* reference fd sets (used to init working copies) */
cnx_collection* collection; /* Collection of connections linked to this loop */ cnx_collection* collection; /* Collection of connections linked to this loop */
}; };
@ -92,7 +96,7 @@ static int fd_is_in_range(int fd) {
/* Accepts a connection from the main socket and assigns it to an empty slot. /* Accepts a connection from the main socket and assigns it to an empty slot.
* If no slots are available, allocate another few. If that fails, drop the * If no slots are available, allocate another few. If that fails, drop the
* connexion */ * connexion */
static int accept_new_connection(int listen_socket, struct cnx_collection *collection) static struct connection* accept_new_connection(int listen_socket, struct cnx_collection *collection)
{ {
int in_socket, res; int in_socket, res;
@ -100,26 +104,26 @@ static int accept_new_connection(int listen_socket, struct cnx_collection *colle
if (cfg.verbose) fprintf(stderr, "accepting from %d\n", listen_socket); if (cfg.verbose) fprintf(stderr, "accepting from %d\n", listen_socket);
in_socket = accept(listen_socket, 0, 0); in_socket = accept(listen_socket, 0, 0);
CHECK_RES_RETURN(in_socket, "accept", -1); CHECK_RES_RETURN(in_socket, "accept", NULL);
if (!fd_is_in_range(in_socket)) { if (!fd_is_in_range(in_socket)) {
close(in_socket); close(in_socket);
return -1; return NULL;
} }
res = set_nonblock(in_socket); res = set_nonblock(in_socket);
if (res == -1) { if (res == -1) {
close(in_socket); close(in_socket);
return -1; return NULL;
} }
struct connection* cnx = collection_alloc_cnx_from_fd(collection, in_socket); struct connection* cnx = collection_alloc_cnx_from_fd(collection, in_socket);
if (!cnx) { if (!cnx) {
close(in_socket); close(in_socket);
return -1; return NULL;
} }
return in_socket; return cnx;
} }
@ -263,10 +267,27 @@ static void connect_proxy(struct connection *cnx)
exit(0); exit(0);
} }
/* Removes cnx from probing list */
static void remove_probing_cnx(struct select_info* fd_info, struct connection* cnx)
{
fprintf(stderr, "remove_probing_cnx %d\n", fd_info->num_probing);
gap_remove_ptr(fd_info->probing_list, cnx, fd_info->num_probing);
fd_info->num_probing--;
}
static void add_probing_cnx(struct select_info* fd_info, struct connection* cnx)
{
fprintf(stderr, "add_probing_cnx %d\n", fd_info->num_probing);
gap_set(fd_info->probing_list, fd_info->num_probing, cnx);
fd_info->num_probing++;
}
/* Process read activity on a socket in probe state /* Process read activity on a socket in probe state
* IN/OUT cnx: connection data, updated if connected * IN/OUT cnx: connection data, updated if connected
* IN/OUT info: updated if connected * IN/OUT info: updated if connected
* */ * */
static void probing_read_process(struct connection* cnx, static void probing_read_process(struct connection* cnx,
struct select_info* fd_info) struct select_info* fd_info)
{ {
@ -286,7 +307,7 @@ static void probing_read_process(struct connection* cnx,
return; return;
} }
fd_info->num_probing--; remove_probing_cnx(fd_info, cnx);
cnx->state = ST_SHOVELING; cnx->state = ST_SHOVELING;
/* libwrap check if required for this protocol */ /* libwrap check if required for this protocol */
@ -375,7 +396,7 @@ static void cnx_write_process(struct select_info* fd_info, int fd)
res = flush_deferred(&cnx->q[queue]); res = flush_deferred(&cnx->q[queue]);
if ((res == -1) && ((errno == EPIPE) || (errno == ECONNRESET))) { if ((res == -1) && ((errno == EPIPE) || (errno == ECONNRESET))) {
if (cnx->state == ST_PROBING) fd_info->num_probing--; if (cnx->state == ST_PROBING) remove_probing_cnx(fd_info, cnx);
tidy_connection(cnx, fd_info); tidy_connection(cnx, fd_info);
} else { } else {
/* If no deferred data is left, stop monitoring the fd /* If no deferred data is left, stop monitoring the fd
@ -392,12 +413,14 @@ void cnx_accept_process(struct select_info* fd_info, int fd)
{ {
if (debug) fprintf(stderr, "cnx_accept_process fd %d\n", fd); if (debug) fprintf(stderr, "cnx_accept_process fd %d\n", fd);
int in_socket = accept_new_connection(fd, fd_info->collection); struct connection* cnx = accept_new_connection(fd, fd_info->collection);
if (in_socket > 0) {
fd_info->num_probing++; if (cnx) {
FD_SET(in_socket, &fd_info->fds_r); add_probing_cnx(fd_info, cnx);
if (in_socket >= fd_info->max_fd) int new_socket = cnx->q[0].fd;
fd_info->max_fd = in_socket + 1; FD_SET(new_socket, &fd_info->fds_r);
if (new_socket >= fd_info->max_fd)
fd_info->max_fd = new_socket + 1;
} }
} }
@ -418,7 +441,6 @@ void cnx_accept_process(struct select_info* fd_info, int fd)
void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen) void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
{ {
struct select_info fd_info = {0}; struct select_info fd_info = {0};
fd_set readfds, writefds; /* working read and write fd sets */ fd_set readfds, writefds; /* working read and write fd sets */
struct timeval tv; struct timeval tv;
int i, res; int i, res;
@ -426,6 +448,7 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
fd_info.num_probing = 0; fd_info.num_probing = 0;
FD_ZERO(&fd_info.fds_r); FD_ZERO(&fd_info.fds_r);
FD_ZERO(&fd_info.fds_w); FD_ZERO(&fd_info.fds_w);
fd_info.probing_list = gap_init();
for (i = 0; i < num_addr_listen; i++) { for (i = 0; i < num_addr_listen; i++) {
FD_SET(listen_sockets[i].socketfd, &fd_info.fds_r); FD_SET(listen_sockets[i].socketfd, &fd_info.fds_r);
@ -468,17 +491,19 @@ void main_loop(struct listen_endpoint listen_sockets[], int num_addr_listen)
} }
} }
/* Check all sockets for timeouts */ /* Check sockets in probing state for timeouts */
/* TODO: refactor to use a list of probing connections to avoid linear for (i = 0; i < fd_info.num_probing; i++) {
* search through all connections */ struct connection* cnx = gap_get(fd_info.probing_list, i);
for (i = 0; i < fd_info.max_fd; i++) { if (!cnx || cnx->state != ST_PROBING) {
struct connection* cnx = collection_get_cnx_from_fd(fd_info.collection, i); log_message(LOG_ERR, "Inconsistent probing: cnx=%0xp\n", cnx);
if (cnx) { if (cnx)
if ((cnx->state == ST_PROBING) && (cnx->probe_timeout < time(NULL))) { log_message(LOG_ERR, "Inconsistent probing: state=%d\n", cnx);
if (cfg.verbose) exit(1);
fprintf(stderr, "timeout slot %d\n", i); }
probing_read_process(cnx, &fd_info); if (cnx->probe_timeout < time(NULL)) {
} if (cfg.verbose)
fprintf(stderr, "timeout slot %d\n", i);
probing_read_process(cnx, &fd_info);
} }
} }