mirror of
https://github.com/discordeno/discordeno.git
synced 2026-05-21 02:40:08 +00:00
feat(gateway)!: Resharding with workers (#4206)
* feat(gateway): Rework how shards are resharded * Make bigbot example use the resharding * fix reshard doesn't increment the lastShardId From my testing it works, however i don't know if we should just do this or is there a better way * revert createGatewayManager type changes It breaks if you want to disable resharding, didn't think of that * Fix typo (#4252) * Apply suggestions from code review Co-authored-by: Link <lts20050703@gmail.com> --------- Co-authored-by: NotDemonix <90858555+NotDemonix@users.noreply.github.com> Co-authored-by: Link <lts20050703@gmail.com>
This commit is contained in:
@@ -25,8 +25,51 @@ const gatewayManager = createGatewayManager({
|
||||
shardsPerWorker: SHARDS_PER_WORKER,
|
||||
totalShards: TOTAL_SHARDS,
|
||||
totalWorkers: TOTAL_WORKERS,
|
||||
resharding: {
|
||||
getSessionInfo: restManager.getGatewayBot,
|
||||
},
|
||||
})
|
||||
|
||||
gatewayManager.resharding.tellWorkerToPrepare = async (workerId, shardId, bucketId) => {
|
||||
logger.info(`Tell worker to prepare, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`)
|
||||
|
||||
let worker = workers.get(workerId)
|
||||
if (!worker) {
|
||||
worker = createWorker(workerId)
|
||||
workers.set(workerId, worker)
|
||||
}
|
||||
|
||||
worker.postMessage({
|
||||
type: 'PrepareShard',
|
||||
shardId,
|
||||
totalShards: gatewayManager.totalShards,
|
||||
} satisfies WorkerMessage)
|
||||
|
||||
const { promise, resolve } = promiseWithResolvers<void>()
|
||||
|
||||
const waitForShardPrepared = (message: ManagerMessage) => {
|
||||
if (message.type === 'ShardPrepared' && message.shardId === shardId) {
|
||||
resolve()
|
||||
}
|
||||
}
|
||||
|
||||
worker.on('message', waitForShardPrepared)
|
||||
|
||||
await promise
|
||||
|
||||
worker.off('message', waitForShardPrepared)
|
||||
}
|
||||
|
||||
gatewayManager.resharding.onReshardingSwitch = async () => {
|
||||
logger.info('Resharding switch triggered, telling workers to switch the shards')
|
||||
|
||||
for (const worker of workers.values()) {
|
||||
worker.postMessage({
|
||||
type: 'SwitchShards',
|
||||
} satisfies WorkerMessage)
|
||||
}
|
||||
}
|
||||
|
||||
gatewayManager.tellWorkerToIdentify = async (workerId, shardId, bucketId) => {
|
||||
logger.info(`Tell worker to identify, workerId: ${workerId}, shardId: ${shardId}, bucketId: ${bucketId}`)
|
||||
|
||||
|
||||
@@ -1,13 +1,30 @@
|
||||
import type { DiscordUpdatePresence, ShardSocketRequest } from '@discordeno/bot'
|
||||
|
||||
export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardInfo
|
||||
export type WorkerMessage = WorkerIdentifyShard | WorkerAllowIdentify | WorkerShardPayload | WorkerPresencesUpdate | WorkerShardInfo
|
||||
export type ManagerMessage = ManagerRequestIdentify | ManagerShardIdentified | ManagerShardPrepared | ManagerShardInfo
|
||||
export type WorkerMessage =
|
||||
| WorkerIdentifyShard
|
||||
| WorkerPrepareShard
|
||||
| WorkerSwitchShards
|
||||
| WorkerAllowIdentify
|
||||
| WorkerShardPayload
|
||||
| WorkerPresencesUpdate
|
||||
| WorkerShardInfo
|
||||
|
||||
export interface WorkerIdentifyShard {
|
||||
type: 'IdentifyShard'
|
||||
shardId: number
|
||||
}
|
||||
|
||||
export interface WorkerPrepareShard {
|
||||
type: 'PrepareShard'
|
||||
shardId: number
|
||||
totalShards: number
|
||||
}
|
||||
|
||||
export interface WorkerSwitchShards {
|
||||
type: 'SwitchShards'
|
||||
}
|
||||
|
||||
export interface WorkerAllowIdentify {
|
||||
type: 'AllowIdentify'
|
||||
shardId: number
|
||||
@@ -76,3 +93,8 @@ export interface ManagerShardIdentified {
|
||||
type: 'ShardIdentified'
|
||||
shardId: number
|
||||
}
|
||||
|
||||
export interface ManagerShardPrepared {
|
||||
type: 'ShardPrepared'
|
||||
shardId: number
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import assert from 'node:assert'
|
||||
import { createHash } from 'node:crypto'
|
||||
import { workerData as _workerData, parentPort } from 'node:worker_threads'
|
||||
import { DiscordenoShard, GatewayOpcodes, createLogger } from '@discordeno/bot'
|
||||
import { type Camelize, type DiscordGatewayPayload, DiscordenoShard, GatewayOpcodes, ShardSocketCloseCodes, createLogger } from '@discordeno/bot'
|
||||
import { type Channel as amqpChannel, connect as connectAmqp } from 'amqplib'
|
||||
import { promiseWithResolvers } from '../../util.js'
|
||||
import type { ManagerMessage, WorkerCreateData, WorkerMessage } from './types.js'
|
||||
@@ -14,6 +14,9 @@ const logger = createLogger({ name: `Worker #${workerData.workerId}` })
|
||||
|
||||
const identifyPromises = new Map<number, () => void>()
|
||||
const shards = new Map<number, DiscordenoShard>()
|
||||
const pendingShards = new Map<number, DiscordenoShard>()
|
||||
|
||||
let totalShards = workerData.connectionData.totalShards
|
||||
|
||||
let rabbitMQChannel: amqpChannel | undefined
|
||||
|
||||
@@ -38,6 +41,68 @@ parentPort.on('message', async (message: WorkerMessage) => {
|
||||
|
||||
return
|
||||
}
|
||||
if (message.type === 'PrepareShard') {
|
||||
logger.info(`Preparing shard #${message.shardId}`)
|
||||
totalShards = message.totalShards
|
||||
let shard = pendingShards.get(message.shardId)
|
||||
if (!shard) {
|
||||
shard = createShard(message.shardId)
|
||||
pendingShards.set(message.shardId, shard)
|
||||
}
|
||||
|
||||
// Ignore the events
|
||||
// TODO: If you need 'gateway.resharding.updateGuildsShardId' it you can listen to only the ready event and use the data from that event for the function call
|
||||
shard.events.message = () => {}
|
||||
|
||||
await shard.identify()
|
||||
|
||||
parentPort.postMessage({
|
||||
type: 'ShardPrepared',
|
||||
shardId: message.shardId,
|
||||
} satisfies ManagerMessage)
|
||||
|
||||
return
|
||||
}
|
||||
if (message.type === 'SwitchShards') {
|
||||
logger.info('Switching shards')
|
||||
|
||||
// Change the message event for all shards
|
||||
for (const shard of pendingShards.values()) {
|
||||
shard.events.message = handleShardMessageEvent
|
||||
}
|
||||
|
||||
// Old shards stop processing events
|
||||
for (const shard of shards.values()) {
|
||||
const oldHandler = shard.events.message
|
||||
|
||||
shard.events.message = async function (_, message) {
|
||||
// Member checks need to continue but others can stop
|
||||
if (message.t === 'GUILD_MEMBERS_CHUNK') {
|
||||
oldHandler?.(shard, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the old shards
|
||||
const shardsToShutdown = Array.from(shards.values())
|
||||
|
||||
// Move the pending shards to the active shards
|
||||
shards.clear()
|
||||
for (const [shardId, shard] of pendingShards.entries()) {
|
||||
shards.set(shardId, shard)
|
||||
pendingShards.delete(shardId)
|
||||
}
|
||||
|
||||
// Shutdown the old shards
|
||||
const promises = shardsToShutdown.map(async (shard) => {
|
||||
await shard.close(ShardSocketCloseCodes.Resharded, 'Shard is being resharded')
|
||||
logger.info(`Shard #${shard.id} has been shutdown`)
|
||||
})
|
||||
|
||||
await Promise.all(promises)
|
||||
|
||||
return
|
||||
}
|
||||
if (message.type === 'AllowIdentify') {
|
||||
identifyPromises.get(message.shardId)?.()
|
||||
identifyPromises.delete(message.shardId)
|
||||
@@ -99,7 +164,7 @@ function createShard(shardId: number): DiscordenoShard {
|
||||
device: 'Discordeno',
|
||||
},
|
||||
token: workerData.connectionData.token,
|
||||
totalShards: workerData.connectionData.totalShards,
|
||||
totalShards: totalShards,
|
||||
url: workerData.connectionData.url,
|
||||
version: workerData.connectionData.version,
|
||||
transportCompression: null,
|
||||
@@ -126,50 +191,52 @@ function createShard(shardId: number): DiscordenoShard {
|
||||
shard.events.message?.(shard, packet)
|
||||
}
|
||||
|
||||
shard.events.message = async (shard, payload) => {
|
||||
const body = JSON.stringify({ payload, shardId: shard.id })
|
||||
|
||||
if (workerData.messageQueue.enabled) {
|
||||
if (!rabbitMQChannel) {
|
||||
logger.error('The RabbitMQ channel has not been created. The event will be lost')
|
||||
return
|
||||
}
|
||||
|
||||
const message = Buffer.from(body)
|
||||
const discordData = JSON.stringify(payload.d)
|
||||
|
||||
const deduplicationHash = createHash('sha1')
|
||||
deduplicationHash.update(discordData)
|
||||
|
||||
rabbitMQChannel.publish('gatewayMessage', '', message, {
|
||||
contentType: 'application/json',
|
||||
headers: {
|
||||
'x-deduplication-header': deduplicationHash.digest('hex'),
|
||||
},
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length]
|
||||
if (!url) {
|
||||
logger.error('No url found to send events to')
|
||||
return
|
||||
}
|
||||
|
||||
await fetch(url, {
|
||||
method: 'POST',
|
||||
body,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: workerData.eventHandler.authentication,
|
||||
},
|
||||
}).catch((error) => logger.error('Failed to send events to the bot code', error))
|
||||
}
|
||||
shard.events.message = handleShardMessageEvent
|
||||
|
||||
return shard
|
||||
}
|
||||
|
||||
async function handleShardMessageEvent(shard: DiscordenoShard, payload: Camelize<DiscordGatewayPayload>) {
|
||||
const body = JSON.stringify({ payload, shardId: shard.id })
|
||||
|
||||
if (workerData.messageQueue.enabled) {
|
||||
if (!rabbitMQChannel) {
|
||||
logger.error('The RabbitMQ channel has not been created. The event will be lost')
|
||||
return
|
||||
}
|
||||
|
||||
const message = Buffer.from(body)
|
||||
const discordData = JSON.stringify(payload.d)
|
||||
|
||||
const deduplicationHash = createHash('sha1')
|
||||
deduplicationHash.update(discordData)
|
||||
|
||||
rabbitMQChannel.publish('gatewayMessage', '', message, {
|
||||
contentType: 'application/json',
|
||||
headers: {
|
||||
'x-deduplication-header': deduplicationHash.digest('hex'),
|
||||
},
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const url = workerData.eventHandler.urls[shard.id % workerData.eventHandler.urls.length]
|
||||
if (!url) {
|
||||
logger.error('No url found to send events to')
|
||||
return
|
||||
}
|
||||
|
||||
await fetch(url, {
|
||||
method: 'POST',
|
||||
body,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: workerData.eventHandler.authentication,
|
||||
},
|
||||
}).catch((error) => logger.error('Failed to send events to the bot code', error))
|
||||
}
|
||||
|
||||
async function connectToRabbitMQ(): Promise<void> {
|
||||
rabbitMQChannel = undefined
|
||||
const messageQueue = workerData.messageQueue
|
||||
|
||||
Reference in New Issue
Block a user