From d304b03fca92fa3afcadc6c3320287a472d1b36e Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Fri, 8 Dec 2023 21:11:49 +0100 Subject: [PATCH] feat: timeout on connection opening --- src/ic-websocket.test.ts | 55 +++++++++++++++++++++++++++++++++++++--- src/ic-websocket.ts | 25 ++++++++++++++++++ src/utils.ts | 2 +- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/src/ic-websocket.test.ts b/src/ic-websocket.test.ts index 6c4d079..e5fbac4 100644 --- a/src/ic-websocket.test.ts +++ b/src/ic-websocket.test.ts @@ -201,6 +201,37 @@ describe("IcWebsocket class", () => { expect(icWs["_isConnectionEstablished"]).toEqual(false); }); + it("closes the connection if the open message is not received in time", async () => { + const onOpen = jest.fn(); + const onError = jest.fn(); + const onClose = jest.fn(); + const icWs = new IcWebSocket(wsGatewayAddress, undefined, icWebsocketConfig); + expect(icWs).toBeDefined(); + icWs.onopen = onOpen; + icWs.onerror = onError; + icWs.onclose = onClose; + await mockWsServer.connected; + + jest.useFakeTimers(); + mockWsServer.send(encodeHandshakeMessage(VALID_HANDSHAKE_MESSAGE_FROM_GATEWAY)); + + // advance the open timeout + await jest.advanceTimersByTimeAsync(2 * MAX_ALLOWED_NETWORK_LATENCY_MS); + + expect(icWs["_isConnectionEstablished"]).toEqual(false); + expect(onOpen).not.toHaveBeenCalled(); + const openError = new Error("Open timeout expired before receiving the open message"); + expect(onError).toHaveBeenCalledWith(new ErrorEvent("error", { error: openError })); + + await jest.runAllTimersAsync(); + await expect(mockWsServer.closed).resolves.not.toThrow(); + + expect(onClose).toHaveBeenCalled(); + expect(icWs.readyState).toEqual(WebSocket.CLOSED); + + jest.useRealTimers(); + }); + it("creates a new instance and sends the open message", async () => { const icWs = new IcWebSocket(wsGatewayAddress, undefined, icWebsocketConfig); expect(icWs).toBeDefined(); @@ -230,30 +261,46 @@ describe("IcWebsocket class", () => { it("onopen is called when open message from canister is received", async () => { const onOpen = jest.fn(); const onMessage = jest.fn(); + const onError = jest.fn(); const icWs = new IcWebSocket(wsGatewayAddress, undefined, icWebsocketConfig); expect(icWs).toBeDefined(); + // workaround: simulate the client identity + icWs["_clientKey"] = client1Key; icWs.onopen = onOpen; icWs.onmessage = onMessage; + icWs.onerror = onError; await mockWsServer.connected; - await sendHandshakeMessage(VALID_HANDSHAKE_MESSAGE_FROM_GATEWAY); + + jest.useFakeTimers(); + mockWsServer.send(encodeHandshakeMessage(VALID_HANDSHAKE_MESSAGE_FROM_GATEWAY)); expect(onOpen).not.toHaveBeenCalled(); expect(icWs["_isConnectionEstablished"]).toEqual(false); + expect(onError).not.toHaveBeenCalled(); expect(onMessage).not.toHaveBeenCalled(); // wait for the open message from the client + await jest.advanceTimersToNextTimerAsync(); // needed just to advance the mockWsServer timeouts await mockWsServer.nextMessage; - // workaround to simulate the client identity - icWs["_clientKey"] = client1Key; // send the open confirmation message from the canister mockWsServer.send(Cbor.encode(VALID_OPEN_MESSAGE)); - await sleep(100); + jest.runAllTicks(); + + // advance the open timeout so that it expires + // workaround: call the advanceTimers twice + // to make message processing happening in the meantime + await jest.advanceTimersByTimeAsync(MAX_ALLOWED_NETWORK_LATENCY_MS); + await jest.advanceTimersByTimeAsync(MAX_ALLOWED_NETWORK_LATENCY_MS); expect(onOpen).toHaveBeenCalled(); expect(icWs["_isConnectionEstablished"]).toEqual(true); + expect(onError).not.toHaveBeenCalled(); + expect(icWs.readyState).toEqual(WebSocket.OPEN); // make sure onmessage callback is not called when receiving the first message expect(onMessage).not.toHaveBeenCalled(); + + jest.useRealTimers(); }); it("onmessage is called when a valid message is received", async () => { diff --git a/src/ic-websocket.ts b/src/ic-websocket.ts index f7a623e..86dfc0b 100644 --- a/src/ic-websocket.ts +++ b/src/ic-websocket.ts @@ -109,6 +109,7 @@ export default class IcWebSocket< private _clientKey: ClientKey; private _gatewayPrincipal: Principal | null = null; private _maxCertificateAgeInMinutes = 5; + private _openTimeout: NodeJS.Timeout | null = null; onclose: ((this: IcWebSocket, ev: CloseEvent) => any) | null = null; onerror: ((this: IcWebSocket, ev: ErrorEvent) => any) | null = null; @@ -231,6 +232,27 @@ export default class IcWebSocket< this._incomingMessagesQueue.addAndProcess(event.data); } + private _startOpenTimeout() { + // the timeout is double the maximum allowed network latency, + // because opening the connection involves a message sent by the client and one by the canister + this._openTimeout = setTimeout(() => { + if (!this._isConnectionEstablished) { + logger.error("[onWsOpen] Error: Open timeout expired before receiving the open message"); + this._callOnErrorCallback(new Error("Open timeout expired before receiving the open message")); + this._wsInstance.close(4000, "Open connection timeout"); + } + + this._openTimeout = null; + }, 2 * MAX_ALLOWED_NETWORK_LATENCY_MS); + } + + private _cancelOpenTimeout() { + if (this._openTimeout) { + clearTimeout(this._openTimeout); + this._openTimeout = null; + } + } + private async _handleHandshakeMessage(handshakeMessage: GatewayHandshakeMessage): Promise { // at this point, we're sure that the gateway_principal is valid // because the isGatewayHandshakeMessage function checks it @@ -239,6 +261,8 @@ export default class IcWebSocket< try { await this._sendOpenMessage(); + + this._startOpenTimeout(); } catch (error) { logger.error("[onWsMessage] Handshake message error:", error); // if a handshake message fails, we can't continue @@ -340,6 +364,7 @@ export default class IcWebSocket< } this._isConnectionEstablished = true; + this._cancelOpenTimeout(); this._callOnOpenCallback(); diff --git a/src/utils.ts b/src/utils.ts index f4f5090..165893a 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -84,7 +84,7 @@ export const safeExecute = async ( warnMessage: string ): Promise => { try { - return await fn(); + return await Promise.resolve(fn()); } catch (error) { logger.warn(warnMessage, error); }