feat(gateway)!: Resharding with workers (#4206)

* feat(gateway): Rework how shards are resharded

* Make bigbot example use the resharding

* fix reshard doesn't increment the lastShardId

From my testing it works, however i don't know if we should just do
this or is there a better way

* revert createGatewayManager type changes
It breaks if you want to disable resharding, didn't think of that

* Fix typo (#4252)

* Apply suggestions from code review

Co-authored-by: Link <lts20050703@gmail.com>

---------

Co-authored-by: NotDemonix <90858555+NotDemonix@users.noreply.github.com>
Co-authored-by: Link <lts20050703@gmail.com>
This commit is contained in:
Fleny
2025-08-04 19:46:24 +02:00
committed by GitHub
parent 8fd277924f
commit 27ab08a61d
4 changed files with 217 additions and 82 deletions

View File

@@ -25,8 +25,51 @@ const gatewayManager = createGatewayManager({
shardsPerWorker: SHARDS_PER_WORKER,
totalShards: TOTAL_SHARDS,
totalWorkers: TOTAL_WORKERS,
resharding: {
getSessionInfo: restManager.getGatewayBot,
},
})
gatewayManager.resharding.tellWorkerToPrepare = async (workerId, shardId, bucketId) => {
logger.info(`Tell worker to prepare, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`)
let worker = workers.get(workerId)
if (!worker) {
worker = createWorker(workerId)
workers.set(workerId, worker)
}
worker.postMessage({
type: 'PrepareShard',
shardId,
totalShards: gatewayManager.totalShards,
} satisfies WorkerMessage)
const { promise, resolve } = promiseWithResolvers<void>()
const waitForShardPrepared = (message: ManagerMessage) => {
if (message.type === 'ShardPrepared' && message.shardId === shardId) {
resolve()
}
}
worker.on('message', waitForShardPrepared)
await promise
worker.off('message', waitForShardPrepared)
}
gatewayManager.resharding.onReshardingSwitch = async () => {
logger.info('Resharding switch triggered, telling workers to switch the shards')
for (const worker of workers.values()) {
worker.postMessage({
type: 'SwitchShards',
} satisfies WorkerMessage)
}
}
gatewayManager.tellWorkerToIdentify = async (workerId, shardId, bucketId) => {
logger.info(`Tell worker to identify, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`)

View File

@@ -1,13 +1,30 @@
import type { DiscordUpdatePresence, ShardSocketRequest } from '@discordeno/bot'
export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardInfo
export type WorkerMessage = WorkerIdentifyShard | WorkerAllowIdentify | WorkerShardPayload | WorkerPresencesUpdate | WorkerShardInfo
export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardPrepared | ManagerShardInfo
export type WorkerMessage =
| WorkerIdentifyShard
| WorkerPrepareShard
| WorkerSwitchShards
| WorkerAllowIdentify
| WorkerShardPayload
| WorkerPresencesUpdate
| WorkerShardInfo
export interface WorkerIdentifyShard {
type: 'IdentifyShard'
shardId: number
}
export interface WorkerPrepareShard {
type: 'PrepareShard'
shardId: number
totalShards: number
}
export interface WorkerSwitchShards {
type: 'SwitchShards'
}
export interface WorkerAllowIdentify {
type: 'AllowIdentify'
shardId: number
@@ -76,3 +93,8 @@ export interface ManagerShardIdentified {
type: 'ShardIdentified'
shardId: number
}
export interface ManagerShardPrepared {
type: 'ShardPrepared'
shardId: number
}

View File

@@ -1,7 +1,7 @@
import assert from 'node:assert'
import { createHash } from 'node:crypto'
import { workerData as _workerData, parentPort } from 'node:worker_threads'
import { DiscordenoShard, GatewayOpcodes, createLogger } from '@discordeno/bot'
import { type Camelize, type DiscordGatewayPayload, DiscordenoShard, GatewayOpcodes, ShardSocketCloseCodes, createLogger } from '@discordeno/bot'
import { type Channel as amqpChannel, connect as connectAmqp } from 'amqplib'
import { promiseWithResolvers } from '../../util.js'
import type { ManagerMessage, WorkerCreateData, WorkerMessage } from './types.js'
@@ -14,6 +14,9 @@ const logger = createLogger({ name: `Worker #${workerData.workerId}` })
const identifyPromises = new Map<number, () => void>()
const shards = new Map<number, DiscordenoShard>()
const pendingShards = new Map<number, DiscordenoShard>()
let totalShards = workerData.connectionData.totalShards
let rabbitMQChannel: amqpChannel | undefined
@@ -38,6 +41,68 @@ parentPort.on('message', async (message: WorkerMessage) => {
return
}
if (message.type === 'PrepareShard') {
logger.info(`Preparing shard #${message.shardId}`)
totalShards = message.totalShards
let shard = pendingShards.get(message.shardId)
if (!shard) {
shard = createShard(message.shardId)
pendingShards.set(message.shardId, shard)
}
// Ignore the events
// TODO: If you need 'gateway.resharding.updateGuildsShardId' it you can listen to only the ready event and use the data from that event for the function call
shard.events.message = () => {}
await shard.identify()
parentPort.postMessage({
type: 'ShardPrepared',
shardId: message.shardId,
} satisfies ManagerMessage)
return
}
if (message.type === 'SwitchShards') {
logger.info('Switching shards')
// Change the message event for all shards
for (const shard of pendingShards.values()) {
shard.events.message = handleShardMessageEvent
}
// Old shards stop processing events
for (const shard of shards.values()) {
const oldHandler = shard.events.message
shard.events.message = async function (_, message) {
// Member checks need to continue but others can stop
if (message.t === 'GUILD_MEMBERS_CHUNK') {
oldHandler?.(shard, message)
}
}
}
// Shutdown the old shards
const shardsToShutdown = Array.from(shards.values())
// Move the pending shards to the active shards
shards.clear()
for (const [shardId, shard] of pendingShards.entries()) {
shards.set(shardId, shard)
pendingShards.delete(shardId)
}
// Shutdown the old shards
const promises = shardsToShutdown.map(async (shard) => {
await shard.close(ShardSocketCloseCodes.Resharded, 'Shard is being resharded')
logger.info(`Shard #${shard.id} has been shutdown`)
})
await Promise.all(promises)
return
}
if (message.type === 'AllowIdentify') {
identifyPromises.get(message.shardId)?.()
identifyPromises.delete(message.shardId)
@@ -99,7 +164,7 @@ function createShard(shardId: number): DiscordenoShard {
device: 'Discordeno',
},
token: workerData.connectionData.token,
totalShards: workerData.connectionData.totalShards,
totalShards: totalShards,
url: workerData.connectionData.url,
version: workerData.connectionData.version,
transportCompression: null,
@@ -126,50 +191,52 @@ function createShard(shardId: number): DiscordenoShard {
shard.events.message?.(shard, packet)
}
shard.events.message = async (shard, payload) => {
const body = JSON.stringify({ payload, shardId: shard.id })
if (workerData.messageQueue.enabled) {
if (!rabbitMQChannel) {
logger.error('The RabbitMQ channel has not been created. The event will be lost')
return
}
const message = Buffer.from(body)
const discordData = JSON.stringify(payload.d)
const deduplicationHash = createHash('sha1')
deduplicationHash.update(discordData)
rabbitMQChannel.publish('gatewayMessage', '', message, {
contentType: 'application/json',
headers: {
'x-deduplication-header': deduplicationHash.digest('hex'),
},
})
return
}
const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length]
if (!url) {
logger.error('No url found to send events to')
return
}
await fetch(url, {
method: 'POST',
body,
headers: {
'Content-Type': 'application/json',
Authorization: workerData.eventHandler.authentication,
},
}).catch((error) => logger.error('Failed to send events to the bot code', error))
}
shard.events.message = handleShardMessageEvent
return shard
}
async function handleShardMessageEvent(shard: DiscordenoShard, payload: Camelize<DiscordGatewayPayload>) {
const body = JSON.stringify({ payload, shardId: shard.id })
if (workerData.messageQueue.enabled) {
if (!rabbitMQChannel) {
logger.error('The RabbitMQ channel has not been created. The event will be lost')
return
}
const message = Buffer.from(body)
const discordData = JSON.stringify(payload.d)
const deduplicationHash = createHash('sha1')
deduplicationHash.update(discordData)
rabbitMQChannel.publish('gatewayMessage', '', message, {
contentType: 'application/json',
headers: {
'x-deduplication-header': deduplicationHash.digest('hex'),
},
})
return
}
const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length]
if (!url) {
logger.error('No url found to send events to')
return
}
await fetch(url, {
method: 'POST',
body,
headers: {
'Content-Type': 'application/json',
Authorization: workerData.eventHandler.authentication,
},
}).catch((error) => logger.error('Failed to send events to the bot code', error))
}
async function connectToRabbitMQ(): Promise<void> {
rabbitMQChannel = undefined
const messageQueue = workerData.messageQueue