diff --git a/probe.c b/probe.c index cf21070..f2dc901 100644 --- a/probe.c +++ b/probe.c @@ -233,7 +233,7 @@ static int is_tls_protocol(const char *p, int len, struct proto *proto) valid_tls = parse_tls_header(proto->data, p, len); - if(valid_tls < 0) + if(valid_tls <= 0) return -1 == valid_tls ? PROBE_AGAIN : PROBE_NEXT; /* There *was* a valid match */ diff --git a/tls.c b/tls.c index 84b89b3..3eec1f9 100644 --- a/tls.c +++ b/tls.c @@ -157,19 +157,20 @@ parse_tls_header(const struct TLSProtocol *tls_data, const char *data, size_t da if (pos + len > data_len) return -5; - /* By now we know it's TLS. if SNI/ALPN is set, parse extensions to see if + /* By now we know it's TLS. if SNI or ALPN is set, parse extensions to see if * they match. Otherwise, it's a match already */ - if (tls_data->match_mode.tls_match_alpn || tls_data->match_mode.tls_match_sni) + if (tls_data->match_mode.tls_match_alpn || tls_data->match_mode.tls_match_sni) { return parse_extensions(tls_data, data + pos, len); - else + } else { return 1; + } } static int parse_extensions(const struct TLSProtocol *tls_data, const char *data, size_t data_len) { size_t pos = 0; size_t len; - int last_matched = 0; + int sni_match = 0, alpn_match = 0; if (tls_data == NULL) return -3; @@ -186,49 +187,12 @@ parse_extensions(const struct TLSProtocol *tls_data, const char *data, size_t da size_t extension_type = ((unsigned char) data[pos] << 8) + (unsigned char) data[pos + 1]; - - /* Check if it's a server name extension */ - /* There can be only one extension of each type, so we break - our state and move pos to beginning of the extension here */ - if (tls_data->match_mode.tls_match_sni && tls_data->match_mode.tls_match_alpn) { - /* we want BOTH alpn and sni to match */ - if (extension_type == 0x00) { /* Server Name */ - if (parse_server_name_extension(tls_data, data + pos + 4, len) > 0) { - /* SNI matched */ - if(last_matched) { - /* this is only true if ALPN matched, so return true */ - return last_matched; - } else { - /* otherwise store that SNI matched */ - last_matched = 1; - } - } else { - /* both can't match */ - return -2; - } - } else if (extension_type == 0x10) { /* ALPN */ - if (parse_alpn_extension(tls_data, data + pos + 4, len) > 0) { - /* ALPN matched */ - if(last_matched) { - /* this is only true if SNI matched, so return true */ - return last_matched; - } else { - /* otherwise store that ALPN matched */ - last_matched = 1; - } - } else { - /* both can't match */ - return -2; - } - } - - } else if (extension_type == 0x00 && tls_data->match_mode.tls_match_sni) { /* Server Name */ - return parse_server_name_extension(tls_data, data + pos + 4, len); + if (extension_type == 0x00 && tls_data->match_mode.tls_match_sni) { /* Server Name */ + sni_match = parse_server_name_extension(tls_data, data + pos + 4, len); + if (sni_match < 0) return sni_match; } else if (extension_type == 0x10 && tls_data->match_mode.tls_match_alpn) { /* ALPN */ - if (parse_alpn_extension(tls_data, data + pos + 4, len) > 0) { - return 1; - } - return parse_alpn_extension(tls_data, data + pos + 4, len); + alpn_match = parse_alpn_extension(tls_data, data + pos + 4, len); + if (alpn_match < 0) return alpn_match; } pos += 4 + len; /* Advance to the next extension header */ @@ -238,7 +202,9 @@ parse_extensions(const struct TLSProtocol *tls_data, const char *data, size_t da if (pos != data_len) return -5; - return -2; + return (sni_match && alpn_match) + || (!tls_data->match_mode.tls_match_sni && alpn_match) + || (!tls_data->match_mode.tls_match_alpn && sni_match); } static int