diff --git a/packages/rest/src/manager.ts b/packages/rest/src/manager.ts index 5f3c353f9..06b9d60a3 100644 --- a/packages/rest/src/manager.ts +++ b/packages/rest/src/manager.ts @@ -89,6 +89,7 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage globallyRateLimited: false, invalidBucket: createInvalidRequestBucket({ logger: options.logger }), isProxied: !baseUrl.startsWith(DISCORD_API_URL), + updateBearerTokenEndpoint: options.proxy?.updateBearerTokenEndpoint, maxRetryCount: Infinity, processingRateLimitedPaths: false, queues: new Map(), @@ -105,10 +106,8 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage } }, - checkRateLimits(url, headers) { - const authHeader = headers?.authorization ?? '' - - const ratelimited = rest.rateLimitedPaths.get(`${authHeader}${url}`) + checkRateLimits(url, requestAuthorization) { + const ratelimited = rest.rateLimitedPaths.get(`${requestAuthorization}${url}`) const global = rest.rateLimitedPaths.get('global') const now = Date.now() @@ -124,6 +123,71 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage return false }, + async updateTokenQueues(oldToken, newToken) { + if (rest.isProxied) { + if (!rest.updateBearerTokenEndpoint) { + throw new Error( + "The 'proxy.updateBearerTokenEndpoint' option needs to be set when using a rest proxy and needed to call 'updateTokenQueues'", + ) + } + + const headers = { + 'content-type': 'application/json', + } as Record + + if (rest.authorization !== undefined) { + headers[rest.authorizationHeader] = rest.authorization + } + + await fetch(`${rest.baseUrl}/${rest.updateBearerTokenEndpoint}`, { + method: 'POST', + body: JSON.stringify({ + oldToken, + newToken, + }), + headers, + }) + + return + } + + const newAuthorization = `Bearer ${newToken}` + + // Update all the queues + for (const [key, queue] of rest.queues.entries()) { + if (!key.startsWith(`Bearer ${oldToken}`)) continue + + rest.queues.delete(key) + queue.requestAuthorization = newAuthorization + + const newKey = `${newAuthorization}${queue.url}` + const newQueue = rest.queues.get(newKey) + + // Merge the queues + if (newQueue) { + newQueue.waiting.unshift(...queue.waiting) + newQueue.pending.unshift(...queue.pending) + + queue.waiting = [] + queue.pending = [] + + queue.cleanup() + } else { + rest.queues.set(newKey, queue) + } + } + + for (const [key, ratelimitPath] of rest.rateLimitedPaths.entries()) { + if (!key.startsWith(`Bearer ${oldToken}`)) continue + + rest.rateLimitedPaths.set(`${newAuthorization}${ratelimitPath.url}`, ratelimitPath) + + if (ratelimitPath.bucketId) { + rest.rateLimitedPaths.set(`${newAuthorization}${ratelimitPath.bucketId}`, ratelimitPath) + } + } + }, + changeToDiscordFormat(obj: any): any { if (obj === null) return null @@ -320,7 +384,7 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage }) if (bucketId) { - rest.rateLimitedPaths.set(`${requestAuthorization}${bucketId}`, { + rest.rateLimitedPaths.set(requestAuthorization, { url: 'global', resetTimestamp: globalReset, bucketId, @@ -340,10 +404,9 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage const loggingHeaders = { ...payload.headers } - const authenticationScheme = payload.headers.authorization?.split(' ')[0] - if (payload.headers.authorization) { - loggingHeaders.authorization = `${authenticationScheme} tokenhere` + const authorizationScheme = payload.headers.authorization?.split(' ')[0] + loggingHeaders.authorization = `${authorizationScheme} tokenhere` } rest.logger.debug(`sending request to ${url}`, 'with payload:', { ...payload, headers: loggingHeaders }) @@ -364,11 +427,8 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage rest.invalidBucket.handleCompletedRequest(response.status, response.headers.get(RATE_LIMIT_SCOPE_HEADER) === 'shared') // Set the bucket id if it was available on the headers - const bucketId = rest.processHeaders( - rest.simplifyUrl(options.route, options.method), - response.headers, - authenticationScheme === 'Bearer' ? payload.headers.authorization : '', - ) + const bucketId = rest.processHeaders(rest.simplifyUrl(options.route, options.method), response.headers, payload.headers.authorization) + if (bucketId) options.bucketId = bucketId if (response.status < HttpResponseCode.Success || response.status >= HttpResponseCode.Error) { @@ -443,20 +503,21 @@ export function createRestManager(options: CreateRestManagerOptions): RestManage return } - const authHeader = request.requestBodyOptions?.headers?.authorization ?? '' + const authorization = request.requestBodyOptions?.headers?.authorization ?? `Bot ${rest.token}` - const queue = rest.queues.get(`${authHeader}${url}`) + const queue = rest.queues.get(`${authorization}${url}`) if (queue !== undefined) { queue.makeRequest(request) } else { // CREATES A NEW QUEUE - const bucketQueue = new Queue(rest, { url, deleteQueueDelay: rest.deleteQueueDelay, authentication: authHeader }) + const bucketQueue = new Queue(rest, { url, deleteQueueDelay: rest.deleteQueueDelay, requestAuthorization: authorization }) + + // Save queue + rest.queues.set(`${authorization}${url}`, bucketQueue) // Add request to queue bucketQueue.makeRequest(request) - // Save queue - rest.queues.set(`${authHeader}${url}`, bucketQueue) } }, diff --git a/packages/rest/src/queue.ts b/packages/rest/src/queue.ts index f2fc9da4b..073b68dc2 100644 --- a/packages/rest/src/queue.ts +++ b/packages/rest/src/queue.ts @@ -28,19 +28,26 @@ export class Queue { frozenAt: number = 0 /** The time in milliseconds to wait before deleting this queue if it is empty. Defaults to 60000(one minute). */ deleteQueueDelay: number = 60000 - /** The authentication header used for the OAuth2 request. Defaults to an empty string for non-OAuth2 requests */ - authentication: string = '' + /** The timeout for the deletion of this queue */ + deleteQueueTimeout?: NodeJS.Timeout + /** + * The authorization being used for the requests in this queue + * + * @remarks + * This is also used to get the key this queue is stored as in the queue mapping of the rest manager + */ + requestAuthorization: string constructor(rest: RestManager, options: QueueOptions) { this.rest = rest this.url = options.url + this.requestAuthorization = options.requestAuthorization if (options.interval) this.interval = options.interval if (options.max) this.max = options.max if (options.remaining) this.remaining = options.remaining if (options.timeoutId) this.timeoutId = options.timeoutId if (options.deleteQueueDelay) this.deleteQueueDelay = options.deleteQueueDelay - if (options.authentication) this.authentication = options.authentication } /** Check if there is any remaining requests that are allowed. */ @@ -71,7 +78,7 @@ export class Queue { this.processing = true while (this.waiting.length > 0) { - this.rest.logger.debug(`[Queue] ${this.isOauth2Queue() ? '' : 'Bearer '}${this.url} process waiting while loop ran.`) + this.rest.logger.debug(`[Queue] ${this.getQueueType()} ${this.url} process waiting while loop ran.`) if (this.isRequestAllowed()) { // Resolve the next item in the queue this.waiting.shift()?.() @@ -93,7 +100,7 @@ export class Queue { this.processingPending = true while (this.pending.length > 0) { - this.rest.logger.debug(`Queue ${this.isOauth2Queue() ? '' : 'Bearer '}${this.url} process pending while loop ran with ${this.pending.length}.`) + this.rest.logger.debug(`Queue ${this.getQueueType()} ${this.url} process pending while loop ran with ${this.pending.length}.`) if (!this.firstRequest && !this.isRequestAllowed()) { const now = Date.now() const future = this.frozenAt + this.interval @@ -106,18 +113,18 @@ export class Queue { const basicURL = this.rest.simplifyUrl(request.route, request.method) // If this url is still rate limited, try again - const urlResetIn = this.rest.checkRateLimits(basicURL, request.requestBodyOptions?.headers) + const urlResetIn = this.rest.checkRateLimits(basicURL, this.requestAuthorization) if (urlResetIn) await delay(urlResetIn) // IF A BUCKET EXISTS, CHECK THE BUCKET'S RATE LIMITS - const bucketResetIn = request.bucketId ? this.rest.checkRateLimits(request.bucketId, request.requestBodyOptions?.headers) : false + const bucketResetIn = request.bucketId ? this.rest.checkRateLimits(request.bucketId, this.requestAuthorization) : false if (bucketResetIn) await delay(bucketResetIn) this.firstRequest = false this.remaining-- - if (this.timeoutId && this.remaining === 0 && this.interval !== 0) { - this.timeoutId = setTimeout(() => { + if (this.remaining === 0 && this.interval !== 0) { + this.timeoutId ??= setTimeout(() => { this.remaining = this.max this.timeoutId = undefined }, this.interval) @@ -128,6 +135,8 @@ export class Queue { // Check if this request is able to be made globally await this.rest.invalidBucket.waitUntilRequestAvailable() + if (request.requestBodyOptions?.headers?.authorization) request.requestBodyOptions.headers.authorization = this.requestAuthorization + await this.rest .sendRequest(request) // Should be handled in sendRequest, this catch just prevents bots from dying @@ -135,7 +144,7 @@ export class Queue { } } - this.rest.logger.debug(`Queue ${this.isOauth2Queue() ? '' : 'Bearer '}${this.url} process pending while loop exited with ${this.pending.length}.`) + this.rest.logger.debug(`Queue ${this.getQueueType()} ${this.url} process pending while loop exited with ${this.pending.length}.`) // Mark as false so next pending request can be triggered by new loop. this.processingPending = false @@ -153,7 +162,7 @@ export class Queue { if (headers.remaining !== undefined) this.remaining = headers.remaining if (this.remaining <= 1) { - this.timeoutId = setTimeout(() => { + this.timeoutId ??= setTimeout(() => { this.remaining = this.max this.timeoutId = undefined }, headers.interval) @@ -174,24 +183,27 @@ export class Queue { return } - this.rest.logger.debug(`[Queue] ${this.isOauth2Queue() ? '' : 'Bearer '}${this.url}. Delaying delete for ${this.deleteQueueDelay}ms`) + this.rest.logger.debug(`[Queue] ${this.getQueueType()} ${this.url}. Delaying delete for ${this.deleteQueueDelay}ms`) + // Delete in a minute giving a bit of time to allow new requests that may reuse this queue - setTimeout(async () => { + clearTimeout(this.deleteQueueTimeout) + this.deleteQueueTimeout = setTimeout(() => { if (!this.isQueueClearable()) { - this.rest.logger.debug(`[Queue] ${this.isOauth2Queue() ? '' : 'Bearer '}${this.url}. is not clearable. Restarting processing of queue.`) + this.rest.logger.debug(`[Queue] ${this.getQueueType()} ${this.url}. is not clearable. Restarting processing of queue.`) this.processPending() return } - this.rest.logger.debug(`[Queue] ${this.url}. Deleting`) + this.rest.logger.debug(`[Queue] ${this.getQueueType()} ${this.url}. Deleting`) + if (this.timeoutId) clearTimeout(this.timeoutId) + // No requests have been requested for this queue so we nuke this queue - this.rest.queues.delete(`${this.authentication}${this.url}`) + this.rest.queues.delete(`${this.requestAuthorization}${this.url}`) this.rest.logger.debug( - `[Queue] ${this.url}. Deleted! Remaining: (${this.rest.queues.size})`, - [...this.rest.queues.values()].map((queue) => `${queue.isOauth2Queue() ? '' : 'Bearer '}${queue.url}`), + `[Queue] ${this.getQueueType()} ${this.url}. Deleted! Remaining: (${this.rest.queues.size})`, + [...this.rest.queues.values()].map((queue) => `${queue.getQueueType()}${queue.url}`), ) - if (this.rest.queues.size) this.processPending() }, this.deleteQueueDelay) } @@ -200,15 +212,14 @@ export class Queue { if (this.firstRequest) return false if (this.waiting.length > 0) return false if (this.pending.length > 0) return false - if (this.interval === 0) return false if (this.processing) return false if (this.processingPending) return false return true } - isOauth2Queue(): boolean { - return this.authentication === '' + getQueueType(): string { + return this.requestAuthorization.split(' ')[0] } } @@ -225,6 +236,6 @@ export interface QueueOptions { url: string /** The time in milliseconds to wait before deleting this queue if it is empty. Defaults to 60000(one minute). */ deleteQueueDelay?: number - /** Authentication used for the request. In non-OAuth2 situations should be an empty string. Defaults to an empty string */ - authentication?: string + /** The base key that identifies this queue in the rest manager */ + requestAuthorization: string } diff --git a/packages/rest/src/types.ts b/packages/rest/src/types.ts index f740ec7af..78c3810c6 100644 --- a/packages/rest/src/types.ts +++ b/packages/rest/src/types.ts @@ -165,6 +165,15 @@ export interface CreateRestManagerOptions { * @default "authorization" // For compatibility purposes */ authorizationHeader?: string + /** + * The endpoint to use in the rest proxy to update the bearer tokens + * + * @remarks + * Should not include a `/` in the start + * + * This value is actually required if you want to use `updateTokenQueues` + */ + updateBearerTokenEndpoint?: string } /** * The api versions which can be used to make requests. @@ -201,6 +210,8 @@ export interface RestManager { authorization?: string /** The authorization header name to attach when sending requests to the proxy */ authorizationHeader: string + /** The endpoint to use for `updateTokenQueues` when working with a rest proxy */ + updateBearerTokenEndpoint?: string /** The maximum amount of times a request should be retried. Defaults to Infinity */ maxRetryCount: number /** Whether or not the manager is rate limited globally across all requests. Defaults to false. */ @@ -224,20 +235,17 @@ export interface RestManager { /** Whether or not the rest manager should keep objects in raw snake case from discord. */ preferSnakeCase: (enabled: boolean) => RestManager /** Check the rate limits for a url or a bucket. */ - checkRateLimits: (url: string, headers?: Record) => number | false + checkRateLimits: (url: string, requestAuthorization: string) => number | false + /* Update the queues and ratelimit information to adapt to the new token */ + updateTokenQueues: (oldToken: string, newToken: string) => Promise /** Reshapes and modifies the obj as needed to make it ready for discords api. */ changeToDiscordFormat: (obj: any) => any /** Creates the request body and headers that are necessary to send a request. Will handle different types of methods and everything necessary for discord. */ createRequestBody: (method: RequestMethods, options?: CreateRequestBodyOptions) => RequestBody /** This will create a infinite loop running in 1 seconds using tail recursion to keep rate limits clean. When a rate limit resets, this will remove it so the queue can proceed. */ processRateLimitedPaths: () => void - /** - * Processes the rate limit headers and determines if it needs to be rate limited and returns the bucket id if available - * - * @remarks - * The authenticationHeader should be defined ONLY if the request was done using a OAuth2 Access Token, in other cases it should be passed as an empty string - */ - processHeaders: (url: string, headers: Headers, authenticationHeader?: string) => string | undefined + /** Processes the rate limit headers and determines if it needs to be rate limited and returns the bucket id if available */ + processHeaders: (url: string, headers: Headers, requestAuthorization: string) => string | undefined /** Sends a request to the api. */ sendRequest: (options: SendRequestOptions) => Promise /** Split a url to separate rate limit buckets based on major/minor parameters. */ diff --git a/packages/rest/tests/unit/manager.spec.ts b/packages/rest/tests/unit/manager.spec.ts index 03e4956e0..0be294aab 100644 --- a/packages/rest/tests/unit/manager.spec.ts +++ b/packages/rest/tests/unit/manager.spec.ts @@ -110,24 +110,24 @@ describe('[rest] manager', () => { }) it('will return false for path without rate limited', () => { - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(false) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(false) }) describe('With per URL rateLimitedPath', () => { it('Will return time until reset if before resetTimestamp', () => { - rest.rateLimitedPaths.set('/channel/555555555555555555', { + rest.rateLimitedPaths.set(`Bot ${token}/channel/555555555555555555`, { url: '/channel/555555555555555555', resetTimestamp: Date.now() + 6541, }) - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(6541) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(6541) }) it('Will return false if before resetTimestamp', () => { - rest.rateLimitedPaths.set('/channel/555555555555555555', { + rest.rateLimitedPaths.set(`Bot ${token}/channel/555555555555555555`, { url: '/channel/555555555555555555', resetTimestamp: Date.now(), }) - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(false) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(false) }) }) @@ -137,7 +137,7 @@ describe('[rest] manager', () => { url: '/channel/555555555555555555', resetTimestamp: Date.now() + 9849, }) - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(9849) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(9849) }) it('Will return false if before resetTimestamp', () => { @@ -145,13 +145,13 @@ describe('[rest] manager', () => { url: '/channel/555555555555555555', resetTimestamp: Date.now(), }) - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(false) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(false) }) }) describe('With both URL and Global rateLimitedPath', () => { it('Will return URL time first if before resetTimestamp', () => { - rest.rateLimitedPaths.set('/channel/555555555555555555', { + rest.rateLimitedPaths.set(`Bot ${token}/channel/555555555555555555`, { url: '/channel/555555555555555555', resetTimestamp: Date.now() + 6541, }) @@ -159,7 +159,7 @@ describe('[rest] manager', () => { url: '/channel/555555555555555555', resetTimestamp: Date.now() + 9849, }) - expect(rest.checkRateLimits('/channel/555555555555555555')).to.be.equal(6541) + expect(rest.checkRateLimits('/channel/555555555555555555', `Bot ${token}`)).to.be.equal(6541) }) }) })