Skip to content

Commit

Permalink
feat(Net): add non-blocking support to WebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
obiltschnig committed Nov 10, 2024
1 parent ea333d1 commit 3975f58
Show file tree
Hide file tree
Showing 38 changed files with 1,294 additions and 113 deletions.
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ option(WEBTUNNELCLIENTLIB_MODULE "Build WebTunnelClientLib as a module" OFF)
option(WEBTUNNELAGENTLIB_SHARED "Build WebTunnelAgentLib as a shared library" OFF)
option(WEBTUNNELAGENTLIB_MODULE "Build WebTunnelAgentLib as a module" OFF)

find_package(OpenSSL)

if(WIN32)
option(ENABLE_NETSSL_WIN "Enable NetSSL Windows" ON)
option(ENABLE_NETSSL "Enable NetSSL" OFF)
option(ENABLE_CRYPTO "Enable Crypto" OFF)
else()
option(ENABLE_NETSSL_WIN "Enable NetSSL Windows" OFF)
find_package(OpenSSL)
if(OPENSSL_FOUND)
option(ENABLE_NETSSL "Enable NetSSL" ON)
option(ENABLE_CRYPTO "Enable Crypto" ON)
Expand Down Expand Up @@ -160,7 +161,7 @@ endif()
if(ENABLE_NETSSL_WIN)
set(ENABLE_UTIL ON CACHE BOOL "Enable Util" FORCE)
if(ENABLE_TESTS)
set(ENABLE_CRYPTO ON CACHE BOOL "Enable Crypto" FORCE)
set(ENABLE_CRYPTO OFF CACHE BOOL "Enable Crypto" FORCE)
endif()
endif()

Expand Down Expand Up @@ -225,13 +226,16 @@ if(WIN32 AND EXISTS ${PROJECT_SOURCE_DIR}/NetSSL_Win AND ENABLE_NETSSL_WIN)
endif(WIN32 AND EXISTS ${PROJECT_SOURCE_DIR}/NetSSL_Win AND ENABLE_NETSSL_WIN)

if(OPENSSL_FOUND)
message(STATUS "OpenSSL FOUND.")
if(EXISTS ${PROJECT_SOURCE_DIR}/NetSSL_OpenSSL AND ENABLE_NETSSL)
add_subdirectory(NetSSL_OpenSSL)
list(APPEND Poco_COMPONENTS "NetSSL_OpenSSL")
message(STATUS "NetSSL_OpenSSL added.")
endif()
if(EXISTS ${PROJECT_SOURCE_DIR}/Crypto AND ENABLE_CRYPTO)
add_subdirectory(Crypto)
list(APPEND Poco_COMPONENTS "Crypto")
message(STATUS "Crypto added.")
endif()
endif(OPENSSL_FOUND)

Expand Down
10 changes: 8 additions & 2 deletions Net/include/Poco/Net/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,21 @@ class Net_API WebSocket: public StreamSocket
/// The other socket must be a WebSocket, otherwise a Poco::InvalidArgumentException
/// will be thrown.

void shutdown();
int shutdown();
/// Sends a Close control frame to the server end of
/// the connection to initiate an orderly shutdown
/// of the connection.
///
/// Returns the number of bytes sent or -1 if the socket
/// is non-blocking and the frame cannot be sent at this time.

void shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = "");
int shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = "");
/// Sends a Close control frame to the server end of
/// the connection to initiate an orderly shutdown
/// of the connection.
///
/// Returns the number of bytes sent or -1 if the socket
/// is non-blocking and the frame cannot be sent at this time.

int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT);
/// Sends the contents of the given buffer through
Expand Down
34 changes: 27 additions & 7 deletions Net/include/Poco/Net/WebSocketImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ class Net_API WebSocketImpl: public StreamSocketImpl
virtual void sendUrgent(unsigned char data);
virtual int available();
virtual bool secure() const;
virtual void setSendBufferSize(int size);
virtual int getSendBufferSize();
virtual void setReceiveBufferSize(int size);
virtual int getReceiveBufferSize();
virtual void setSendTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getSendTimeout();
virtual void setReceiveTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getReceiveTimeout();
virtual void setBlocking(bool flag);
virtual bool getBlocking() const;

// Internal
int frameFlags() const;
Expand All @@ -93,13 +99,27 @@ class Net_API WebSocketImpl: public StreamSocketImpl
enum
{
FRAME_FLAG_MASK = 0x80,
MAX_HEADER_LENGTH = 14
MAX_HEADER_LENGTH = 14,
MASK_LENGTH = 4
};

