diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c9f0c0a..86b243d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,6 +44,14 @@ jobs: rust: stable env: RUST_BACKTRACE: full + services: + redis: + image: redis:7.2-alpine + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - uses: actions/checkout@v3 diff --git a/Cargo.toml b/Cargo.toml index e8c9dfa..fcd526f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ full = [ "geoip", "http", "metrics", + "rate_limit", ] alloc = ["dep:alloc"] analytics = ["dep:analytics"] @@ -33,6 +34,7 @@ geoip = ["dep:geoip"] http = [] metrics = ["dep:metrics", "future/metrics", "alloc/metrics", "http/metrics"] profiler = ["alloc/profiler"] +rate_limit = ["dep:rate_limit"] [workspace.dependencies] aws-sdk-s3 = "1.13" @@ -45,6 +47,7 @@ future = { path = "./crates/future", optional = true } geoip = { path = "./crates/geoip", optional = true } http = { path = "./crates/http", optional = true } metrics = { path = "./crates/metrics", optional = true } +rate_limit = { path = "./crates/rate_limit", optional = true } [dev-dependencies] anyhow = "1" diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml new file mode 100644 index 0000000..8b4a7db --- /dev/null +++ b/crates/rate_limit/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rate_limit" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +deadpool-redis = "0.12" +redis = { version = "0.23", default-features = false, features = ["script"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +tracing = "0.1" + +[dev-dependencies] +anyhow = "1" +tokio = { version = "1", features = ["full"] } diff --git a/crates/rate_limit/docker-compose.yml b/crates/rate_limit/docker-compose.yml new file mode 100644 index 0000000..0da745b --- /dev/null +++ b/crates/rate_limit/docker-compose.yml @@ -0,0 +1,6 @@ +version: "3.9" +services: + redis: + image: redis:7.0 + ports: + - "6379:6379" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs new file mode 100644 index 0000000..c8bf5aa --- /dev/null +++ b/crates/rate_limit/src/lib.rs @@ -0,0 +1,175 @@ +use { + chrono::{DateTime, Duration, Utc}, + core::fmt, + deadpool_redis::{Pool, PoolError}, + redis::{RedisError, Script}, + std::{collections::HashMap, sync::Arc}, +}; + +pub type Clock = Option>; +pub trait ClockImpl: fmt::Debug + Send + Sync { + fn now(&self) -> DateTime; +} + +#[derive(Debug, thiserror::Error)] +#[error("Rate limit exceeded. Try again at {reset}")] +pub struct RateLimitExceeded { + reset: u64, +} + +#[derive(Debug, thiserror::Error)] +pub enum InternalRateLimitError { + #[error("Redis pool error {0}")] + Pool(PoolError), + + #[error("Redis error: {0}")] + Redis(RedisError), +} + +#[derive(Debug, thiserror::Error)] +pub enum RateLimitError { + #[error(transparent)] + RateLimitExceeded(RateLimitExceeded), + + #[error("Internal error: {0}")] + Internal(InternalRateLimitError), +} + +pub async fn token_bucket( + redis_write_pool: &Arc, + key: String, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result<(), RateLimitError> { + let result = token_bucket_many( + redis_write_pool, + vec![key.clone()], + max_tokens, + interval, + refill_rate, + ) + .await + .map_err(RateLimitError::Internal)?; + + let (remaining, reset) = result.get(&key).expect("Should contain the key"); + if remaining.is_negative() { + Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset: reset / 1000, + })) + } else { + Ok(()) + } +} + +pub async fn token_bucket_many( + redis_write_pool: &Arc, + keys: Vec, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result, InternalRateLimitError> { + let now = Utc::now(); + + // Remaining is number of tokens remaining. -1 for rate limited. + // Reset is the time at which there will be 1 more token than before. This + // could, for example, be used to cache a 0 token count. + Script::new(include_str!("token_bucket.lua")) + .key(keys) + .arg(max_tokens) + .arg(interval.num_milliseconds()) + .arg(refill_rate) + .arg(now.timestamp_millis()) + .invoke_async::<_, String>( + &mut redis_write_pool + .clone() + .get() + .await + .map_err(InternalRateLimitError::Pool)?, + ) + .await + .map_err(InternalRateLimitError::Redis) + .map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) +} + +#[cfg(test)] +mod tests { + use { + super::*, + deadpool_redis::{ + redis::{cmd, FromRedisValue}, + Config, + Runtime, + }, + redis::AsyncCommands, + std::{ + env, + time::{Duration, SystemTime}, + }, + tokio::time::sleep, + }; + + async fn clear_redis(conn_uri: &str, keys: &[String]) { + let client = redis::Client::open(conn_uri).unwrap(); + let mut conn = client.get_async_connection().await.unwrap(); + for key in keys { + let _: () = conn.del(key).await.unwrap(); + } + } + + #[tokio::test] + async fn test_token_bucket() { + let redis_uri = ""; + let cfg = Config::from_url(redis_uri); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let key = "test_token_bucket".to_string(); + + // Before running the test, ensure the test keys are cleared + clear_redis(redis_uri, &[key.clone()]).await; + + // Note: max_tokens, refill_interval, and refill_rate must be set properly to + // avoid flaky tests. Although a custom clock is used, the lua script + // still has expiration logic based on the clock of Redis not this custom one. + // The formula for the expiration is: math.ceil(((max_tokens - remaining) / + // refillRate)) * interval If the result of this expression is less than + // the time between the Redis calls, then the key can expire. Setting + // refill_duration to 10 seconds and refill_rate to 1 should be enough to avoid + // this. + let max_tokens = 10; + let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); + let refill_rate = 1; + let rate_limit = || async { + token_bucket_many( + &pool, + vec![key.clone()], + max_tokens, + refill_interval, + refill_rate, + ) + .await + .unwrap() + .get(&key.clone()) + .unwrap() + .to_owned() + }; + + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + println!("result: {:?}", result); + assert_eq!(result.0, curr_iter); + } + + // sleep for refill and try again + sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + println!("result: {:?}", result); + assert_eq!(result.0, curr_iter); + } + + clear_redis(redis_uri, &[key.clone()]).await; + } +} diff --git a/crates/rate_limit/src/token_bucket.lua b/crates/rate_limit/src/token_bucket.lua new file mode 100644 index 0000000..07ec7b1 --- /dev/null +++ b/crates/rate_limit/src/token_bucket.lua @@ -0,0 +1,44 @@ +-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 +local keys = KEYS -- identifier including prefixes +local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens +local interval = tonumber(ARGV[2]) -- size of the window in milliseconds +local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval +local now = tonumber(ARGV[4]) -- current timestamp in milliseconds + +local results = {} + +for i, key in ipairs(keys) do + local bucket = redis.call("HMGET", key, "refilledAt", "tokens") + + local refilledAt + local tokens + + if bucket[1] == false then + refilledAt = now + tokens = maxTokens + else + refilledAt = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) + end + + if now >= refilledAt + interval then + local numRefills = math.floor((now - refilledAt) / interval) + tokens = math.min(maxTokens, tokens + numRefills * refillRate) + + refilledAt = refilledAt + numRefills * interval + end + + if tokens == 0 then + results[key] = {-1, refilledAt + interval} + else + local remaining = tokens - 1 + local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval + + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) + redis.call("PEXPIRE", key, expireAt) + results[key] = {remaining, refilledAt + interval} + end +end + +-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 +return cjson.encode(results) diff --git a/src/lib.rs b/src/lib.rs index 5c99951..a16c71a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,3 +13,5 @@ pub use geoip; pub use http; #[cfg(feature = "metrics")] pub use metrics; +#[cfg(feature = "rate_limit")] +pub use rate_limit;