diff --git a/packages/gateway/src/Shard.ts b/packages/gateway/src/Shard.ts index 333f41143..267b5f3ac 100644 --- a/packages/gateway/src/Shard.ts +++ b/packages/gateway/src/Shard.ts @@ -20,13 +20,13 @@ const ZLIB_SYNC_FLUSH = new Uint8Array([0x0, 0x0, 0xff, 0xff]) let fzstd: typeof import('fzstd') -/** Since fzstd is an optional dependency, we need to import it lazily */ +/** Since fzstd is an optional dependency, we need to import it lazily. */ async function getFZStd() { return (fzstd ??= await import('fzstd')) } export class DiscordenoShard { - /** The id of the shard */ + /** The id of the shard. */ id: number /** The connection config details that this shard will used to connect to discord. */ connection: ShardGatewayConfig @@ -54,18 +54,25 @@ export class DiscordenoShard { resolves = new Map<'READY' | 'RESUMED' | 'INVALID_SESSION', (payload: DiscordGatewayPayload) => void>() /** Shard bucket. Only access this if you know what you are doing. Bucket for handling shard request rate limits. */ bucket: LeakyBucket - /** Logger for the bucket */ + /** Logger for the bucket. */ logger: Pick - /** Text decoder used for compressed payloads */ + /** Text decoder used for compressed payloads. */ textDecoder = new TextDecoder() - /** ZLib Inflate instance for ZLib-stream transport payloads */ + /** ZLib Inflate instance for ZLib-stream transport payloads. */ inflate?: Inflate - /** ZLib inflate buffer */ + /** ZLib inflate buffer. */ inflateBuffer: Uint8Array | null = null - /** ZStd Decompress instance for ZStd-stream transport payloads */ + /** ZStd Decompress instance for ZStd-stream transport payloads. */ zstdDecompress?: ZstdDecompress /** Queue for compressed payloads for Zstd Decompress */ decompressionPromisesQueue: ((data: DiscordGatewayPayload) => void)[] = [] + /** + * A function that will be called once the socket is closed and handleClose() has finished updating internal states. + * + * @internal + * This is for internal purposes only, and subject to breaking changes. + */ + resolveAfterClose?: (close: CloseEvent) => void constructor(options: ShardCreateOptions) { this.id = options.id @@ -120,10 +127,18 @@ export class DiscordenoShard { } /** Close the socket connection to discord if present. */ - close(code: number, reason: string): void { + async close(code: number, reason: string): Promise { if (this.socket?.readyState !== NodeWebSocket.OPEN) return this.socket?.close(code, reason) + + // We need to wait for the socket to be fully closed, otherwise there'll be race condition issues if we try to connect again, resulting in unexpected behavior. + await new Promise((resolve) => { + this.resolveAfterClose = resolve + }) + + // Reset the resolveAfterClose function after it has been resolved. + this.resolveAfterClose = undefined } /** Connect the shard with the gateway and start heartbeating. This will not identify the shard to the gateway. */ @@ -228,7 +243,7 @@ export class DiscordenoShard { // Therefore we need to close the old connection and heartbeating before creating a new one. if (this.isOpen()) { this.logger.debug(`[Shard] Identifying open Shard #${this.id}, closing the connection`) - this.close(ShardSocketCloseCodes.ReIdentifying, 'Re-identifying closure of old connection.') + await this.close(ShardSocketCloseCodes.ReIdentifying, 'Re-identifying closure of old connection.') } this.state = ShardState.Identifying @@ -285,7 +300,7 @@ export class DiscordenoShard { // It's possible that the shard is still connected with Discord's gateway therefore we need to forcefully close it. if (this.isOpen()) { this.logger.debug(`[Shard] Resuming open Shard #${this.id}, closing the connection`) - this.close(ShardSocketCloseCodes.ResumeClosingOldConnection, 'Reconnecting the shard, closing old connection.') + await this.close(ShardSocketCloseCodes.ResumeClosingOldConnection, 'Reconnecting the shard, closing old connection.') } // Shard has never identified, so we cannot resume. @@ -346,7 +361,7 @@ export class DiscordenoShard { /** Shutdown the this. Forcefully disconnect the shard from Discord. The shard may not attempt to reconnect with Discord. */ async shutdown(): Promise { - this.close(ShardSocketCloseCodes.Shutdown, 'Shard shutting down.') + await this.close(ShardSocketCloseCodes.Shutdown, 'Shard shutting down.') this.state = ShardState.Offline } @@ -368,6 +383,9 @@ export class DiscordenoShard { this.logger.debug(`[Shard] Shard #${this.id} closed with code ${close.code}${close.reason ? `, and reason: ${close.reason}` : ''}.`) + // Resolve the close promise if it exists + this.resolveAfterClose?.(close) + switch (close.code) { case ShardSocketCloseCodes.TestingFinished: { this.state = ShardState.Offline @@ -723,7 +741,7 @@ export class DiscordenoShard { // Reference: https://discord.com/developers/docs/topics/gateway#heartbeating-example-gateway-heartbeat-ack if (!this.heart.acknowledged) { this.logger.debug(`[Shard] Heartbeat not acknowledged for Shard #${this.id}. Assuming zombied connection.`) - this.close(ShardSocketCloseCodes.ZombiedConnection, 'Zombied connection, did not receive an heartbeat ACK in time.') + await this.close(ShardSocketCloseCodes.ZombiedConnection, 'Zombied connection, did not receive an heartbeat ACK in time.') await this.resume() return diff --git a/packages/gateway/src/manager.ts b/packages/gateway/src/manager.ts index 14634f051..47f952c36 100644 --- a/packages/gateway/src/manager.ts +++ b/packages/gateway/src/manager.ts @@ -355,11 +355,9 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate } }, async shutdown(code, reason, clearReshardingInterval = true) { - gateway.shards.forEach((shard) => shard.close(code, reason)) - if (clearReshardingInterval) clearInterval(gateway.resharding.checkIntervalId) - await delay(5000) + await Promise.all(Array.from(gateway.shards.values()).map((shard) => shard.close(code, reason))) }, async sendPayload(shardId, payload) { const shard = gateway.shards.get(shardId) diff --git a/packages/gateway/tests/integration/connection.spec.ts b/packages/gateway/tests/integration/connection.spec.ts index 781577624..4e31d5029 100644 --- a/packages/gateway/tests/integration/connection.spec.ts +++ b/packages/gateway/tests/integration/connection.spec.ts @@ -1,6 +1,6 @@ -import { type DiscordGatewayPayload, Intents } from '@discordeno/types' +import { type DiscordGatewayPayload, GatewayCloseEventCodes, GatewayOpcodes, Intents } from '@discordeno/types' import uWS from 'uWebSockets.js' -import { type GatewayManager, ShardSocketCloseCodes, createGatewayManager } from '../../src/index.js' +import { type GatewayManager, createGatewayManager } from '../../src/index.js' /** * This value needs to be AT LEAST `1017` @@ -30,10 +30,6 @@ function createGatewayManagerWithPort(port: number): GatewayManager { } async function createUws(options: CreateUwsOptions) { - options.onOpen ??= () => {} - options.onMessage ??= (_message: any) => {} - options.onClose ??= (_code: number, _message: string) => {} - options.closing ??= false let port: number const { promise, resolve, reject } = promiseWithResolvers<{ port: number; uwsToken: uWS.us_listen_socket }>() @@ -41,44 +37,36 @@ async function createUws(options: CreateUwsOptions) { const app = uWS.App() app.ws('/*', { - compression: uWS.SHARED_COMPRESSOR, - maxPayloadLength: 16 * 1024 * 1024, - idleTimeout: 10, open: async (ws) => { - if (options.closing) { - ws.end(ShardSocketCloseCodes.Shutdown) - return - } - ws.send( JSON.stringify({ - op: 10, + op: GatewayOpcodes.Hello, d: { heartbeat_interval: heartbeatInterval, }, }), ) - options.onOpen!() + options.onOpen?.() }, message: async (ws, message, _isBinary) => { const msg = JSON.parse(Buffer.from(message).toString()) - options.onMessage!(msg) + options.onMessage?.(msg) - if (msg.op === 1) { + if (msg.op === GatewayOpcodes.Heartbeat) { ws.send( JSON.stringify({ - op: 11, + op: GatewayOpcodes.HeartbeatACK, }), ) return } - if (msg.op === 2) { + if (msg.op === GatewayOpcodes.Identify) { ws.send( JSON.stringify({ t: 'READY', s: 1, - op: 0, + op: GatewayOpcodes.Dispatch, d: { v: 10, user_settings: {}, @@ -110,13 +98,6 @@ async function createUws(options: CreateUwsOptions) { return } - if (msg.op === 6) { - // resume - } - }, - close: (_ws, code, message) => { - const msg = Buffer.from(message).toString() - options.onClose!(code, msg) }, }) @@ -144,7 +125,6 @@ describe('gateway', () => { const uwsOptions: CreateUwsOptions = { onOpen: resolveConnected, - closing: false, } const { port, uwsToken } = await createUws(uwsOptions) @@ -153,8 +133,8 @@ describe('gateway', () => { await gateway.spawnShards() await connected - uwsOptions.closing = true - await gateway.shutdown(ShardSocketCloseCodes.Shutdown, 'User requested bot stop', true) + // TODO: We should use ShardSocketCloseCodes.TestingFinished but there is an issue with sending 3xxx codes to uWS + await gateway.shutdown(GatewayCloseEventCodes.InvalidShard, 'User requested bot stop', true) uWS.us_listen_socket_close(uwsToken) }) @@ -172,7 +152,6 @@ describe('gateway', () => { resolveHeartbeat() }, - closing: false, } const { port, uwsToken } = await createUws(uwsOptions) @@ -189,8 +168,8 @@ describe('gateway', () => { clearTimeout(timeout) - uwsOptions.closing = true - await gateway.shutdown(ShardSocketCloseCodes.Shutdown, 'User requested bot stop', true) + // TODO: We should use ShardSocketCloseCodes.TestingFinished but there is an issue with sending 3xxx codes to uWS + await gateway.shutdown(GatewayCloseEventCodes.InvalidShard, 'User requested bot stop', true) uWS.us_listen_socket_close(uwsToken) }) @@ -216,6 +195,4 @@ function promiseWithResolvers() { interface CreateUwsOptions { onOpen?: () => any onMessage?: (message: DiscordGatewayPayload) => any - onClose?: (code: number, message: string) => any - closing?: boolean }