int receiveHeader(char mask[4], bool& useMask);
int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask);
int receiveNBytes(void* buffer, int bytes);
int receiveSomeBytes(char* buffer, int bytes);
struct ReceiveState
{
int frameFlags = 0;
bool useMask = false;
char mask[MASK_LENGTH];
int headerLength = 0;
int payloadLength = 0;
int remainingPayloadLength = 0;
Poco::Buffer<char> payload{0};
};

int peekHeader(ReceiveState& receiveState);
void skipHeader(int headerLength);
int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask);
int receiveNBytes(void* buffer, int length);
int receiveSomeBytes(char* buffer, int length);
int peekSomeBytes(char* buffer, int length);
virtual ~WebSocketImpl();

private:
Expand All @@ -109,8 +129,8 @@ class Net_API WebSocketImpl: public StreamSocketImpl
int _maxPayloadSize;
Poco::Buffer<char> _buffer;
int _bufferOffset;
int _frameFlags;
bool _mustMaskPayload;
ReceiveState _receiveState;
Poco::Random _rnd;
};

Expand All @@ -120,7 +140,7 @@ class Net_API WebSocketImpl: public StreamSocketImpl
//
inline int WebSocketImpl::frameFlags() const
{
return _frameFlags;
return _receiveState.frameFlags;
}


Expand Down
88 changes: 69 additions & 19 deletions Net/src/SocketImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,24 +327,37 @@ void SocketImpl::checkBrokenTimeout(SelectMode mode)

int SocketImpl::sendBytes(const void* buffer, int length, int flags)
{
checkBrokenTimeout(SELECT_WRITE);

if (_blocking)
{
checkBrokenTimeout(SELECT_WRITE);
}
int rc;
do
{
if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException();
rc = ::send(_sockfd, reinterpret_cast<const char*>(buffer), length, flags);
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}


int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags)
{
checkBrokenTimeout(SELECT_WRITE);

if (_blocking)
{
checkBrokenTimeout(SELECT_WRITE);
}
int rc = 0;
do
{
Expand All @@ -361,15 +374,26 @@ int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags)
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}


int SocketImpl::receiveBytes(void* buffer, int length, int flags)
{
checkBrokenTimeout(SELECT_READ);

if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc;
do
{
Expand All @@ -380,7 +404,7 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags)
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand All @@ -393,8 +417,10 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags)

int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags)
{
checkBrokenTimeout(SELECT_READ);

if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc = 0;
do
{
Expand All @@ -414,7 +440,7 @@ int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags)
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down Expand Up @@ -442,7 +468,7 @@ int SocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int flags, const Poco::
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand All @@ -468,7 +494,16 @@ int SocketImpl::sendTo(const void* buffer, int length, const SocketAddress& addr
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}

Expand Down Expand Up @@ -500,7 +535,16 @@ int SocketImpl::sendTo(const SocketBufVec& buffers, const SocketAddress& address
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}

Expand All @@ -522,7 +566,10 @@ int SocketImpl::receiveFrom(void* buffer, int length, SocketAddress& address, in

int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, poco_socklen_t** ppSALen, int flags)
{
checkBrokenTimeout(SELECT_READ);
if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc;
do
{
Expand All @@ -533,7 +580,7 @@ int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, po
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down Expand Up @@ -561,7 +608,10 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, SocketAddress& address, int f

int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_socklen_t** ppSALen, int flags)
{
checkBrokenTimeout(SELECT_READ);
if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc = 0;
do
{
Expand Down Expand Up @@ -590,7 +640,7 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_s
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down
8 changes: 4 additions & 4 deletions Net/src/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,20 @@ WebSocket& WebSocket::operator = (const Socket& socket)
}


void WebSocket::shutdown()
int WebSocket::shutdown()
{
shutdown(WS_NORMAL_CLOSE);
return shutdown(WS_NORMAL_CLOSE);
}


void WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage)
int WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage)
{
Poco::Buffer<char> buffer(statusMessage.size() + 2);
Poco::MemoryOutputStream ostr(buffer.begin(), buffer.size());
Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER);
writer << statusCode;
writer.writeRaw(statusMessage);
sendFrame(buffer.begin(), static_cast<int>(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE);
return sendFrame(buffer.begin(), static_cast<int>(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE);
}


Expand Down
Loading

0 comments on commit 3975f58

Please sign in to comment.