From ee4495a88754e4ff46583af7376d5611c1eeafcb Mon Sep 17 00:00:00 2001 From: Davide Depau Date: Mon, 8 Mar 2021 11:21:17 +0100 Subject: [PATCH] Handle edge-case in which we receive a partial WS header --- src/AsyncWebSocket.cpp | 111 +++++++++++++++++++++++++++++++++++------ src/AsyncWebSocket.h | 4 ++ 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index f76f2fc..7a8a456 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -473,6 +473,8 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, Async _clientId = _server->_getNextId(); _status = WS_CONNECTED; _pstate = 0; + _partialHeader = nullptr; + _partialHeaderLen = 0; _lastMessageTime = millis(); _keepAlivePeriod = 0; _client->setRxTimeout(0); @@ -611,30 +613,107 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ _lastMessageTime = millis(); uint8_t *data = (uint8_t*)pbuf; while(plen > 0){ - if(!_pstate){ - const uint8_t *fdata = data; + if(!_pstate) { + ssize_t dataPayloadOffset = 0; + const uint8_t *headerBuf = data; + + // plen is backed up to initialPlen because, in case we receive a partial header, we would like to undo all of our + // parsing and copy all of what we have of the header into a buffer for later use. + // plen is modified during the parsing attempt, so if we don't back it up we won't know how much we need to copy. + // partialHeaderLen is also backed up for the same reason. + size_t initialPlen = plen; + size_t partialHeaderLen = 0; + + if (_partialHeaderLen > 0) { + // We previously received a truncated header. Recover it by doing the following: + // - Copy the new header chunk into the previous partial header, filling the buffer. It is allocated as a + // buffer in a class field. + // - Change *headerBuf to point to said buffer + // - Update the length counters so that: + // - The initialPlen and plen, which refer to the length of the remaining packet data, also accounts for the + // previously received truncated header + // - The dataPayloadOffset, which is the offset after the header at which the payload begins, so that it + // refers to a point potentially before the beginning of the buffer. As we parse the header we increment it, + // and we can pretty much guarantee it will go back to being positive unless there is a major bug. + // - The class _partialHeaderLen is back to zero since we took ownership of the contained data. + memcpy(_partialHeader + _partialHeaderLen, data, + std::min(plen, (size_t) WS_MAX_HEADER_LEN - _partialHeaderLen)); + headerBuf = _partialHeader; + initialPlen += _partialHeaderLen; + plen += _partialHeaderLen; + dataPayloadOffset -= _partialHeaderLen; + partialHeaderLen = _partialHeaderLen; + + _partialHeaderLen = 0; + } + + // The following series of gotos could have been a try-catch but we are likely being built with -fno-exceptions + if (plen < 2) + goto _exceptionHandleFailPartialHeader; + _pinfo.index = 0; - _pinfo.final = (fdata[0] & 0x80) != 0; - _pinfo.opcode = fdata[0] & 0x0F; - _pinfo.masked = (fdata[1] & 0x80) != 0; - _pinfo.len = fdata[1] & 0x7F; - data += 2; + _pinfo.final = (headerBuf[0] & 0x80) != 0; + _pinfo.opcode = headerBuf[0] & 0x0F; + _pinfo.masked = (headerBuf[1] & 0x80) != 0; + _pinfo.len = headerBuf[1] & 0x7F; + dataPayloadOffset += 2; plen -= 2; - if(_pinfo.len == 126){ - _pinfo.len = fdata[3] | (uint16_t)(fdata[2]) << 8; - data += 2; + + if (_pinfo.len == 126) { + if (plen < 2) + goto _exceptionHandleFailPartialHeader; + + _pinfo.len = headerBuf[3] | (uint16_t)(headerBuf[2]) << 8; + dataPayloadOffset += 2; plen -= 2; - } else if(_pinfo.len == 127){ - _pinfo.len = fdata[9] | (uint16_t)(fdata[8]) << 8 | (uint32_t)(fdata[7]) << 16 | (uint32_t)(fdata[6]) << 24 | (uint64_t)(fdata[5]) << 32 | (uint64_t)(fdata[4]) << 40 | (uint64_t)(fdata[3]) << 48 | (uint64_t)(fdata[2]) << 56; - data += 8; + } else if (_pinfo.len == 127) { + if (plen < 8) + goto _exceptionHandleFailPartialHeader; + + _pinfo.len = headerBuf[9] | (uint16_t)(headerBuf[8]) << 8 | (uint32_t)(headerBuf[7]) << 16 | + (uint32_t)(headerBuf[6]) << 24 | (uint64_t)(headerBuf[5]) << 32 | (uint64_t)(headerBuf[4]) << 40 | + (uint64_t)(headerBuf[3]) << 48 | (uint64_t)(headerBuf[2]) << 56; + dataPayloadOffset += 8; plen -= 8; } - if(_pinfo.masked){ - memcpy(_pinfo.mask, data, 4); - data += 4; + if (_pinfo.masked) { + if (plen < 4) + goto _exceptionHandleFailPartialHeader; + + memcpy(_pinfo.mask, headerBuf + dataPayloadOffset + partialHeaderLen, 4); + dataPayloadOffset += 4; plen -= 4; } + + // Yes I know the control flow here isn't 100% legible but we must support -fno-exceptions. + // If we got to this point it means we did NOT receive a truncated header, therefore we can skip the exception + // handling. + // Control flow resumes after the following block. + goto _headerParsingSuccessful; + + // We DID receive a truncated header: + // - We copy it to our buffer and set the _partialHeaderLen + // - We return early + // This will trigger the partial recovery at the next call of this method, once more data is received and we have + // a full header. + _exceptionHandleFailPartialHeader: + { + if (initialPlen <= WS_MAX_HEADER_LEN) { + // If initialPlen > WS_MAX_HEADER_LEN there must be something wrong with this code. It should never happen but + // but it's better safe than sorry. + memcpy(_partialHeader, headerBuf, initialPlen * sizeof(uint8_t)); + _partialHeaderLen = initialPlen; + } else { + DEBUGF("[AsyncWebSocketClient::_onData] initialPlen (= %d) > WS_MAX_HEADER_LEN (= %d)\n", initialPlen, + WS_MAX_HEADER_LEN); + } + return; + } + + _headerParsingSuccessful: + + data += dataPayloadOffset; } const size_t datalen = std::min((size_t)(_pinfo.len - _pinfo.index), plen); diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 5b03ace..ed172f4 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -46,6 +46,8 @@ #define DEFAULT_MAX_WS_CLIENTS 4 #endif +#define WS_MAX_HEADER_LEN 16 + class AsyncWebSocket; class AsyncWebSocketResponse; class AsyncWebSocketClient; @@ -166,6 +168,8 @@ class AsyncWebSocketClient { uint8_t _pstate; AwsFrameInfo _pinfo; + uint8_t _partialHeader[WS_MAX_HEADER_LEN]; + uint8_t _partialHeaderLen; uint32_t _lastMessageTime; uint32_t _keepAlivePeriod;