From b0889c192f03fc884e180a61a9e21b0b9b439c04 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Mon, 25 Mar 2024 10:15:14 +0100 Subject: [PATCH 1/7] feat: initial rate_limiting library --- .github/workflows/ci.yaml | 10 ++ Cargo.toml | 3 + crates/rate_limit/Cargo.toml | 17 +++ crates/rate_limit/docker-compose.yml | 6 + crates/rate_limit/src/lib.rs | 160 +++++++++++++++++++++++++ crates/rate_limit/src/token_bucket.lua | 44 +++++++ src/lib.rs | 2 + 7 files changed, 242 insertions(+) create mode 100644 crates/rate_limit/Cargo.toml create mode 100644 crates/rate_limit/docker-compose.yml create mode 100644 crates/rate_limit/src/lib.rs create mode 100644 crates/rate_limit/src/token_bucket.lua diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c9f0c0a..6cc2c3d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,6 +44,16 @@ jobs: rust: stable env: RUST_BACKTRACE: full + services: + redis: + image: redis:7.2-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 6379:6379 steps: - uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index e8c9dfa..fcd526f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ full = [ "geoip", "http", "metrics", + "rate_limit", ] alloc = ["dep:alloc"] analytics = ["dep:analytics"] @@ -33,6 +34,7 @@ geoip = ["dep:geoip"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] profiler = ["alloc/profiler"] +rate_limit = ["dep:rate_limit"] [workspace.dependencies] aws-sdk-s3 = "1.13" @@ -45,6 +47,7 @@ future = { path = "./crates/future", optional = true } geoip = { path = "./crates/geoip", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } +rate_limit = { path = "./crates/rate_limit", optional = true } [dev-dependencies] anyhow = "1" diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml new file mode 100644 index 0000000..8b4a7db --- /dev/null +++ b/crates/rate_limit/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rate_limit" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +deadpool-redis = "0.12" +redis = { version = "0.23", default-features = false, features = ["script"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +tracing = "0.1" + +[dev-dependencies] +anyhow = "1" +tokio = { version = "1", features = ["full"] } diff --git a/crates/rate_limit/docker-compose.yml b/crates/rate_limit/docker-compose.yml new file mode 100644 index 0000000..0da745b --- /dev/null +++ b/crates/rate_limit/docker-compose.yml @@ -0,0 +1,6 @@ +version: "3.9" +services: + redis: + image: redis:7.0 + ports: + - "6379:6379" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs new file mode 100644 index 0000000..288e7e1 --- /dev/null +++ b/crates/rate_limit/src/lib.rs @@ -0,0 +1,160 @@ +use { + chrono::{DateTime, Duration, Utc}, + core::fmt, + deadpool_redis::{Pool, PoolError}, + redis::{RedisError, Script}, + std::{collections::HashMap, sync::Arc}, +}; + +pub type Clock = Option>; +pub trait ClockImpl: fmt::Debug + Send + Sync { + fn now(&self) -> DateTime; +} + +#[derive(Debug, thiserror::Error)] +#[error("Rate limit exceeded. Try again at {reset}")] +pub struct RateLimitExceeded { + reset: u64, +} + +#[derive(Debug, thiserror::Error)] +pub enum InternalRateLimitError { + #[error("Redis pool error {0}")] + Pool(PoolError), + + #[error("Redis error: {0}")] + Redis(RedisError), +} + +#[derive(Debug, thiserror::Error)] +pub enum RateLimitError { + #[error(transparent)] + RateLimitExceeded(RateLimitExceeded), + + #[error("Internal error: {0}")] + Internal(InternalRateLimitError), +} + +pub async fn token_bucket( + redis_write_pool: &Arc, + key: String, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result<(), RateLimitError> { + let result = token_bucket_many( + redis_write_pool, + vec![key.clone()], + max_tokens, + interval, + refill_rate, + ) + .await + .map_err(RateLimitError::Internal)?; + + let (remaining, reset) = result.get(&key).expect("Should contain the key"); + if remaining.is_negative() { + Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset: reset / 1000, + })) + } else { + Ok(()) + } +} + +pub async fn token_bucket_many( + redis_write_pool: &Arc, + keys: Vec, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result, InternalRateLimitError> { + let now = Utc::now(); + + // Remaining is number of tokens remaining. -1 for rate limited. + // Reset is the time at which there will be 1 more token than before. This + // could, for example, be used to cache a 0 token count. + Script::new(include_str!("token_bucket.lua")) + .key(keys) + .arg(max_tokens) + .arg(interval.num_milliseconds()) + .arg(refill_rate) + .arg(now.timestamp_millis()) + .invoke_async::<_, String>( + &mut redis_write_pool + .clone() + .get() + .await + .map_err(InternalRateLimitError::Pool)?, + ) + .await + .map_err(InternalRateLimitError::Redis) + .map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) +} + +#[cfg(test)] +mod tests { + const REDIS_URI: &str = "redis://localhost:6379"; + use { + super::*, + deadpool_redis::{Config, Runtime}, + redis::AsyncCommands, + tokio::time::sleep, + }; + + async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { + let client = redis::Client::open(conn_uri).unwrap(); + let mut conn = client.get_async_connection().await.unwrap(); + for key in keys { + let _: () = conn.del(key).await.unwrap(); + } + } + + #[tokio::test] + async fn test_token_bucket() { + let cfg = Config::from_url(REDIS_URI); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let key = "test_token_bucket".to_string(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + + let max_tokens = 10; + let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); + let refill_rate = 1; + let rate_limit = || async { + token_bucket_many( + &pool, + vec![key.clone()], + max_tokens, + refill_interval, + refill_rate, + ) + .await + .unwrap() + .get(&key.clone()) + .unwrap() + .to_owned() + }; + + // Iterate over the max tokens + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + assert_eq!(result.0, curr_iter); + } + + // Sleep for refill and try again + // Tokens numbers should be the same as the previous iteration + sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + assert_eq!(result.0, curr_iter); + } + + // Clear keys after the test + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + } +} diff --git a/crates/rate_limit/src/token_bucket.lua b/crates/rate_limit/src/token_bucket.lua new file mode 100644 index 0000000..07ec7b1 --- /dev/null +++ b/crates/rate_limit/src/token_bucket.lua @@ -0,0 +1,44 @@ +-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 +local keys = KEYS -- identifier including prefixes +local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens +local interval = tonumber(ARGV[2]) -- size of the window in milliseconds +local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval +local now = tonumber(ARGV[4]) -- current timestamp in milliseconds + +local results = {} + +for i, key in ipairs(keys) do + local bucket = redis.call("HMGET", key, "refilledAt", "tokens") + + local refilledAt + local tokens + + if bucket[1] == false then + refilledAt = now + tokens = maxTokens + else + refilledAt = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) + end + + if now >= refilledAt + interval then + local numRefills = math.floor((now - refilledAt) / interval) + tokens = math.min(maxTokens, tokens + numRefills * refillRate) + + refilledAt = refilledAt + numRefills * interval + end + + if tokens == 0 then + results[key] = {-1, refilledAt + interval} + else + local remaining = tokens - 1 + local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval + + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) + redis.call("PEXPIRE", key, expireAt) + results[key] = {remaining, refilledAt + interval} + end +end + +-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 +return cjson.encode(results) diff --git a/src/lib.rs b/src/lib.rs index 5c99951..a16c71a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,3 +13,5 @@ pub use geoip; pub use http; #[cfg(feature = "metrics")] pub use metrics; +#[cfg(feature = "rate_limit")] +pub use rate_limit; From c9bde58a85576c607caae0fc4cdfe91b123a8787 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Mon, 25 Mar 2024 23:00:00 +0100 Subject: [PATCH 2/7] feat: adding Moka caching for a single key calls --- crates/rate_limit/Cargo.toml | 1 + crates/rate_limit/src/lib.rs | 109 ++++++++++++++++++++++++++++++----- 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml index 8b4a7db..1171b76 100644 --- a/crates/rate_limit/Cargo.toml +++ b/crates/rate_limit/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } deadpool-redis = "0.12" +moka = { version = "0.12", features = ["future"] } redis = { version = "0.23", default-features = false, features = ["script"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs index 288e7e1..b0dfed9 100644 --- a/crates/rate_limit/src/lib.rs +++ b/crates/rate_limit/src/lib.rs @@ -2,6 +2,7 @@ use { chrono::{DateTime, Duration, Utc}, core::fmt, deadpool_redis::{Pool, PoolError}, + moka::future::Cache, redis::{RedisError, Script}, std::{collections::HashMap, sync::Arc}, }; @@ -35,13 +36,25 @@ pub enum RateLimitError { Internal(InternalRateLimitError), } +/// Rate limit check using a token bucket algorithm for one key and in-memory +/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value +/// as the refill interval. pub async fn token_bucket( + mem_cache: &Cache, redis_write_pool: &Arc, key: String, max_tokens: u32, interval: Duration, refill_rate: u32, ) -> Result<(), RateLimitError> { + // Check if the key is in the memory cache of rate limited keys + // to omit the redis RTT in case of flood + if let Some(reset) = mem_cache.get(&key).await { + return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset, + })); + } + let result = token_bucket_many( redis_write_pool, vec![key.clone()], @@ -54,14 +67,21 @@ pub async fn token_bucket( let (remaining, reset) = result.get(&key).expect("Should contain the key"); if remaining.is_negative() { + let reset_interval = reset / 1000; + + // Insert the rate-limited key into the memory cache to avoid the redis RTT in + // case of flood + mem_cache.insert(key, reset_interval).await; + Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { - reset: reset / 1000, + reset: reset_interval, })) } else { Ok(()) } } +/// Rate limit check using a token bucket algorithm for many keys. pub async fn token_bucket_many( redis_write_pool: &Arc, keys: Vec, @@ -95,6 +115,8 @@ pub async fn token_bucket_many( #[cfg(test)] mod tests { const REDIS_URI: &str = "redis://localhost:6379"; + const REFILL_INTERVAL_MILLIS: i64 = 100; + use { super::*, deadpool_redis::{Config, Runtime}, @@ -111,16 +133,16 @@ mod tests { } #[tokio::test] - async fn test_token_bucket() { + async fn test_token_bucket_many() { let cfg = Config::from_url(REDIS_URI); let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); - let key = "test_token_bucket".to_string(); + let key = "token_bucket_many_test_key".to_string(); // Before running the test, ensure the test keys are cleared redis_clear_keys(REDIS_URI, &[key.clone()]).await; let max_tokens = 10; - let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); let refill_rate = 1; let rate_limit = || async { token_bucket_many( @@ -136,23 +158,78 @@ mod tests { .unwrap() .to_owned() }; + let call_rate_limit_loop = || async { + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + assert_eq!(result.0, curr_iter); + } + }; - // Iterate over the max tokens - for i in 0..=max_tokens { - let curr_iter = max_tokens as i64 - i as i64 - 1; - let result = rate_limit().await; - assert_eq!(result.0, curr_iter); - } + // Call rate limit until max tokens limit is reached + call_rate_limit_loop().await; // Sleep for refill and try again - // Tokens numbers should be the same as the previous iteration + // Tokens numbers should be the same as the previous iteration because + // they were refilled sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + call_rate_limit_loop().await; - for i in 0..=max_tokens { - let curr_iter = max_tokens as i64 - i as i64 - 1; - let result = rate_limit().await; - assert_eq!(result.0, curr_iter); - } + // Clear keys after the test + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + } + + #[tokio::test] + async fn test_token_bucket() { + // Create Moka cache with a TTL of the refill interval + let cache: Cache = Cache::builder() + .time_to_live(std::time::Duration::from_millis( + REFILL_INTERVAL_MILLIS as u64, + )) + .build(); + + let cfg = Config::from_url(REDIS_URI); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let key = "token_bucket_test_key".to_string(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + + let max_tokens = 10; + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); + let refill_rate = 1; + let rate_limit = || async { + token_bucket( + &cache, + &pool, + key.clone(), + max_tokens, + refill_interval, + refill_rate, + ) + .await + }; + let call_rate_limit_loop = || async { + for i in 0..=max_tokens { + let result = rate_limit().await; + if i == max_tokens { + assert!(result + .err() + .unwrap() + .to_string() + .contains("Rate limit exceeded")); + } else { + assert!(result.is_ok()); + } + } + }; + + // Call rate limit until max tokens limit is reached + call_rate_limit_loop().await; + + // Sleep for refill and try again + sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + call_rate_limit_loop().await; // Clear keys after the test redis_clear_keys(REDIS_URI, &[key.clone()]).await; From 729ec960a436fb748cd70f9fc9a24f3259d75cd7 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Wed, 27 Mar 2024 16:35:35 +0100 Subject: [PATCH 3/7] chore: updating tests to pass millis, multiple keys, reffil time check --- crates/rate_limit/Cargo.toml | 3 +- crates/rate_limit/src/lib.rs | 166 ++++++++++++++++++++++------------- 2 files changed, 106 insertions(+), 63 deletions(-) diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml index 1171b76..efd7813 100644 --- a/crates/rate_limit/Cargo.toml +++ b/crates/rate_limit/Cargo.toml @@ -14,5 +14,6 @@ thiserror = "1.0" tracing = "0.1" [dev-dependencies] -anyhow = "1" +futures = "0.3" tokio = { version = "1", features = ["full"] } +uuid = "1.8" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs index b0dfed9..9374e5a 100644 --- a/crates/rate_limit/src/lib.rs +++ b/crates/rate_limit/src/lib.rs @@ -1,17 +1,11 @@ use { - chrono::{DateTime, Duration, Utc}, - core::fmt, + chrono::Duration, deadpool_redis::{Pool, PoolError}, moka::future::Cache, redis::{RedisError, Script}, std::{collections::HashMap, sync::Arc}, }; -pub type Clock = Option>; -pub trait ClockImpl: fmt::Debug + Send + Sync { - fn now(&self) -> DateTime; -} - #[derive(Debug, thiserror::Error)] #[error("Rate limit exceeded. Try again at {reset}")] pub struct RateLimitExceeded { @@ -46,6 +40,7 @@ pub async fn token_bucket( max_tokens: u32, interval: Duration, refill_rate: u32, + now_millis: i64, ) -> Result<(), RateLimitError> { // Check if the key is in the memory cache of rate limited keys // to omit the redis RTT in case of flood @@ -61,6 +56,7 @@ pub async fn token_bucket( max_tokens, interval, refill_rate, + now_millis, ) .await .map_err(RateLimitError::Internal)?; @@ -88,9 +84,8 @@ pub async fn token_bucket_many( max_tokens: u32, interval: Duration, refill_rate: u32, + now_millis: i64, ) -> Result, InternalRateLimitError> { - let now = Utc::now(); - // Remaining is number of tokens remaining. -1 for rate limited. // Reset is the time at which there will be 1 more token than before. This // could, for example, be used to cache a 0 token count. @@ -99,7 +94,7 @@ pub async fn token_bucket_many( .arg(max_tokens) .arg(interval.num_milliseconds()) .arg(refill_rate) - .arg(now.timestamp_millis()) + .arg(now_millis) .invoke_async::<_, String>( &mut redis_write_pool .clone() @@ -116,12 +111,16 @@ pub async fn token_bucket_many( mod tests { const REDIS_URI: &str = "redis://localhost:6379"; const REFILL_INTERVAL_MILLIS: i64 = 100; + const MAX_TOKENS: u32 = 5; + const REFILL_RATE: u32 = 1; use { super::*, + chrono::Utc, deadpool_redis::{Config, Runtime}, redis::AsyncCommands, tokio::time::sleep, + uuid::Uuid, }; async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { @@ -132,51 +131,90 @@ mod tests { } } - #[tokio::test] - async fn test_token_bucket_many() { + async fn test_rate_limiting(key: String) { let cfg = Config::from_url(REDIS_URI); let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); - let key = "token_bucket_many_test_key".to_string(); - - // Before running the test, ensure the test keys are cleared - redis_clear_keys(REDIS_URI, &[key.clone()]).await; - - let max_tokens = 10; let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); - let refill_rate = 1; - let rate_limit = || async { - token_bucket_many( - &pool, - vec![key.clone()], - max_tokens, - refill_interval, - refill_rate, - ) - .await - .unwrap() - .get(&key.clone()) - .unwrap() - .to_owned() + let rate_limit = |now_millis: i64| { + let key = key.clone(); + let pool = pool.clone(); + async move { + token_bucket_many( + &pool, + vec![key.clone()], + MAX_TOKENS, + refill_interval, + REFILL_RATE, + now_millis, + ) + .await + .unwrap() + .get(&key) + .unwrap() + .to_owned() + } }; - let call_rate_limit_loop = || async { - for i in 0..=max_tokens { - let curr_iter = max_tokens as i64 - i as i64 - 1; - let result = rate_limit().await; + // Function to call rate limit multiple times and assert results + // for tokens count and reset timestamp + let call_rate_limit_loop = |loop_iterations| async move { + let first_call_millis = Utc::now().timestamp_millis(); + for i in 0..=loop_iterations { + let curr_iter = loop_iterations as i64 - i as i64 - 1; + + // Using the first call timestamp for the first call or produce the current + let result = if i == 0 { + rate_limit(first_call_millis).await + } else { + rate_limit(Utc::now().timestamp_millis()).await + }; + + // Assert the remaining tokens count assert_eq!(result.0, curr_iter); + // Assert the reset timestamp should be the first call timestamp + refill + // interval + assert_eq!( + result.1, + (first_call_millis + REFILL_INTERVAL_MILLIS) as u64 + ); } + // Returning the refill timestamp + first_call_millis + REFILL_INTERVAL_MILLIS }; // Call rate limit until max tokens limit is reached - call_rate_limit_loop().await; + call_rate_limit_loop(MAX_TOKENS).await; - // Sleep for refill and try again + // Sleep for the full refill and try again // Tokens numbers should be the same as the previous iteration because - // they were refilled - sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; - call_rate_limit_loop().await; + // they were fully refilled + sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; + let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await; + + // Sleep for just one refill and try again + // The result must contain one token and the reset timestamp should be + // the last full iteration call timestamp + refill interval + sleep((refill_interval).to_std().unwrap()).await; + let result = rate_limit(Utc::now().timestamp_millis()).await; + assert_eq!(result.0, 0); + assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64); + } + + #[tokio::test] + async fn test_token_bucket_many() { + const KEYS_NUMBER_TO_TEST: usize = 3; + let keys = (0..KEYS_NUMBER_TO_TEST) + .map(|_| Uuid::new_v4().to_string()) + .collect::>(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &keys).await; + + // Start async test for each key and wait for all to complete + let tasks = keys.iter().map(|key| test_rate_limiting(key.clone())); + futures::future::join_all(tasks).await; // Clear keys after the test - redis_clear_keys(REDIS_URI, &[key.clone()]).await; + redis_clear_keys(REDIS_URI, &keys).await; } #[tokio::test] @@ -190,29 +228,33 @@ mod tests { let cfg = Config::from_url(REDIS_URI); let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); - let key = "token_bucket_test_key".to_string(); + let key = Uuid::new_v4().to_string(); // Before running the test, ensure the test keys are cleared redis_clear_keys(REDIS_URI, &[key.clone()]).await; - let max_tokens = 10; let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); - let refill_rate = 1; - let rate_limit = || async { - token_bucket( - &cache, - &pool, - key.clone(), - max_tokens, - refill_interval, - refill_rate, - ) - .await + let rate_limit = |now_millis| { + let key = key.clone(); + let pool = pool.clone(); + let cache = cache.clone(); + async move { + token_bucket( + &cache, + &pool, + key.clone(), + MAX_TOKENS, + refill_interval, + REFILL_RATE, + now_millis, + ) + .await + } }; - let call_rate_limit_loop = || async { - for i in 0..=max_tokens { - let result = rate_limit().await; - if i == max_tokens { + let call_rate_limit_loop = |now_millis| async move { + for i in 0..=MAX_TOKENS { + let result = rate_limit(now_millis).await; + if i == MAX_TOKENS { assert!(result .err() .unwrap() @@ -225,11 +267,11 @@ mod tests { }; // Call rate limit until max tokens limit is reached - call_rate_limit_loop().await; + call_rate_limit_loop(Utc::now().timestamp_millis()).await; // Sleep for refill and try again - sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; - call_rate_limit_loop().await; + sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; + call_rate_limit_loop(Utc::now().timestamp_millis()).await; // Clear keys after the test redis_clear_keys(REDIS_URI, &[key.clone()]).await; From c8af64658a8ea7db122c5e782388bd2c983c6a46 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Wed, 27 Mar 2024 22:58:35 +0100 Subject: [PATCH 4/7] chore: passing DateTime type for millis instead of u64 --- crates/rate_limit/src/lib.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs index 9374e5a..301ed0b 100644 --- a/crates/rate_limit/src/lib.rs +++ b/crates/rate_limit/src/lib.rs @@ -1,5 +1,5 @@ use { - chrono::Duration, + chrono::{DateTime, Duration, Utc}, deadpool_redis::{Pool, PoolError}, moka::future::Cache, redis::{RedisError, Script}, @@ -40,7 +40,7 @@ pub async fn token_bucket( max_tokens: u32, interval: Duration, refill_rate: u32, - now_millis: i64, + now_millis: DateTime, ) -> Result<(), RateLimitError> { // Check if the key is in the memory cache of rate limited keys // to omit the redis RTT in case of flood @@ -84,7 +84,7 @@ pub async fn token_bucket_many( max_tokens: u32, interval: Duration, refill_rate: u32, - now_millis: i64, + now_millis: DateTime, ) -> Result, InternalRateLimitError> { // Remaining is number of tokens remaining. -1 for rate limited. // Reset is the time at which there will be 1 more token than before. This @@ -94,7 +94,7 @@ pub async fn token_bucket_many( .arg(max_tokens) .arg(interval.num_milliseconds()) .arg(refill_rate) - .arg(now_millis) + .arg(now_millis.timestamp_millis()) .invoke_async::<_, String>( &mut redis_write_pool .clone() @@ -135,7 +135,7 @@ mod tests { let cfg = Config::from_url(REDIS_URI); let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); - let rate_limit = |now_millis: i64| { + let rate_limit = |now_millis| { let key = key.clone(); let pool = pool.clone(); async move { @@ -157,7 +157,7 @@ mod tests { // Function to call rate limit multiple times and assert results // for tokens count and reset timestamp let call_rate_limit_loop = |loop_iterations| async move { - let first_call_millis = Utc::now().timestamp_millis(); + let first_call_millis = Utc::now(); for i in 0..=loop_iterations { let curr_iter = loop_iterations as i64 - i as i64 - 1; @@ -165,7 +165,7 @@ mod tests { let result = if i == 0 { rate_limit(first_call_millis).await } else { - rate_limit(Utc::now().timestamp_millis()).await + rate_limit(Utc::now()).await }; // Assert the remaining tokens count @@ -174,11 +174,11 @@ mod tests { // interval assert_eq!( result.1, - (first_call_millis + REFILL_INTERVAL_MILLIS) as u64 + (first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64 ); } // Returning the refill timestamp - first_call_millis + REFILL_INTERVAL_MILLIS + first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS }; // Call rate limit until max tokens limit is reached @@ -194,7 +194,7 @@ mod tests { // The result must contain one token and the reset timestamp should be // the last full iteration call timestamp + refill interval sleep((refill_interval).to_std().unwrap()).await; - let result = rate_limit(Utc::now().timestamp_millis()).await; + let result = rate_limit(Utc::now()).await; assert_eq!(result.0, 0); assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64); } @@ -267,11 +267,11 @@ mod tests { }; // Call rate limit until max tokens limit is reached - call_rate_limit_loop(Utc::now().timestamp_millis()).await; + call_rate_limit_loop(Utc::now()).await; // Sleep for refill and try again sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; - call_rate_limit_loop(Utc::now().timestamp_millis()).await; + call_rate_limit_loop(Utc::now()).await; // Clear keys after the test redis_clear_keys(REDIS_URI, &[key.clone()]).await; From 85530b020d27c14b8e19baa9db96dde161f6bac8 Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Wed, 27 Mar 2024 22:59:43 +0100 Subject: [PATCH 5/7] feat: optimizing and minifying lua script --- crates/rate_limit/src/token_bucket.lua | 39 ++++++++------------------ 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/crates/rate_limit/src/token_bucket.lua b/crates/rate_limit/src/token_bucket.lua index 07ec7b1..0fdc994 100644 --- a/crates/rate_limit/src/token_bucket.lua +++ b/crates/rate_limit/src/token_bucket.lua @@ -6,39 +6,24 @@ local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each local now = tonumber(ARGV[4]) -- current timestamp in milliseconds local results = {} - -for i, key in ipairs(keys) do +for i, key in ipairs(KEYS) do local bucket = redis.call("HMGET", key, "refilledAt", "tokens") + local refilledAt = (bucket[1] == false and tonumber(now) or tonumber(bucket[1])) + local tokens = (bucket[1] == false and tonumber(maxTokens) or tonumber(bucket[2])) - local refilledAt - local tokens - - if bucket[1] == false then - refilledAt = now - tokens = maxTokens - else - refilledAt = tonumber(bucket[1]) - tokens = tonumber(bucket[2]) - end - - if now >= refilledAt + interval then - local numRefills = math.floor((now - refilledAt) / interval) - tokens = math.min(maxTokens, tokens + numRefills * refillRate) - - refilledAt = refilledAt + numRefills * interval + if tonumber(now) >= refilledAt + interval then + tokens = math.min(tonumber(maxTokens), tokens + math.floor((tonumber(now) - refilledAt) / interval) * tonumber(refillRate)) + refilledAt = refilledAt + math.floor((tonumber(now) - refilledAt) / interval) * interval end - if tokens == 0 then - results[key] = {-1, refilledAt + interval} + if tokens > 0 then + tokens = tokens - 1 + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", tokens) + redis.call("PEXPIRE", key, math.ceil(((tonumber(maxTokens) - tokens) / tonumber(refillRate)) * interval)) + results[key] = {tokens, refilledAt + interval} else - local remaining = tokens - 1 - local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval - - redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) - redis.call("PEXPIRE", key, expireAt) - results[key] = {remaining, refilledAt + interval} + results[key] = {-1, refilledAt + interval} end end - -- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 return cjson.encode(results) From b9ecde904124b206b54d37e54f7b5672f3881fda Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Thu, 28 Mar 2024 10:56:25 +0100 Subject: [PATCH 6/7] revert: optimizing and minifying lua script This reverts commit 85530b020d27c14b8e19baa9db96dde161f6bac8. --- crates/rate_limit/src/token_bucket.lua | 39 ++++++++++++++++++-------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/crates/rate_limit/src/token_bucket.lua b/crates/rate_limit/src/token_bucket.lua index 0fdc994..07ec7b1 100644 --- a/crates/rate_limit/src/token_bucket.lua +++ b/crates/rate_limit/src/token_bucket.lua @@ -6,24 +6,39 @@ local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each local now = tonumber(ARGV[4]) -- current timestamp in milliseconds local results = {} -for i, key in ipairs(KEYS) do + +for i, key in ipairs(keys) do local bucket = redis.call("HMGET", key, "refilledAt", "tokens") - local refilledAt = (bucket[1] == false and tonumber(now) or tonumber(bucket[1])) - local tokens = (bucket[1] == false and tonumber(maxTokens) or tonumber(bucket[2])) - if tonumber(now) >= refilledAt + interval then - tokens = math.min(tonumber(maxTokens), tokens + math.floor((tonumber(now) - refilledAt) / interval) * tonumber(refillRate)) - refilledAt = refilledAt + math.floor((tonumber(now) - refilledAt) / interval) * interval - end + local refilledAt + local tokens - if tokens > 0 then - tokens = tokens - 1 - redis.call("HSET", key, "refilledAt", refilledAt, "tokens", tokens) - redis.call("PEXPIRE", key, math.ceil(((tonumber(maxTokens) - tokens) / tonumber(refillRate)) * interval)) - results[key] = {tokens, refilledAt + interval} + if bucket[1] == false then + refilledAt = now + tokens = maxTokens else + refilledAt = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) + end + + if now >= refilledAt + interval then + local numRefills = math.floor((now - refilledAt) / interval) + tokens = math.min(maxTokens, tokens + numRefills * refillRate) + + refilledAt = refilledAt + numRefills * interval + end + + if tokens == 0 then results[key] = {-1, refilledAt + interval} + else + local remaining = tokens - 1 + local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval + + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) + redis.call("PEXPIRE", key, expireAt) + results[key] = {remaining, refilledAt + interval} end end + -- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 return cjson.encode(results) From 9befeaace171ba4caaf98056a6c9bf7b39be1cda Mon Sep 17 00:00:00 2001 From: Max Kalashnikoff Date: Thu, 28 Mar 2024 20:56:16 +0100 Subject: [PATCH 7/7] feat: bumping redis and deadpool-redis versions --- crates/rate_limit/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml index efd7813..cdee74d 100644 --- a/crates/rate_limit/Cargo.toml +++ b/crates/rate_limit/Cargo.toml @@ -5,9 +5,9 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } -deadpool-redis = "0.12" +deadpool-redis = "0.14" moka = { version = "0.12", features = ["future"] } -redis = { version = "0.23", default-features = false, features = ["script"] } +redis = { version = "0.24", default-features = false, features = ["script"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0"