-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: rate_limiting
token bucket sub module
#14
Changes from 5 commits
b0889c1
c9bde58
729ec96
c8af646
85530b0
b9ecde9
9befeaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "rate_limit" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
chrono = { version = "0.4", features = ["serde"] } | ||
deadpool-redis = "0.12" | ||
moka = { version = "0.12", features = ["future"] } | ||
redis = { version = "0.23", default-features = false, features = ["script"] } | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_json = "1.0" | ||
thiserror = "1.0" | ||
tracing = "0.1" | ||
|
||
[dev-dependencies] | ||
futures = "0.3" | ||
tokio = { version = "1", features = ["full"] } | ||
uuid = "1.8" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
version: "3.9" | ||
services: | ||
redis: | ||
image: redis:7.0 | ||
ports: | ||
- "6379:6379" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
use { | ||
chrono::{DateTime, Duration, Utc}, | ||
deadpool_redis::{Pool, PoolError}, | ||
moka::future::Cache, | ||
redis::{RedisError, Script}, | ||
std::{collections::HashMap, sync::Arc}, | ||
}; | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Rate limit exceeded. Try again at {reset}")] | ||
pub struct RateLimitExceeded { | ||
reset: u64, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum InternalRateLimitError { | ||
#[error("Redis pool error {0}")] | ||
Pool(PoolError), | ||
|
||
#[error("Redis error: {0}")] | ||
Redis(RedisError), | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
pub enum RateLimitError { | ||
#[error(transparent)] | ||
RateLimitExceeded(RateLimitExceeded), | ||
|
||
#[error("Internal error: {0}")] | ||
Internal(InternalRateLimitError), | ||
} | ||
|
||
/// Rate limit check using a token bucket algorithm for one key and in-memory | ||
/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value | ||
/// as the refill interval. | ||
pub async fn token_bucket( | ||
mem_cache: &Cache<String, u64>, | ||
redis_write_pool: &Arc<Pool>, | ||
key: String, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
now_millis: DateTime<Utc>, | ||
) -> Result<(), RateLimitError> { | ||
// Check if the key is in the memory cache of rate limited keys | ||
// to omit the redis RTT in case of flood | ||
if let Some(reset) = mem_cache.get(&key).await { | ||
return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset, | ||
})); | ||
} | ||
|
||
let result = token_bucket_many( | ||
redis_write_pool, | ||
vec![key.clone()], | ||
max_tokens, | ||
interval, | ||
refill_rate, | ||
now_millis, | ||
) | ||
.await | ||
.map_err(RateLimitError::Internal)?; | ||
|
||
let (remaining, reset) = result.get(&key).expect("Should contain the key"); | ||
if remaining.is_negative() { | ||
let reset_interval = reset / 1000; | ||
|
||
// Insert the rate-limited key into the memory cache to avoid the redis RTT in | ||
// case of flood | ||
mem_cache.insert(key, reset_interval).await; | ||
|
||
Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { | ||
reset: reset_interval, | ||
})) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
|
||
/// Rate limit check using a token bucket algorithm for many keys. | ||
pub async fn token_bucket_many( | ||
redis_write_pool: &Arc<Pool>, | ||
keys: Vec<String>, | ||
max_tokens: u32, | ||
interval: Duration, | ||
refill_rate: u32, | ||
now_millis: DateTime<Utc>, | ||
) -> Result<HashMap<String, (i64, u64)>, InternalRateLimitError> { | ||
// Remaining is number of tokens remaining. -1 for rate limited. | ||
// Reset is the time at which there will be 1 more token than before. This | ||
// could, for example, be used to cache a 0 token count. | ||
Script::new(include_str!("token_bucket.lua")) | ||
.key(keys) | ||
.arg(max_tokens) | ||
.arg(interval.num_milliseconds()) | ||
.arg(refill_rate) | ||
.arg(now_millis.timestamp_millis()) | ||
.invoke_async::<_, String>( | ||
&mut redis_write_pool | ||
.clone() | ||
.get() | ||
.await | ||
.map_err(InternalRateLimitError::Pool)?, | ||
) | ||
.await | ||
.map_err(InternalRateLimitError::Redis) | ||
.map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
const REDIS_URI: &str = "redis://localhost:6379"; | ||
const REFILL_INTERVAL_MILLIS: i64 = 100; | ||
const MAX_TOKENS: u32 = 5; | ||
const REFILL_RATE: u32 = 1; | ||
|
||
use { | ||
super::*, | ||
chrono::Utc, | ||
deadpool_redis::{Config, Runtime}, | ||
redis::AsyncCommands, | ||
tokio::time::sleep, | ||
uuid::Uuid, | ||
}; | ||
|
||
async fn redis_clear_keys(conn_uri: &str, keys: &[String]) { | ||
let client = redis::Client::open(conn_uri).unwrap(); | ||
let mut conn = client.get_async_connection().await.unwrap(); | ||
for key in keys { | ||
let _: () = conn.del(key).await.unwrap(); | ||
} | ||
} | ||
|
||
async fn test_rate_limiting(key: String) { | ||
let cfg = Config::from_url(REDIS_URI); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); | ||
let rate_limit = |now_millis| { | ||
let key = key.clone(); | ||
let pool = pool.clone(); | ||
async move { | ||
token_bucket_many( | ||
&pool, | ||
vec![key.clone()], | ||
MAX_TOKENS, | ||
refill_interval, | ||
REFILL_RATE, | ||
now_millis, | ||
) | ||
.await | ||
.unwrap() | ||
.get(&key) | ||
.unwrap() | ||
.to_owned() | ||
} | ||
}; | ||
// Function to call rate limit multiple times and assert results | ||
// for tokens count and reset timestamp | ||
let call_rate_limit_loop = |loop_iterations| async move { | ||
let first_call_millis = Utc::now(); | ||
for i in 0..=loop_iterations { | ||
let curr_iter = loop_iterations as i64 - i as i64 - 1; | ||
|
||
// Using the first call timestamp for the first call or produce the current | ||
let result = if i == 0 { | ||
rate_limit(first_call_millis).await | ||
} else { | ||
rate_limit(Utc::now()).await | ||
}; | ||
|
||
// Assert the remaining tokens count | ||
assert_eq!(result.0, curr_iter); | ||
// Assert the reset timestamp should be the first call timestamp + refill | ||
// interval | ||
assert_eq!( | ||
result.1, | ||
(first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS) as u64 | ||
); | ||
} | ||
// Returning the refill timestamp | ||
first_call_millis.timestamp_millis() + REFILL_INTERVAL_MILLIS | ||
}; | ||
|
||
// Call rate limit until max tokens limit is reached | ||
call_rate_limit_loop(MAX_TOKENS).await; | ||
|
||
// Sleep for the full refill and try again | ||
// Tokens numbers should be the same as the previous iteration because | ||
// they were fully refilled | ||
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; | ||
let last_timestamp = call_rate_limit_loop(MAX_TOKENS).await; | ||
|
||
// Sleep for just one refill and try again | ||
// The result must contain one token and the reset timestamp should be | ||
// the last full iteration call timestamp + refill interval | ||
sleep((refill_interval).to_std().unwrap()).await; | ||
let result = rate_limit(Utc::now()).await; | ||
assert_eq!(result.0, 0); | ||
assert_eq!(result.1, (last_timestamp + REFILL_INTERVAL_MILLIS) as u64); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket_many() { | ||
const KEYS_NUMBER_TO_TEST: usize = 3; | ||
let keys = (0..KEYS_NUMBER_TO_TEST) | ||
.map(|_| Uuid::new_v4().to_string()) | ||
.collect::<Vec<String>>(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
redis_clear_keys(REDIS_URI, &keys).await; | ||
|
||
// Start async test for each key and wait for all to complete | ||
let tasks = keys.iter().map(|key| test_rate_limiting(key.clone())); | ||
futures::future::join_all(tasks).await; | ||
|
||
// Clear keys after the test | ||
redis_clear_keys(REDIS_URI, &keys).await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_token_bucket() { | ||
// Create Moka cache with a TTL of the refill interval | ||
let cache: Cache<String, u64> = Cache::builder() | ||
.time_to_live(std::time::Duration::from_millis( | ||
REFILL_INTERVAL_MILLIS as u64, | ||
)) | ||
.build(); | ||
|
||
let cfg = Config::from_url(REDIS_URI); | ||
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); | ||
let key = Uuid::new_v4().to_string(); | ||
|
||
// Before running the test, ensure the test keys are cleared | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
|
||
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); | ||
let rate_limit = |now_millis| { | ||
let key = key.clone(); | ||
let pool = pool.clone(); | ||
let cache = cache.clone(); | ||
async move { | ||
token_bucket( | ||
&cache, | ||
&pool, | ||
key.clone(), | ||
MAX_TOKENS, | ||
refill_interval, | ||
REFILL_RATE, | ||
now_millis, | ||
) | ||
.await | ||
} | ||
}; | ||
let call_rate_limit_loop = |now_millis| async move { | ||
for i in 0..=MAX_TOKENS { | ||
let result = rate_limit(now_millis).await; | ||
if i == MAX_TOKENS { | ||
assert!(result | ||
.err() | ||
.unwrap() | ||
.to_string() | ||
.contains("Rate limit exceeded")); | ||
} else { | ||
assert!(result.is_ok()); | ||
} | ||
} | ||
}; | ||
|
||
// Call rate limit until max tokens limit is reached | ||
call_rate_limit_loop(Utc::now()).await; | ||
|
||
// Sleep for refill and try again | ||
sleep((refill_interval * MAX_TOKENS as i32).to_std().unwrap()).await; | ||
call_rate_limit_loop(Utc::now()).await; | ||
|
||
// Clear keys after the test | ||
redis_clear_keys(REDIS_URI, &[key.clone()]).await; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 | ||
local keys = KEYS -- identifier including prefixes | ||
local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens | ||
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds | ||
local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval | ||
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds | ||
|
||
local results = {} | ||
for i, key in ipairs(KEYS) do | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverted it in b9ecde9 if you find it hard to debug the minified version of the Lua script since it's not a huge optimization. Let's stick to the initial one and move forward with it. |
||
local bucket = redis.call("HMGET", key, "refilledAt", "tokens") | ||
local refilledAt = (bucket[1] == false and tonumber(now) or tonumber(bucket[1])) | ||
local tokens = (bucket[1] == false and tonumber(maxTokens) or tonumber(bucket[2])) | ||
|
||
if tonumber(now) >= refilledAt + interval then | ||
geekbrother marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tokens = math.min(tonumber(maxTokens), tokens + math.floor((tonumber(now) - refilledAt) / interval) * tonumber(refillRate)) | ||
refilledAt = refilledAt + math.floor((tonumber(now) - refilledAt) / interval) * interval | ||
geekbrother marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
if tokens > 0 then | ||
tokens = tokens - 1 | ||
redis.call("HSET", key, "refilledAt", refilledAt, "tokens", tokens) | ||
redis.call("PEXPIRE", key, math.ceil(((tonumber(maxTokens) - tokens) / tonumber(refillRate)) * interval)) | ||
geekbrother marked this conversation as resolved.
Show resolved
Hide resolved
|
||
results[key] = {tokens, refilledAt + interval} | ||
else | ||
results[key] = {-1, refilledAt + interval} | ||
end | ||
end | ||
-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 | ||
return cjson.encode(results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean? Does
Cache
support different keys with different refill intervals e.g. for what Notify Server needs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TTL is for all records. If there are requirements to refill different rates for different cases it can be produced with the single refill interval (lowest) and different refill rates by tuning the refill rate to the single refill interval, without adding additional TTL per key.