diff --git a/index.js b/index.js index e57f75c6..70792edc 100644 --- a/index.js +++ b/index.js @@ -64,11 +64,18 @@ async function fastifyRateLimit (fastify, settings) { : defaultMax // Global time window - globalParams.timeWindow = typeof settings.timeWindow === 'string' - ? ms.parse(settings.timeWindow) - : typeof settings.timeWindow === 'number' && Number.isFinite(settings.timeWindow) && settings.timeWindow >= 0 - ? Math.trunc(settings.timeWindow) - : defaultTimeWindow + const twType = typeof settings.timeWindow + globalParams.timeWindow = defaultTimeWindow + if (twType === 'function') { + globalParams.timeWindow = settings.timeWindow + } else if (twType === 'string') { + globalParams.timeWindow = ms.parse(settings.timeWindow) + } else if ( + twType === 'number' && + Number.isFinite(settings.timeWindow) && settings.timeWindow >= 0 + ) { + globalParams.timeWindow = Math.trunc(settings.timeWindow) + } globalParams.hook = settings.hook || defaultHook globalParams.allowList = settings.allowList || settings.whitelist || null @@ -147,7 +154,7 @@ function mergeParams (...params) { result.timeWindow = ms.parse(result.timeWindow) } else if (typeof result.timeWindow === 'number' && Number.isFinite(result.timeWindow) && result.timeWindow >= 0) { result.timeWindow = Math.trunc(result.timeWindow) - } else { + } else if (typeof result.timeWindow !== 'function') { result.timeWindow = defaultTimeWindow } @@ -180,7 +187,6 @@ function addRouteRateHook (pluginComponent, params, routeOptions) { function rateLimitRequestHandler (pluginComponent, params) { const { rateLimitRan, store } = pluginComponent - const timeWindowString = ms.format(params.timeWindow, true) return async (req, res) => { if (req[rateLimitRan]) { @@ -204,6 +210,7 @@ function rateLimitRequestHandler (pluginComponent, params) { } const max = typeof params.max === 'number' ? params.max : await params.max(req, key) + const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) let current = 0 let ttl = 0 let timeLeftInSeconds = 0 @@ -213,7 +220,7 @@ function rateLimitRequestHandler (pluginComponent, params) { const res = await new Promise((resolve, reject) => { store.incr(key, (err, res) => { err ? reject(err) : resolve(res) - }, max) + }, timeWindow, max) }) current = res.current @@ -248,7 +255,7 @@ function rateLimitRequestHandler (pluginComponent, params) { ban: false, max, ttl, - after: timeWindowString + after: ms.format(timeWindow, true) } if (params.ban !== -1 && current - max > params.ban) { diff --git a/store/LocalStore.js b/store/LocalStore.js index 9fed5006..e3024bbf 100644 --- a/store/LocalStore.js +++ b/store/LocalStore.js @@ -4,21 +4,20 @@ const { LruMap: Lru } = require('toad-cache') function LocalStore (cache = 5000, timeWindow, continueExceeding) { this.lru = new Lru(cache) - this.timeWindow = timeWindow this.continueExceeding = continueExceeding } -LocalStore.prototype.incr = function (ip, cb, max) { +LocalStore.prototype.incr = function (ip, cb, timeWindow, max) { const nowInMs = Date.now() let current = this.lru.get(ip) if (!current) { // Item doesn't exist - current = { current: 1, ttl: this.timeWindow, iterationStartMs: nowInMs } - } else if (current.iterationStartMs + this.timeWindow <= nowInMs) { + current = { current: 1, ttl: timeWindow, iterationStartMs: nowInMs } + } else if (current.iterationStartMs + timeWindow <= nowInMs) { // Item has expired current.current = 1 - current.ttl = this.timeWindow + current.ttl = timeWindow current.iterationStartMs = nowInMs } else { // Item is alive @@ -26,10 +25,10 @@ LocalStore.prototype.incr = function (ip, cb, max) { // Reset TLL if max has been exceeded and `continueExceeding` is enabled if (this.continueExceeding && current.current > max) { - current.ttl = this.timeWindow + current.ttl = timeWindow current.iterationStartMs = nowInMs } else { - current.ttl = this.timeWindow - (nowInMs - current.iterationStartMs) + current.ttl = timeWindow - (nowInMs - current.iterationStartMs) } } diff --git a/store/RedisStore.js b/store/RedisStore.js index ebef6786..e3c5b2c9 100644 --- a/store/RedisStore.js +++ b/store/RedisStore.js @@ -39,8 +39,8 @@ function RedisStore (redis, key = 'fastify-rate-limit-', timeWindow, continueExc } } -RedisStore.prototype.incr = function (ip, cb, max) { - this.redis.rateLimit(this.key + ip, this.timeWindow, max, this.continueExceeding, (err, result) => { +RedisStore.prototype.incr = function (ip, cb, timeWindow, max) { + this.redis.rateLimit(this.key + ip, timeWindow, max, this.continueExceeding, (err, result) => { err ? cb(err, null) : cb(null, { current: result[0], ttl: result[1] }) }) } diff --git a/test/global-rate-limit.test.js b/test/global-rate-limit.test.js index 05a9362d..8bdc0694 100644 --- a/test/global-rate-limit.test.js +++ b/test/global-rate-limit.test.js @@ -106,6 +106,54 @@ test('With text timeWindow', async t => { }) }) +test('With function timeWindow', async t => { + t.plan(15) + t.context.clock = FakeTimers.install() + const fastify = Fastify() + await fastify.register(rateLimit, { max: 2, timeWindow: (_, __) => 1000 }) + + fastify.get('/', async (req, reply) => 'hello!') + + let res + + res = await fastify.inject('/') + + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '1') + + res = await fastify.inject('/') + + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '0') + + res = await fastify.inject('/') + + t.equal(res.statusCode, 429) + t.equal(res.headers['content-type'], 'application/json; charset=utf-8') + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '0') + t.equal(res.headers['retry-after'], '1') + t.same({ + statusCode: 429, + error: 'Too Many Requests', + message: 'Rate limit exceeded, retry in 1 second' + }, JSON.parse(res.payload)) + + t.context.clock.tick(1100) + + res = await fastify.inject('/') + + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '1') + + t.teardown(() => { + t.context.clock.uninstall() + }) +}) + test('When passing NaN to the timeWindow property then the timeWindow should be the default value - 60 seconds', async t => { t.plan(5) diff --git a/test/route-rate-limit.test.js b/test/route-rate-limit.test.js index c3029293..2336a5f2 100644 --- a/test/route-rate-limit.test.js +++ b/test/route-rate-limit.test.js @@ -92,7 +92,69 @@ test('With text timeWindow', async t => { await fastify.register(rateLimit, { global: false }) fastify.get('/', { - config: defaultRouteConfig + config: { + rateLimit: { + max: 2, + timeWindow: '1s' + }, + someOtherPlugin: { + someValue: 1 + } + } + }, async (req, reply) => 'hello!') + + let res + + res = await fastify.inject('/') + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '1') + + res = await fastify.inject('/') + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '0') + + res = await fastify.inject('/') + t.equal(res.statusCode, 429) + t.equal(res.headers['content-type'], 'application/json; charset=utf-8') + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '0') + t.equal(res.headers['retry-after'], '1') + t.same(JSON.parse(res.payload), { + statusCode: 429, + error: 'Too Many Requests', + message: 'Rate limit exceeded, retry in 1 second' + }) + + t.context.clock.tick(1100) + + res = await fastify.inject('/') + t.equal(res.statusCode, 200) + t.equal(res.headers['x-ratelimit-limit'], '2') + t.equal(res.headers['x-ratelimit-remaining'], '1') + + t.teardown(() => { + t.context.clock.uninstall() + }) +}) + +test('With function timeWindow', async t => { + t.plan(15) + t.context.clock = FakeTimers.install() + const fastify = Fastify() + await fastify.register(rateLimit, { global: false }) + + fastify.get('/', { + config: { + rateLimit: { + max: 2, + timeWindow: (_, __) => 1000 + }, + someOtherPlugin: { + someValue: 1 + } + } }, async (req, reply) => 'hello!') let res diff --git a/types/index.d.ts b/types/index.d.ts index f3ddcf84..b50bc878 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -80,10 +80,14 @@ declare namespace fastifyRateLimit { export interface RateLimitOptions { max?: - | number - | ((req: FastifyRequest, key: string) => number) - | ((req: FastifyRequest, key: string) => Promise); - timeWindow?: number | string; + | number + | ((req: FastifyRequest, key: string) => number) + | ((req: FastifyRequest, key: string) => Promise); + timeWindow?: + | number + | string + | ((req: FastifyRequest, key: string) => number) + | ((req: FastifyRequest, key: string) => Promise); hook?: RateLimitHook; cache?: number; store?: FastifyRateLimitStoreCtor; diff --git a/types/index.test-d.ts b/types/index.test-d.ts index e18dc614..5e070a5e 100644 --- a/types/index.test-d.ts +++ b/types/index.test-d.ts @@ -114,6 +114,22 @@ const options6: RateLimitPluginOptions = { hook: 'preHandler' } +const options7: RateLimitPluginOptions = { + global: true, + max: (req: FastifyRequest, key: string) => 42, + timeWindow: (req: FastifyRequest, key: string) => 5000, + store: CustomStore, + hook: 'preValidation' +} + +const options8: RateLimitPluginOptions = { + global: true, + max: (req: FastifyRequest, key: string) => 42, + timeWindow: (req: FastifyRequest, key: string) => Promise.resolve(5000), + store: CustomStore, + hook: 'preValidation' +} + appWithImplicitHttp.register(fastifyRateLimit, options1) appWithImplicitHttp.register(fastifyRateLimit, options2) appWithImplicitHttp.register(fastifyRateLimit, options5) @@ -144,6 +160,8 @@ appWithHttp2.register(fastifyRateLimit, options1) appWithHttp2.register(fastifyRateLimit, options2) appWithHttp2.register(fastifyRateLimit, options3) appWithHttp2.register(fastifyRateLimit, options5) +appWithHttp2.register(fastifyRateLimit, options7) +appWithHttp2.register(fastifyRateLimit, options8) appWithHttp2.get('/public', { config: {