diff --git a/client.cpp b/client.cpp index accc140..8909484 100644 --- a/client.cpp +++ b/client.cpp @@ -21,6 +21,7 @@ #include "server.h" #include "exception.h" #include "config.h" +#include "utility.h" #include #include @@ -35,6 +36,7 @@ Client::Client(int tunnelMtu, const char *deviceName, uint32_t serverIp, int max { this->serverIp = serverIp; this->maxPolls = maxPolls; + this->nextEchoId = Utility::rand(); state = STATE_CLOSED; } @@ -78,7 +80,7 @@ void Client::sendChallengeResponse(int dataLength) setTimeout(5000); } -bool Client::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, int id, int seq) +bool Client::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, uint16_t id, uint16_t seq) { if (realIp != serverIp || !reply) return false; @@ -90,6 +92,10 @@ bool Client::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t { case TunnelHeader::TYPE_RESET_CONNECTION: syslog(LOG_DEBUG, "reset reveiced"); + + if (privilegesDropped) + throw Exception("cannot reconnect without root privileges"); + sendConnectionRequest(); return true; case TunnelHeader::TYPE_SERVER_FULL: @@ -151,7 +157,10 @@ void Client::sendEchoToServer(int type, int dataLength) if (maxPolls == 0 && state == STATE_ESTABLISHED) setTimeout(KEEP_ALIVE_INTERVAL); - sendEcho(magic, type, dataLength, serverIp, false, ICMP_ID, 0); + sendEcho(magic, type, dataLength, serverIp, false, nextEchoId, 0); + + if (maxPolls > 0) + nextEchoId = nextEchoId + 38543; // some random prime } void Client::startPolling() diff --git a/client.h b/client.h index 6fdf905..c856da9 100644 --- a/client.h +++ b/client.h @@ -43,7 +43,7 @@ protected: STATE_ESTABLISHED }; - virtual bool handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, int id, int seq); + virtual bool handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, uint16_t id, uint16_t seq); virtual void handleTunData(int dataLength, uint32_t sourceIp, uint32_t destIp); virtual void handleTimeout(); @@ -62,6 +62,8 @@ protected: int maxPolls; int pollTimeoutNr; + uint16_t nextEchoId; + State state; }; diff --git a/config.h b/config.h index 1ef0137..989effb 100644 --- a/config.h +++ b/config.h @@ -17,8 +17,6 @@ * */ -#define ICMP_ID 57251 - #define MAX_BUFFERED_PACKETS 20 #define KEEP_ALIVE_INTERVAL (60 * 1000) diff --git a/echo.cpp b/echo.cpp index e016abb..ccb1d31 100644 --- a/echo.cpp +++ b/echo.cpp @@ -45,7 +45,7 @@ int Echo::headerSize() return sizeof(IpHeader) + sizeof(EchoHeader); } -void Echo::send(int payloadLength, uint32_t realIp, bool reply, int id, int seq) +void Echo::send(int payloadLength, uint32_t realIp, bool reply, uint16_t id, uint16_t seq) { struct sockaddr_in target; target.sin_family = AF_INET; @@ -67,7 +67,7 @@ void Echo::send(int payloadLength, uint32_t realIp, bool reply, int id, int seq) throw Exception("sendto", true); } -int Echo::receive(uint32_t &realIp, bool &reply, int &id, int &seq) +int Echo::receive(uint32_t &realIp, bool &reply, uint16_t &id, uint16_t &seq) { struct sockaddr_in source; int source_addr_len = sizeof(struct sockaddr_in); diff --git a/echo.h b/echo.h index 46e508f..23d671d 100644 --- a/echo.h +++ b/echo.h @@ -31,8 +31,8 @@ public: int getFd() { return fd; } - void send(int payloadLength, uint32_t realIp, bool reply, int id, int seq); - int receive(uint32_t &realIp, bool &reply, int &id, int &seq); + void send(int payloadLength, uint32_t realIp, bool reply, uint16_t id, uint16_t seq); + int receive(uint32_t &realIp, bool &reply, uint16_t &id, uint16_t &seq); char *payloadBuffer() { return buffer + headerSize(); } diff --git a/server.cpp b/server.cpp index 9db42fb..b5b9335 100644 --- a/server.cpp +++ b/server.cpp @@ -49,14 +49,13 @@ Server::~Server() } -void Server::handleUnknownClient(const TunnelHeader &header, int dataLength, uint32_t realIp) +void Server::handleUnknownClient(const TunnelHeader &header, int dataLength, uint32_t realIp, uint16_t echoId) { ClientData client; client.realIp = realIp; - client.maxPolls = 0; + client.maxPolls = 1; -// if (header.type == TunnelHeader::TYPE_POLL) -// return; + pollReceived(&client, echoId); if (header.type != TunnelHeader::TYPE_CONNECTION_REQUEST || dataLength != sizeof(ClientConnectData)) { @@ -65,8 +64,6 @@ void Server::handleUnknownClient(const TunnelHeader &header, int dataLength, uin return; } - pollReceived(&client); - ClientConnectData *connectData = (ClientConnectData *)payloadBuffer(); client.maxPolls = connectData->maxPolls; @@ -146,7 +143,7 @@ void Server::sendReset(ClientData *client) sendEchoToClient(client, TunnelHeader::TYPE_RESET_CONNECTION, 0); } -bool Server::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, int id, int seq) +bool Server::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t realIp, bool reply, uint16_t id, uint16_t seq) { if (reply) return false; @@ -157,11 +154,11 @@ bool Server::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t ClientData *client = getClientByRealIp(realIp); if (client == NULL) { - handleUnknownClient(header, dataLength, realIp); + handleUnknownClient(header, dataLength, realIp, id); return true; } - pollReceived(client); + pollReceived(client, id); switch (header.type) { @@ -172,6 +169,9 @@ bool Server::handleEchoData(const TunnelHeader &header, int dataLength, uint32_t return true; } + while (client->pollIds.size() > 1) + client->pollIds.pop(); + syslog(LOG_DEBUG, "reconnecting %s", Utility::formatIp(realIp).c_str()); sendReset(client); removeClient(client); @@ -230,14 +230,14 @@ void Server::handleTunData(int dataLength, uint32_t sourceIp, uint32_t destIp) sendEchoToClient(client, TunnelHeader::TYPE_DATA, dataLength); } -void Server::pollReceived(ClientData *client) +void Server::pollReceived(ClientData *client, uint16_t echoId) { unsigned int maxSavedPolls = client->maxPolls != 0 ? client->maxPolls : 1; - client->pollTimes.push(now); - if (client->pollTimes.size() > maxSavedPolls) - client->pollTimes.pop(); - DEBUG_ONLY(printf("poll -> %d\n", client->pollTimes.size())); + client->pollIds.push(echoId); + if (client->pollIds.size() > maxSavedPolls) + client->pollIds.pop(); + DEBUG_ONLY(printf("poll (%d) -> %d\n", echoId, client->pollIds.size())); if (client->pendingPackets.size() > 0) { @@ -256,23 +256,21 @@ void Server::sendEchoToClient(ClientData *client, int type, int dataLength) { if (client->maxPolls == 0) { - sendEcho(magic, type, dataLength, client->realIp, true, ICMP_ID, 0); + sendEcho(magic, type, dataLength, client->realIp, true, client->pollIds.front(), 0); return; } - while (client->pollTimes.size() != 0) + if (client->pollIds.size() != 0) { - Time pollTime = client->pollTimes.front(); - client->pollTimes.pop(); + uint16_t id = client->pollIds.front(); + client->pollIds.pop(); - if (pollTime + POLL_INTERVAL * (client->maxPolls + 1) > now) - { - DEBUG_ONLY(printf("sending -> %d\n", client->pollTimes.size())); - sendEcho(magic, type, dataLength, client->realIp, true, ICMP_ID, 0); - return; - } + DEBUG_ONLY(printf("sending (%d) -> %d\n", id, client->pollIds.size())); + sendEcho(magic, type, dataLength, client->realIp, true, id, 0); + return; } - DEBUG_ONLY(printf("queuing -> %d\n", client->pollTimes.size())); + + DEBUG_ONLY(printf("queuing -> %d\n", client->pollIds.size())); if (client->pendingPackets.size() == MAX_BUFFERED_PACKETS) { @@ -314,7 +312,7 @@ void Server::handleTimeout() uint32_t Server::reserveTunnelIp() { uint32_t ip = network + 2; - + list::iterator i; for (i = usedIps.begin(); i != usedIps.end(); ++i) { @@ -322,10 +320,10 @@ uint32_t Server::reserveTunnelIp() break; ip = ip + 1; } - + if (ip - network >= 255) return 0; - + usedIps.insert(i, ip); return ip; } diff --git a/server.h b/server.h index 7686571..fb614b8 100644 --- a/server.h +++ b/server.h @@ -63,7 +63,7 @@ protected: std::queue pendingPackets; int maxPolls; - std::queue