diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 5956da87..2e0e3a51 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -17,6 +17,24 @@ import type { UnidirectionalStream, } from "./eventStreamConnection"; +/** + * Connection states for the ReconnectingWebSocket state machine. + */ +export enum ConnectionState { + /** Initial state, ready to connect */ + IDLE = "IDLE", + /** Actively running connect() - WS factory in progress */ + CONNECTING = "CONNECTING", + /** Socket is open and working */ + CONNECTED = "CONNECTED", + /** Waiting for backoff timer before attempting reconnection */ + AWAITING_RETRY = "AWAITING_RETRY", + /** Temporarily paused - user must call reconnect() to resume */ + DISCONNECTED = "DISCONNECTED", + /** Permanently closed - cannot be reused */ + DISPOSED = "DISPOSED", +} + export type SocketFactory = () => Promise>; export interface ReconnectingWebSocketOptions { @@ -46,10 +64,8 @@ export class ReconnectingWebSocket< #lastRoute = "unknown"; // Cached route for logging when socket is closed #backoffMs: number; #reconnectTimeoutId: NodeJS.Timeout | null = null; - #isDisconnected = false; // Temporary pause, can be resumed via reconnect() - #isDisposed = false; // Permanent disposal, cannot be resumed - #isConnecting = false; - #pendingReconnect = false; + #state: ConnectionState = ConnectionState.IDLE; + #pendingReconnect = false; // Queue reconnect during CONNECTING state #certRefreshAttempted = false; // Tracks if cert refresh was already attempted this connection cycle readonly #onDispose?: () => void; @@ -94,11 +110,10 @@ export class ReconnectingWebSocket< } /** - * Returns true if the socket is temporarily disconnected and not attempting to reconnect. - * Use reconnect() to resume. + * Returns the current connection state. */ - get isDisconnected(): boolean { - return this.#isDisconnected; + get state(): string { + return this.#state; } /** @@ -133,14 +148,14 @@ export class ReconnectingWebSocket< * Resumes the socket if previously disconnected via disconnect(). */ reconnect(): void { - if (this.#isDisconnected) { - this.#isDisconnected = false; - this.#backoffMs = this.#options.initialBackoffMs; - this.#certRefreshAttempted = false; // User-initiated reconnect, allow retry + if (this.#state === ConnectionState.DISPOSED) { + return; } - if (this.#isDisposed) { - return; + if (this.#state === ConnectionState.DISCONNECTED) { + this.#state = ConnectionState.IDLE; + this.#backoffMs = this.#options.initialBackoffMs; + this.#certRefreshAttempted = false; // User-initiated reconnect, allow retry } if (this.#reconnectTimeoutId !== null) { @@ -148,7 +163,7 @@ export class ReconnectingWebSocket< this.#reconnectTimeoutId = null; } - if (this.#isConnecting) { + if (this.#state === ConnectionState.CONNECTING) { this.#pendingReconnect = true; return; } @@ -161,16 +176,19 @@ export class ReconnectingWebSocket< * Temporarily disconnect the socket. Can be resumed via reconnect(). */ disconnect(code?: number, reason?: string): void { - if (this.#isDisposed || this.#isDisconnected) { + if ( + this.#state === ConnectionState.DISPOSED || + this.#state === ConnectionState.DISCONNECTED + ) { return; } - this.#isDisconnected = true; + this.#state = ConnectionState.DISCONNECTED; this.clearCurrentSocket(code, reason); } close(code?: number, reason?: string): void { - if (this.#isDisposed) { + if (this.#state === ConnectionState.DISPOSED) { return; } @@ -187,11 +205,16 @@ export class ReconnectingWebSocket< } private async connect(): Promise { - if (this.#isDisposed || this.#isDisconnected || this.#isConnecting) { + // Only allow connecting from IDLE, CONNECTED (reconnect), or AWAITING_RETRY states + if ( + this.#state === ConnectionState.DISPOSED || + this.#state === ConnectionState.DISCONNECTED || + this.#state === ConnectionState.CONNECTING + ) { return; } - this.#isConnecting = true; + this.#state = ConnectionState.CONNECTING; try { // Close any existing socket before creating a new one if (this.#currentSocket) { @@ -204,18 +227,20 @@ export class ReconnectingWebSocket< const socket = await this.#socketFactory(); - // Check if disconnected/disposed while waiting for factory - if (this.#isDisposed || this.#isDisconnected) { + // Check if state changed while waiting for factory (e.g., disconnect/dispose called) + if (this.#state !== ConnectionState.CONNECTING) { socket.close(WebSocketCloseCode.NORMAL, "Cancelled during connection"); return; } this.#currentSocket = socket; this.#lastRoute = this.#route; + this.#state = ConnectionState.CONNECTED; socket.addEventListener("open", (event) => { + // Reset backoff on successful connection this.#backoffMs = this.#options.initialBackoffMs; - this.#certRefreshAttempted = false; // Reset on successful connection + this.#certRefreshAttempted = false; this.executeHandlers("open", event); }); @@ -233,7 +258,10 @@ export class ReconnectingWebSocket< }); socket.addEventListener("close", (event) => { - if (this.#isDisposed || this.#isDisconnected) { + if ( + this.#state === ConnectionState.DISPOSED || + this.#state === ConnectionState.DISCONNECTED + ) { return; } @@ -256,8 +284,6 @@ export class ReconnectingWebSocket< } catch (error) { await this.handleConnectionError(error); } finally { - this.#isConnecting = false; - if (this.#pendingReconnect) { this.#pendingReconnect = false; this.reconnect(); @@ -267,13 +293,15 @@ export class ReconnectingWebSocket< private scheduleReconnect(): void { if ( - this.#isDisposed || - this.#isDisconnected || - this.#reconnectTimeoutId !== null + this.#state === ConnectionState.DISPOSED || + this.#state === ConnectionState.DISCONNECTED || + this.#state === ConnectionState.AWAITING_RETRY ) { return; } + this.#state = ConnectionState.AWAITING_RETRY; + const jitter = this.#backoffMs * this.#options.jitterFactor * (Math.random() * 2 - 1); const delayMs = Math.max(0, this.#backoffMs + jitter); @@ -354,7 +382,10 @@ export class ReconnectingWebSocket< * otherwise schedules a reconnect. */ private async handleConnectionError(error: unknown): Promise { - if (this.#isDisposed || this.#isDisconnected) { + if ( + this.#state === ConnectionState.DISPOSED || + this.#state === ConnectionState.DISCONNECTED + ) { return; } @@ -396,11 +427,11 @@ export class ReconnectingWebSocket< } private dispose(code?: number, reason?: string): void { - if (this.#isDisposed) { + if (this.#state === ConnectionState.DISPOSED) { return; } - this.#isDisposed = true; + this.#state = ConnectionState.DISPOSED; this.clearCurrentSocket(code, reason); for (const set of Object.values(this.#eventHandlers)) { diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts index d81f4c1a..e06f1ec8 100644 --- a/test/unit/websocket/reconnectingWebSocket.test.ts +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { WebSocketCloseCode, HttpStatusCode } from "@/websocket/codes"; import { + ConnectionState, ReconnectingWebSocket, type SocketFactory, } from "@/websocket/reconnectingWebSocket"; @@ -27,13 +28,17 @@ describe("ReconnectingWebSocket", () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); + sockets[0].fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Network error", }); + expect(ws.state).toBe(ConnectionState.AWAITING_RETRY); await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(2); + expect(ws.state).toBe(ConnectionState.CONNECTED); ws.close(); }); @@ -65,7 +70,10 @@ describe("ReconnectingWebSocket", () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); + sockets[0].fireClose({ code, reason: "Unrecoverable" }); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); await vi.advanceTimersByTimeAsync(10000); expect(sockets).toHaveLength(1); @@ -98,7 +106,7 @@ describe("ReconnectingWebSocket", () => { ); // Should be disconnected after unrecoverable HTTP error - expect(ws.isDisconnected).toBe(true); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); // Should not retry after unrecoverable HTTP error await vi.advanceTimersByTimeAsync(10000); @@ -121,6 +129,8 @@ describe("ReconnectingWebSocket", () => { sockets[0].fireError( new Error(`Unexpected server response: ${statusCode}`), ); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); + sockets[0].fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Connection failed", @@ -179,11 +189,13 @@ describe("ReconnectingWebSocket", () => { await createBlockingReconnectingWebSocket(); ws.reconnect(); + expect(ws.state).toBe(ConnectionState.CONNECTING); ws.reconnect(); // queued expect(sockets).toHaveLength(2); // This should cancel the queued request ws.disconnect(); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); failConnection(new Error("No base URL")); await Promise.resolve(); @@ -200,10 +212,12 @@ describe("ReconnectingWebSocket", () => { // Start reconnect (will block on factory promise) ws.reconnect(); + expect(ws.state).toBe(ConnectionState.CONNECTING); expect(sockets).toHaveLength(2); // Disconnect while factory is still pending ws.disconnect(); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); completeConnection(); await Promise.resolve(); @@ -274,6 +288,7 @@ describe("ReconnectingWebSocket", () => { it("preserves event handlers after suspend() and reconnect()", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); const handler = vi.fn(); ws.addEventListener("message", handler); @@ -282,12 +297,14 @@ describe("ReconnectingWebSocket", () => { // Suspend the socket ws.disconnect(); + expect(ws.state).toBe(ConnectionState.DISCONNECTED); // Reconnect (async operation) ws.reconnect(); await Promise.resolve(); // Wait for async connect() expect(sockets).toHaveLength(2); sockets[1].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); // Handler should still work after suspend/reconnect sockets[1].fireMessage({ test: 2 }); @@ -361,19 +378,26 @@ describe("ReconnectingWebSocket", () => { ); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); + sockets[0].fireClose({ code: WebSocketCloseCode.PROTOCOL_ERROR, reason: "Protocol error", }); // Should suspend, not dispose - allows recovery when credentials change + expect(ws.state).toBe(ConnectionState.DISCONNECTED); expect(disposeCount).toBe(0); // Should be able to reconnect after suspension ws.reconnect(); + await Promise.resolve(); expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); ws.close(); + expect(ws.state).toBe(ConnectionState.DISPOSED); }); it("does not call onDispose callback during reconnection", async () => { @@ -399,6 +423,7 @@ describe("ReconnectingWebSocket", () => { const { ws, sockets, setFactoryError } = await createReconnectingWebSocketWithErrorControl(); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); // Trigger reconnect that will fail with 403 setFactoryError( @@ -408,6 +433,7 @@ describe("ReconnectingWebSocket", () => { await Promise.resolve(); // Socket should be suspended - no automatic reconnection + expect(ws.state).toBe(ConnectionState.DISCONNECTED); await vi.advanceTimersByTimeAsync(10000); expect(sockets).toHaveLength(1); @@ -416,17 +442,23 @@ describe("ReconnectingWebSocket", () => { ws.reconnect(); await Promise.resolve(); expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); ws.close(); + expect(ws.state).toBe(ConnectionState.DISPOSED); }); it("reconnect() does nothing after close()", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); + expect(ws.state).toBe(ConnectionState.CONNECTED); ws.close(); - ws.reconnect(); + expect(ws.state).toBe(ConnectionState.DISPOSED); + ws.reconnect(); + expect(ws.state).toBe(ConnectionState.DISPOSED); expect(sockets).toHaveLength(1); }); }); @@ -539,7 +571,9 @@ describe("ReconnectingWebSocket", () => { ); sockets[0].fireError(new Error("ssl alert certificate_expired")); - await vi.waitFor(() => expect(ws.isDisconnected).toBe(true)); + await vi.waitFor(() => + expect(ws.state).toBe(ConnectionState.DISCONNECTED), + ); expect(sockets).toHaveLength(1); ws.close(); @@ -556,7 +590,9 @@ describe("ReconnectingWebSocket", () => { await vi.waitFor(() => expect(sockets).toHaveLength(2)); sockets[1].fireError(new Error("ssl alert certificate_expired")); - await vi.waitFor(() => expect(ws.isDisconnected).toBe(true)); + await vi.waitFor(() => + expect(ws.state).toBe(ConnectionState.DISCONNECTED), + ); expect(refreshCount).toBe(1); ws.close(); @@ -583,7 +619,9 @@ describe("ReconnectingWebSocket", () => { ); sockets[0].fireError(new Error("ssl alert unknown_ca")); - await vi.waitFor(() => expect(ws.isDisconnected).toBe(true)); + await vi.waitFor(() => + expect(ws.state).toBe(ConnectionState.DISCONNECTED), + ); expect(refreshCallback).not.toHaveBeenCalled(); ws.close();