-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ad42551
commit 21506a7
Showing
7 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "rate_limit" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
chrono = { version = "0.4", features = ["serde"] } | ||
deadpool-redis = "0.12" | ||
redis = { version = "0.23", default-features = false, features = ["script"] } | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_json = "1.0" | ||
thiserror = "1.0" | ||
tracing = "0.1" | ||
|
||
[dev-dependencies] | ||
anyhow = "1" | ||
tokio = { version = "1", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
version: "3.9" | ||
services: | ||
redis: | ||
image: redis:7.0 | ||
ports: | ||
- "6379:6379" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
use { | ||
chrono::{DateTime, Duration, Utc}, | ||
core::fmt, | ||
deadpool_redis::{Pool, PoolError}, | ||
redis::{RedisError, Script}, | ||
std::{collections::HashMap, sync::Arc}, | ||
}; | ||
|
||
pub type Clock = Option<Arc<dyn ClockImpl>>; | ||
pub trait ClockImpl: fmt::Debug + Send + Sync { | ||
fn now(&self) -> DateTime<Utc>; | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Rate limit exceeded. Try again at {reset}")] | ||
pub struct RateLimitExceeded { | ||
reset: u64, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum InternalRateLimitError { | ||
#[error("Redis pool error {0}")] | ||
Pool(PoolError), | ||
|
||
#[error("Redis error: {0}")] | ||
Redis(RedisError), | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum RateLimitError { | ||
#[error(transparent)] | ||
RateLimitExceeded(RateLimitExceeded), | ||
|
||
#[error("Internal error: {0}")] | ||
Internal(InternalRateLimitError), | ||
} | ||
|
||
pub async fn token_bucket( | ||
redis_write_pool: &Arc<Pool>, | ||
key: String, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
) -> Result<(), RateLimitError> { | ||
let result = token_bucket_many( | ||
redis_write_pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
interval, | ||
refill_rate, | ||
) | ||
.await | ||
.map_err(RateLimitError::Internal)?; | ||
|
||
let (remaining, reset) = result.get(&key).expect("Should contain the key"); | ||
if remaining.is_negative() { | ||
Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset: reset / 1000, | ||
})) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
|
||
pub async fn token_bucket_many( | ||
redis_write_pool: &Arc<Pool>, | ||
keys: Vec<String>, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> { | ||
let now = Utc::now(); | ||
|
||
// Remaining is number of tokens remaining. -1 for rate limited. | ||
// Reset is the time at which there will be 1 more token than before. This | ||
// could, for example, be used to cache a 0 token count. | ||
Script::new(include_str!("token_bucket.lua")) | ||
.key(keys) | ||
.arg(max_tokens) | ||
.arg(interval.num_milliseconds()) | ||
.arg(refill_rate) | ||
.arg(now.timestamp_millis()) | ||
.invoke_async::<_, String>( | ||
&mut redis_write_pool | ||
.clone() | ||
.get() | ||
.await | ||
.map_err(InternalRateLimitError::Pool)?, | ||
) | ||
.await | ||
.map_err(InternalRateLimitError::Redis) | ||
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use { | ||
super::*, | ||
deadpool_redis::{ | ||
redis::{cmd, FromRedisValue}, | ||
Check warning on line 100 in crates/rate_limit/src/lib.rs
|
||
Config, | ||
Runtime, | ||
}, | ||
redis::AsyncCommands, | ||
std::{ | ||
env, | ||
time::{Duration, SystemTime}, | ||
}, | ||
tokio::time::sleep, | ||
}; | ||
|
||
async fn clear_redis(conn_uri: &str, keys: &[String]) { | ||
let client = redis::Client::open(conn_uri).unwrap(); | ||
let mut conn = client.get_async_connection().await.unwrap(); | ||
for key in keys { | ||
let _: () = conn.del(key).await.unwrap(); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket() { | ||
let redis_uri = ""; | ||
let cfg = Config::from_url(redis_uri); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let key = "test_token_bucket".to_string(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
clear_redis(redis_uri, &[key.clone()]).await; | ||
|
||
// Note: max_tokens, refill_interval, and refill_rate must be set properly to | ||
// avoid flaky tests. Although a custom clock is used, the lua script | ||
// still has expiration logic based on the clock of Redis not this custom one. | ||
// The formula for the expiration is: math.ceil(((max_tokens - remaining) / | ||
// refillRate)) * interval If the result of this expression is less than | ||
// the time between the Redis calls, then the key can expire. Setting | ||
// refill_duration to 10 seconds and refill_rate to 1 should be enough to avoid | ||
// this. | ||
let max_tokens = 10; | ||
let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); | ||
let refill_rate = 1; | ||
let rate_limit = || async { | ||
token_bucket_many( | ||
&pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
refill_interval, | ||
refill_rate, | ||
) | ||
.await | ||
.unwrap() | ||
.get(&key.clone()) | ||
.unwrap() | ||
.to_owned() | ||
}; | ||
|
||
for i in 0..=max_tokens { | ||
let curr_iter = max_tokens as i64 - i as i64 - 1; | ||
let result = rate_limit().await; | ||
println!("result: {:?}", result); | ||
assert_eq!(result.0, curr_iter); | ||
} | ||
|
||
// sleep for refill and try again | ||
sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; | ||
|
||
for i in 0..=max_tokens { | ||
let curr_iter = max_tokens as i64 - i as i64 - 1; | ||
let result = rate_limit().await; | ||
println!("result: {:?}", result); | ||
assert_eq!(result.0, curr_iter); | ||
} | ||
|
||
clear_redis(redis_uri, &[key.clone()]).await; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 | ||
local keys = KEYS -- identifier including prefixes | ||
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens | ||
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds | ||
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval | ||
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds | ||
|
||
local results = {} | ||
|
||
for i, key in ipairs(keys) do | ||
local bucket = redis.call("HMGET", key, "refilledAt", "tokens") | ||
|
||
local refilledAt | ||
local tokens | ||
|
||
if bucket[1] == false then | ||
refilledAt = now | ||
tokens = maxTokens | ||
else | ||
refilledAt = tonumber(bucket[1]) | ||
tokens = tonumber(bucket[2]) | ||
end | ||
|
||
if now >= refilledAt + interval then | ||
local numRefills = math.floor((now - refilledAt) / interval) | ||
tokens = math.min(maxTokens, tokens + numRefills * refillRate) | ||
|
||
refilledAt = refilledAt + numRefills * interval | ||
end | ||
|
||
if tokens == 0 then | ||
results[key] = {-1, refilledAt + interval} | ||
else | ||
local remaining = tokens - 1 | ||
local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval | ||
|
||
redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) | ||
redis.call("PEXPIRE", key, expireAt) | ||
results[key] = {remaining, refilledAt + interval} | ||
end | ||
end | ||
|
||
-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 | ||
return cjson.encode(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters