diff --git a/examples/bigbot/src/gateway/gatewayManager.ts b/examples/bigbot/src/gateway/gatewayManager.ts index 9471cd075..fb57c2b31 100644 --- a/examples/bigbot/src/gateway/gatewayManager.ts +++ b/examples/bigbot/src/gateway/gatewayManager.ts @@ -25,8 +25,51 @@ const gatewayManager = createGatewayManager({ shardsPerWorker: SHARDS_PER_WORKER, totalShards: TOTAL_SHARDS, totalWorkers: TOTAL_WORKERS, + resharding: { + getSessionInfo: restManager.getGatewayBot, + }, }) +gatewayManager.resharding.tellWorkerToPrepare = async (workerId, shardId, bucketId) => { + logger.info(`Tell worker to prepare, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`) + + let worker = workers.get(workerId) + if (!worker) { + worker = createWorker(workerId) + workers.set(workerId, worker) + } + + worker.postMessage({ + type: 'PrepareShard', + shardId, + totalShards: gatewayManager.totalShards, + } satisfies WorkerMessage) + + const { promise, resolve } = promiseWithResolvers() + + const waitForShardPrepared = (message: ManagerMessage) => { + if (message.type === 'ShardPrepared' && message.shardId === shardId) { + resolve() + } + } + + worker.on('message', waitForShardPrepared) + + await promise + + worker.off('message', waitForShardPrepared) +} + +gatewayManager.resharding.onReshardingSwitch = async () => { + logger.info('Resharding switch triggered, telling workers to switch the shards') + + for (const worker of workers.values()) { + worker.postMessage({ + type: 'SwitchShards', + } satisfies WorkerMessage) + } +} + gatewayManager.tellWorkerToIdentify = async (workerId, shardId, bucketId) => { logger.info(`Tell worker to identify, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`) diff --git a/examples/bigbot/src/gateway/worker/types.ts b/examples/bigbot/src/gateway/worker/types.ts index 7f47bfbae..fa9c46141 100644 --- a/examples/bigbot/src/gateway/worker/types.ts +++ b/examples/bigbot/src/gateway/worker/types.ts @@ -1,13 +1,30 @@ import type { DiscordUpdatePresence, ShardSocketRequest } from '@discordeno/bot' -export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardInfo -export type WorkerMessage = WorkerIdentifyShard | WorkerAllowIdentify | WorkerShardPayload | WorkerPresencesUpdate | WorkerShardInfo +export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardPrepared | ManagerShardInfo +export type WorkerMessage = + | WorkerIdentifyShard + | WorkerPrepareShard + | WorkerSwitchShards + | WorkerAllowIdentify + | WorkerShardPayload + | WorkerPresencesUpdate + | WorkerShardInfo export interface WorkerIdentifyShard { type: 'IdentifyShard' shardId: number } +export interface WorkerPrepareShard { + type: 'PrepareShard' + shardId: number + totalShards: number +} + +export interface WorkerSwitchShards { + type: 'SwitchShards' +} + export interface WorkerAllowIdentify { type: 'AllowIdentify' shardId: number @@ -76,3 +93,8 @@ export interface ManagerShardIdentified { type: 'ShardIdentified' shardId: number } + +export interface ManagerShardPrepared { + type: 'ShardPrepared' + shardId: number +} diff --git a/examples/bigbot/src/gateway/worker/worker.ts b/examples/bigbot/src/gateway/worker/worker.ts index cb6e42fa0..5dbf87a85 100644 --- a/examples/bigbot/src/gateway/worker/worker.ts +++ b/examples/bigbot/src/gateway/worker/worker.ts @@ -1,7 +1,7 @@ import assert from 'node:assert' import { createHash } from 'node:crypto' import { workerData as _workerData, parentPort } from 'node:worker_threads' -import { DiscordenoShard, GatewayOpcodes, createLogger } from '@discordeno/bot' +import { type Camelize, type DiscordGatewayPayload, DiscordenoShard, GatewayOpcodes, ShardSocketCloseCodes, createLogger } from '@discordeno/bot' import { type Channel as amqpChannel, connect as connectAmqp } from 'amqplib' import { promiseWithResolvers } from '../../util.js' import type { ManagerMessage, WorkerCreateData, WorkerMessage } from './types.js' @@ -14,6 +14,9 @@ const logger = createLogger({ name: `Worker #${workerData.workerId}` }) const identifyPromises = new Map void>() const shards = new Map() +const pendingShards = new Map() + +let totalShards = workerData.connectionData.totalShards let rabbitMQChannel: amqpChannel | undefined @@ -38,6 +41,68 @@ parentPort.on('message', async (message: WorkerMessage) => { return } + if (message.type === 'PrepareShard') { + logger.info(`Preparing shard #${message.shardId}`) + totalShards = message.totalShards + let shard = pendingShards.get(message.shardId) + if (!shard) { + shard = createShard(message.shardId) + pendingShards.set(message.shardId, shard) + } + + // Ignore the events + // TODO: If you need 'gateway.resharding.updateGuildsShardId' it you can listen to only the ready event and use the data from that event for the function call + shard.events.message = () => {} + + await shard.identify() + + parentPort.postMessage({ + type: 'ShardPrepared', + shardId: message.shardId, + } satisfies ManagerMessage) + + return + } + if (message.type === 'SwitchShards') { + logger.info('Switching shards') + + // Change the message event for all shards + for (const shard of pendingShards.values()) { + shard.events.message = handleShardMessageEvent + } + + // Old shards stop processing events + for (const shard of shards.values()) { + const oldHandler = shard.events.message + + shard.events.message = async function (_, message) { + // Member checks need to continue but others can stop + if (message.t === 'GUILD_MEMBERS_CHUNK') { + oldHandler?.(shard, message) + } + } + } + + // Shutdown the old shards + const shardsToShutdown = Array.from(shards.values()) + + // Move the pending shards to the active shards + shards.clear() + for (const [shardId, shard] of pendingShards.entries()) { + shards.set(shardId, shard) + pendingShards.delete(shardId) + } + + // Shutdown the old shards + const promises = shardsToShutdown.map(async (shard) => { + await shard.close(ShardSocketCloseCodes.Resharded, 'Shard is being resharded') + logger.info(`Shard #${shard.id} has been shutdown`) + }) + + await Promise.all(promises) + + return + } if (message.type === 'AllowIdentify') { identifyPromises.get(message.shardId)?.() identifyPromises.delete(message.shardId) @@ -99,7 +164,7 @@ function createShard(shardId: number): DiscordenoShard { device: 'Discordeno', }, token: workerData.connectionData.token, - totalShards: workerData.connectionData.totalShards, + totalShards: totalShards, url: workerData.connectionData.url, version: workerData.connectionData.version, transportCompression: null, @@ -126,50 +191,52 @@ function createShard(shardId: number): DiscordenoShard { shard.events.message?.(shard, packet) } - shard.events.message = async (shard, payload) => { - const body = JSON.stringify({ payload, shardId: shard.id }) - - if (workerData.messageQueue.enabled) { - if (!rabbitMQChannel) { - logger.error('The RabbitMQ channel has not been created. The event will be lost') - return - } - - const message = Buffer.from(body) - const discordData = JSON.stringify(payload.d) - - const deduplicationHash = createHash('sha1') - deduplicationHash.update(discordData) - - rabbitMQChannel.publish('gatewayMessage', '', message, { - contentType: 'application/json', - headers: { - 'x-deduplication-header': deduplicationHash.digest('hex'), - }, - }) - - return - } - - const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length] - if (!url) { - logger.error('No url found to send events to') - return - } - - await fetch(url, { - method: 'POST', - body, - headers: { - 'Content-Type': 'application/json', - Authorization: workerData.eventHandler.authentication, - }, - }).catch((error) => logger.error('Failed to send events to the bot code', error)) - } + shard.events.message = handleShardMessageEvent return shard } +async function handleShardMessageEvent(shard: DiscordenoShard, payload: Camelize) { + const body = JSON.stringify({ payload, shardId: shard.id }) + + if (workerData.messageQueue.enabled) { + if (!rabbitMQChannel) { + logger.error('The RabbitMQ channel has not been created. The event will be lost') + return + } + + const message = Buffer.from(body) + const discordData = JSON.stringify(payload.d) + + const deduplicationHash = createHash('sha1') + deduplicationHash.update(discordData) + + rabbitMQChannel.publish('gatewayMessage', '', message, { + contentType: 'application/json', + headers: { + 'x-deduplication-header': deduplicationHash.digest('hex'), + }, + }) + + return + } + + const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length] + if (!url) { + logger.error('No url found to send events to') + return + } + + await fetch(url, { + method: 'POST', + body, + headers: { + 'Content-Type': 'application/json', + Authorization: workerData.eventHandler.authentication, + }, + }).catch((error) => logger.error('Failed to send events to the bot code', error)) +} + async function connectToRabbitMQ(): Promise { rabbitMQChannel = undefined const messageQueue = workerData.messageQueue diff --git a/packages/gateway/src/manager.ts b/packages/gateway/src/manager.ts index 943f19398..7ceb613de 100644 --- a/packages/gateway/src/manager.ts +++ b/packages/gateway/src/manager.ts @@ -63,8 +63,7 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate enabled: options.resharding?.enabled ?? true, shardsFullPercentage: options.resharding?.shardsFullPercentage ?? 80, checkInterval: options.resharding?.checkInterval ?? 28800000, // 8 hours - shards: new Collection(), - pendingShards: new Collection(), + shards: new Map(), getSessionInfo: options.resharding?.getSessionInfo, updateGuildsShardId: options.resharding?.updateGuildsShardId, async checkIfReshardingIsNeeded() { @@ -123,6 +122,8 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate if (typeof info.firstShardId === 'number') gateway.firstShardId = info.firstShardId // Set last shard id if provided in info if (typeof info.lastShardId === 'number') gateway.lastShardId = info.lastShardId + // If we didn't get any lastShardId, we assume all the shards are to be used + else gateway.lastShardId = gateway.totalShards - 1 gateway.logger.info(`[Resharding] Starting the reshard process. New total shards: ${gateway.totalShards}`) // Resetting buckets @@ -130,14 +131,20 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate // Refilling buckets with new values gateway.prepareBuckets() - // SPREAD THIS OUT TO DIFFERENT WORKERS TO BEGIN STARTING UP - gateway.buckets.forEach(async (bucket, bucketId) => { + // Call all the buckets and tell their workers & shards to identify + const promises = Array.from(gateway.buckets.entries()).map(async ([bucketId, bucket]) => { for (const worker of bucket.workers) { for (const shardId of worker.queue) { await gateway.resharding.tellWorkerToPrepare(worker.id, shardId, bucketId) } } }) + + await Promise.all(promises) + + gateway.logger.info(`[Resharding] All shards are now online.`) + + await gateway.resharding.onReshardingSwitch() }, async tellWorkerToPrepare(workerId, shardId, bucketId) { gateway.logger.debug(`[Resharding] Telling worker to prepare. Worker: ${workerId} | Shard: ${shardId} | Bucket: ${bucketId}.`) @@ -153,9 +160,9 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate url: gateway.url, version: gateway.version, }, - // Ignore events until we are ready events: { async message(_shard, payload) { + // Ignore all events until we swich from the old shards to the new ones. if (payload.t === 'READY') { await gateway.resharding.updateGuildsShardId?.( (payload.d as DiscordReady).guilds.map((g) => g.id), @@ -177,26 +184,16 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate gateway.resharding.shards.set(shardId, shard) - const bucket = gateway.buckets.get(shardId % gateway.connection.sessionStartLimit.maxConcurrency) - if (!bucket) return + await shard.identify() - await shard.identify().then(() => gateway.resharding.shardIsPending(shard)) + gateway.logger.debug(`[Resharding] Shard #${shardId} identified.`) }, - async shardIsPending(shard) { - // Save this in pending at the moment, until all shards are online - gateway.resharding.pendingShards.set(shard.id, shard) - gateway.logger.debug(`[Resharding] Shard #${shard.id} is now pending.`) + async onReshardingSwitch() { + gateway.logger.debug(`[Resharding] Making the switch from the old shards to the new ones.`) - // Check if all shards are now online. - if (gateway.lastShardId - gateway.firstShardId >= gateway.resharding.pendingShards.size) return - - gateway.logger.info(`[Resharding] All shards are now online.`) - - // New shards start processing events + // Move the events from the old shards to the new ones for (const shard of gateway.resharding.shards.values()) { - for (const event in options.events) { - shard.events[event as keyof ShardEvents] = options.events[event as keyof ShardEvents] as (...args: unknown[]) => unknown - } + shard.events = options.events ?? {} } // Old shards stop processing events @@ -216,17 +213,11 @@ export function createGatewayManager(options: CreateGatewayManagerOptions): Gate } gateway.logger.info(`[Resharding] Shutting down old shards.`) - // Close old shards await gateway.shutdown(ShardSocketCloseCodes.Resharded, 'Resharded!', false) gateway.logger.info(`[Resharding] Completed.`) - - // Replace old shards - gateway.shards = new Collection(gateway.resharding.shards) - - // Clear our collections and keep only one reference to the shards, the one in gateway.shards + gateway.shards = new Map(gateway.resharding.shards) gateway.resharding.shards.clear() - gateway.resharding.pendingShards.clear() }, }, @@ -736,22 +727,34 @@ export interface GatewayManager extends Required { logger: Pick /** Everything related to resharding. */ resharding: CreateGatewayManagerOptions['resharding'] & { - /** - * The interval id of the check interval. This is used to clear the interval when the manager is shutdown. - */ + /** The interval id of the check interval. This is used to clear the interval when the manager is shutdown. */ checkIntervalId?: NodeJS.Timeout | undefined /** Holds the shards that resharding has created. Once resharding is done, this replaces the gateway.shards */ - shards: Collection - /** Holds the pending shards that have been created and are pending all shards finish loading. */ - pendingShards: Collection + shards: Map /** Handler to check if resharding is necessary. */ checkIfReshardingIsNeeded: () => Promise<{ needed: boolean; info?: Camelize }> - /** Handler to begin resharding. */ + /** + * Handler to begin resharding. + * + * @remarks + * This function will resolve once the resharding is done. + * So when all the calls to {@link tellWorkerToPrepare} and {@link onReshardingSwitch} are done. + */ reshard: (info: Camelize & { firstShardId?: number; lastShardId?: number }) => Promise - /** Handler to communicate to a worker that a shard needs to be created. */ + /** + * Handler to communicate to a worker that it needs to spawn a new shard and identify it for the resharding. + * + * @remarks + * This handler works in the same way as the {@link tellWorkerToIdentify} handler. + * So you should wait for the worker to have identified the shard before resolving the promise + */ tellWorkerToPrepare: (workerId: number, shardId: number, bucketId: number) => Promise - /** Handler to alert the gateway that a shard(resharded) is online. It should now wait for all shards to be pending before shutting off old shards. */ - shardIsPending: (shard: Shard) => Promise + /** + * Handle called when all the workers have finished preparing for the resharding. + * + * This should make the new resharded shards become the active ones and shutdown the old ones + */ + onReshardingSwitch: () => Promise } /** Determine max number of shards to use based upon the max concurrency. */ calculateTotalShards: () => number