diff --git a/.github/workflows/sub-ci.yml b/.github/workflows/sub-ci.yml index bf8217ab..75eb29e4 100644 --- a/.github/workflows/sub-ci.yml +++ b/.github/workflows/sub-ci.yml @@ -52,6 +52,10 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 + redis: + image: redis:7-alpine + ports: + - 6379:6379 steps: - name: Checkout uses: actions/checkout@v3 diff --git a/Cargo.lock b/Cargo.lock index af6413c2..48d2ac8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -836,19 +836,6 @@ dependencies = [ "libc", ] -[[package]] -name = "cerberus" -version = "0.2.0" -source = "git+https://github.com/WalletConnect/cerberus.git?tag=v0.9.1#e35d31285bcf89e1809192b3a0fe0406f9f40e22" -dependencies = [ - "async-trait", - "once_cell", - "regex", - "reqwest", - "serde", - "thiserror", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -2368,7 +2355,6 @@ dependencies = [ "bs58", "build-info", "build-info-build", - "cerberus", "chacha20poly1305", "chrono", "dashmap", @@ -2395,6 +2381,7 @@ dependencies = [ "rand 0.7.3", "rand_chacha 0.3.1", "rand_core 0.5.1", + "redis", "regex", "relay_client", "relay_rpc", @@ -2402,7 +2389,6 @@ dependencies = [ "ring 0.16.20", "rmp-serde", "serde", - "serde_bson", "serde_json", "sha2 0.10.8", "sha256", @@ -2415,8 +2401,6 @@ dependencies = [ "tower", "tower-http", "tracing", - "tracing-appender", - "tracing-opentelemetry", "tracing-subscriber", "tungstenite", "url", @@ -3050,6 +3034,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "ryu", + "sha1_smol", "tokio", "tokio-util", "url", @@ -3464,17 +3449,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "serde_bson" -version = "0.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55a9c200c43d6e6e4db10568daef0764c945d4d1e80cad32e5b5708ca123aa0" -dependencies = [ - "bytes", - "serde", - "take_mut", -] - [[package]] name = "serde_derive" version = "1.0.190" @@ -3552,6 +3526,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.9.9" @@ -4002,12 +3982,6 @@ dependencies = [ "libc", ] -[[package]] -name = "take_mut" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f764005d11ee5f36500a149ace24e00e3da98b0158b3e2d53a7495660d3f4d60" - [[package]] name = "tap" version = "1.0.1" @@ -4327,17 +4301,6 @@ dependencies = [ "tracing-core", ] -[[package]] -name = "tracing-appender" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" -dependencies = [ - "crossbeam-channel", - "time", - "tracing-subscriber", -] - [[package]] name = "tracing-attributes" version = "0.1.27" @@ -4370,20 +4333,6 @@ dependencies = [ "tracing-core", ] -[[package]] -name = "tracing-opentelemetry" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600" -dependencies = [ - "once_cell", - "opentelemetry", - "tracing", - "tracing-core", - "tracing-log", - "tracing-subscriber", -] - [[package]] name = "tracing-subscriber" version = "0.3.17" diff --git a/Cargo.toml b/Cargo.toml index c6fd1979..f5684049 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ once_cell = "1.18.0" lazy_static = "1.4.0" rmp-serde = "1.1.1" deadpool-redis = "0.12.0" +redis = { version = "0.23.3", default-features = false, features = ["script"] } rand_chacha = "0.3.1" sqlx = { version = "0.7.1", features = ["runtime-tokio-native-tls", "postgres", "chrono", "uuid"] } wiremock = "0.5.19" diff --git a/docker-compose.storage.yml b/docker-compose.storage.yml index 036d6167..b2f16576 100644 --- a/docker-compose.storage.yml +++ b/docker-compose.storage.yml @@ -7,6 +7,11 @@ services: # ports: # - "3001:16686" + redis: + image: redis:7-alpine + ports: + - "6379:6379" + postgres: image: postgres:16 environment: diff --git a/src/auth.rs b/src/auth.rs index 513be109..37e8a490 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -425,7 +425,7 @@ pub struct Authorization { pub domain: String, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub enum AuthorizedApp { Limited(String), Unlimited, @@ -490,21 +490,7 @@ pub async fn verify_identity( let app = { let statement = cacao.p.statement.ok_or(AuthError::CacaoStatementMissing)?; info!("CACAO statement: {statement}"); - if statement.contains("DAPP") - || statement == STATEMENT_THIS_DOMAIN_IDENTITY - || statement == STATEMENT_THIS_DOMAIN - { - AuthorizedApp::Limited(cacao.p.domain.clone()) - } else if statement.contains("WALLET") - || statement == STATEMENT - || statement == STATEMENT_ALL_DOMAINS_IDENTITY - || statement == STATEMENT_ALL_DOMAINS_OLD - || statement == STATEMENT_ALL_DOMAINS - { - AuthorizedApp::Unlimited - } else { - return Err(AuthError::CacaoStatementInvalid)?; - } + parse_cacao_statement(&statement, &cacao.p.domain)? }; if cacao.p.iss != sub { @@ -541,6 +527,24 @@ pub async fn verify_identity( }) } +fn parse_cacao_statement(statement: &str, domain: &str) -> Result { + if statement.contains("DAPP") + || statement == STATEMENT_THIS_DOMAIN_IDENTITY + || statement == STATEMENT_THIS_DOMAIN + { + Ok(AuthorizedApp::Limited(domain.to_owned())) + } else if statement.contains("WALLET") + || statement == STATEMENT + || statement == STATEMENT_ALL_DOMAINS_IDENTITY + || statement == STATEMENT_ALL_DOMAINS_OLD + || statement == STATEMENT_ALL_DOMAINS + { + Ok(AuthorizedApp::Unlimited) + } else { + return Err(AuthError::CacaoStatementInvalid)?; + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] struct KeyServerResponse { status: String, @@ -568,3 +572,69 @@ pub fn encode_subscribe_private_key(subscribe_key: &StaticSecret) -> String { pub fn encode_subscribe_public_key(subscribe_key: &StaticSecret) -> String { hex::encode(PublicKey::from(subscribe_key)) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn notify_all_domains() { + assert_eq!( + parse_cacao_statement(STATEMENT_ALL_DOMAINS, "app.example.com").unwrap(), + AuthorizedApp::Unlimited + ); + } + + #[test] + fn notify_all_domains_old() { + assert_eq!( + parse_cacao_statement(STATEMENT_ALL_DOMAINS_OLD, "app.example.com").unwrap(), + AuthorizedApp::Unlimited + ); + } + + #[test] + fn notify_this_domain() { + assert_eq!( + parse_cacao_statement(STATEMENT_THIS_DOMAIN, "app.example.com").unwrap(), + AuthorizedApp::Limited("app.example.com".to_owned()) + ); + } + + #[test] + fn notify_all_domains_identity() { + assert_eq!( + parse_cacao_statement(STATEMENT_ALL_DOMAINS_IDENTITY, "app.example.com").unwrap(), + AuthorizedApp::Unlimited + ); + } + + #[test] + fn notify_this_domain_identity() { + assert_eq!( + parse_cacao_statement(STATEMENT_THIS_DOMAIN_IDENTITY, "app.example.com").unwrap(), + AuthorizedApp::Limited("app.example.com".to_owned()) + ); + } + + #[test] + fn old_siwe_compatible() { + assert_eq!( + parse_cacao_statement(STATEMENT, "app.example.com").unwrap(), + AuthorizedApp::Unlimited + ); + } + + #[test] + fn old_old_siwe_compatible() { + assert_eq!( + parse_cacao_statement( + "I further authorize this DAPP to send and receive messages on my behalf for \ + this domain using my WalletConnect identity.", + "app.example.com" + ) + .unwrap(), + AuthorizedApp::Limited("app.example.com".to_owned()) + ); + } +} diff --git a/src/config/local.rs b/src/config/local.rs index 4a7faf0b..d324cf8a 100644 --- a/src/config/local.rs +++ b/src/config/local.rs @@ -25,6 +25,8 @@ pub struct LocalConfiguration { pub postgres_url: String, #[serde(default = "default_postgres_max_connections")] pub postgres_max_connections: u32, + #[serde(default = "default_redis_url")] + pub redis_url: String, #[serde(default = "default_keypair_seed")] pub keypair_seed: String, #[serde(default = "default_relay_url")] @@ -49,6 +51,10 @@ pub fn default_postgres_url() -> String { "postgres://postgres:postgres@localhost:5432/postgres".to_owned() } +pub fn default_redis_url() -> String { + "redis://localhost:6379/0".to_owned() +} + pub fn default_postgres_max_connections() -> u32 { 10 } diff --git a/src/error.rs b/src/error.rs index ffe2bf68..3878f8a7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -183,6 +183,15 @@ pub enum Error { #[error("App domain in-use by another project")] AppDomainInUseByAnotherProject, + + #[error("Redis pool error: {0}")] + RedisPoolError(#[from] deadpool_redis::PoolError), + + #[error("Redis error: {0}")] + RedisError(#[from] redis::RedisError), + + #[error("Rate limit exceeded. Try again at {0}")] + TooManyRequests(u64), } impl IntoResponse for Error { @@ -201,6 +210,9 @@ impl IntoResponse for Error { Self::AppDomainInUseByAnotherProject => { (StatusCode::CONFLICT, self.to_string()).into_response() } + Self::TooManyRequests(_) => { + (StatusCode::TOO_MANY_REQUESTS, self.to_string()).into_response() + } error => { warn!("Error does not have response clause: {:?}", error); (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error.").into_response() diff --git a/src/lib.rs b/src/lib.rs index d8b0f799..33eca314 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use { config::Configuration, error::Result, metrics::Metrics, + registry::storage::redis::Redis, relay_client_helpers::create_http_client, services::{ private_http_server, public_http_server, publisher_service, watcher_expiration_job, @@ -34,6 +35,7 @@ pub mod model; mod notify_keys; pub mod notify_message; pub mod publish_relay_message; +pub mod rate_limit; pub mod registry; pub mod relay_client_helpers; pub mod services; @@ -81,10 +83,19 @@ pub async fn bootstrap(mut shutdown: broadcast::Receiver<()>, config: Configurat let metrics = Some(Metrics::default()); + let redis = if let Some(redis_addr) = &config.auth_redis_addr() { + Some(Arc::new(Redis::new( + redis_addr, + config.redis_pool_size as usize, + )?)) + } else { + None + }; + let registry = Arc::new(registry::Registry::new( config.registry_url.clone(), &config.registry_auth_token, - &config, + redis.clone(), metrics.clone(), )?); @@ -97,6 +108,7 @@ pub async fn bootstrap(mut shutdown: broadcast::Receiver<()>, config: Configurat relay_ws_client.clone(), relay_http_client.clone(), metrics.clone(), + redis, registry, )?); diff --git a/src/rate_limit/mod.rs b/src/rate_limit/mod.rs new file mode 100644 index 00000000..9ed19f81 --- /dev/null +++ b/src/rate_limit/mod.rs @@ -0,0 +1,2 @@ +mod token_bucket; +pub use token_bucket::*; diff --git a/src/rate_limit/token_bucket.lua b/src/rate_limit/token_bucket.lua new file mode 100644 index 00000000..07ec7b15 --- /dev/null +++ b/src/rate_limit/token_bucket.lua @@ -0,0 +1,44 @@ +-- Adapted from https://github.com/upstash/ratelimit/blob/3a8cfb00e827188734ac347965cb743a75fcb98a/src/single.ts#L311 +local keys = KEYS -- identifier including prefixes +local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens +local interval = tonumber(ARGV[2]) -- size of the window in milliseconds +local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval +local now = tonumber(ARGV[4]) -- current timestamp in milliseconds + +local results = {} + +for i, key in ipairs(keys) do + local bucket = redis.call("HMGET", key, "refilledAt", "tokens") + + local refilledAt + local tokens + + if bucket[1] == false then + refilledAt = now + tokens = maxTokens + else + refilledAt = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) + end + + if now >= refilledAt + interval then + local numRefills = math.floor((now - refilledAt) / interval) + tokens = math.min(maxTokens, tokens + numRefills * refillRate) + + refilledAt = refilledAt + numRefills * interval + end + + if tokens == 0 then + results[key] = {-1, refilledAt + interval} + else + local remaining = tokens - 1 + local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval + + redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining) + redis.call("PEXPIRE", key, expireAt) + results[key] = {remaining, refilledAt + interval} + end +end + +-- Redis doesn't support Lua table responses: https://stackoverflow.com/a/24302613 +return cjson.encode(results) diff --git a/src/rate_limit/token_bucket.rs b/src/rate_limit/token_bucket.rs new file mode 100644 index 00000000..069d02d9 --- /dev/null +++ b/src/rate_limit/token_bucket.rs @@ -0,0 +1,47 @@ +use { + crate::{ + error::{Error, Result}, + registry::storage::redis::Redis, + }, + chrono::{Duration, Utc}, + redis::Script, + std::{collections::HashMap, sync::Arc}, +}; + +pub async fn token_bucket( + redis: &Arc, + key: String, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result<()> { + let result = + token_bucket_many(redis, vec![key.clone()], max_tokens, interval, refill_rate).await?; + let (remaining, reset) = result.get(&key).unwrap(); + if remaining.is_negative() { + Err(Error::TooManyRequests(reset / 1000)) + } else { + Ok(()) + } +} + +pub async fn token_bucket_many( + redis: &Arc, + keys: Vec, + max_tokens: u32, + interval: Duration, + refill_rate: u32, +) -> Result> { + // Remaining is number of tokens remaining. -1 for rate limited. + // Reset is the time at which there will be 1 more token than before. This could, for example, be used to cache a 0 token count. + Script::new(include_str!("token_bucket.lua")) + .key(keys) + .arg(max_tokens) + .arg(interval.num_milliseconds()) + .arg(refill_rate) + .arg(Utc::now().timestamp_millis()) + .invoke_async::<_, String>(&mut redis.write_pool().get().await?) + .await + .map_err(Into::into) + .map(|value| serde_json::from_str(&value).expect("Redis script should return valid JSON")) +} diff --git a/src/registry/mod.rs b/src/registry/mod.rs index f67a9b89..a0e899bb 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -1,5 +1,5 @@ use { - crate::{config::Configuration, error::Result, metrics::Metrics}, + crate::{error::Result, metrics::Metrics}, hyper::header, relay_rpc::domain::ProjectId, serde::{Deserialize, Serialize}, @@ -94,19 +94,10 @@ impl Registry { pub fn new( registry_url: Url, auth_token: &str, - config: &Configuration, + cache: Option>, metrics: Option, ) -> Result { let client = Arc::new(RegistryHttpClient::new(registry_url, auth_token, metrics)?); - - let cache = if let Some(redis_addr) = &config.auth_redis_addr() { - Some(Arc::new(Redis::new( - redis_addr, - config.redis_pool_size as usize, - )?)) - } else { - None - }; Ok(Self { client, cache }) } diff --git a/src/registry/storage/redis/mod.rs b/src/registry/storage/redis/mod.rs index bfc525e3..ac5fc740 100644 --- a/src/registry/storage/redis/mod.rs +++ b/src/registry/storage/redis/mod.rs @@ -91,6 +91,14 @@ impl Redis { }) } + pub fn write_pool(&self) -> Pool { + self.write_pool.clone() + } + + pub fn read_pool(&self) -> Pool { + self.read_pool.clone() + } + async fn set_internal( &self, key: &str, diff --git a/src/services/public_http_server/handlers/did_json.rs b/src/services/public_http_server/handlers/did_json.rs index 2d69fa99..331a9d28 100644 --- a/src/services/public_http_server/handlers/did_json.rs +++ b/src/services/public_http_server/handlers/did_json.rs @@ -7,6 +7,9 @@ use { tracing::info, }; +// No rate limit necessary since returning a fixed string is less computational intensive than tracking the rate limit + +// TODO generate this response at app startup to avoid unnecessary string allocations pub async fn handler(State(state): State>) -> Result { info!("Serving did.json"); diff --git a/src/services/public_http_server/handlers/get_subscribers_v1.rs b/src/services/public_http_server/handlers/get_subscribers_v1.rs index ed14f493..b6ec6892 100644 --- a/src/services/public_http_server/handlers/get_subscribers_v1.rs +++ b/src/services/public_http_server/handlers/get_subscribers_v1.rs @@ -2,10 +2,12 @@ use { crate::{ error::{Error, Result}, model::helpers::get_subscriber_accounts_and_scopes_by_project_id, - registry::extractor::AuthedProjectId, + rate_limit, + registry::{extractor::AuthedProjectId, storage::redis::Redis}, state::AppState, }, axum::{extract::State, http::StatusCode, response::IntoResponse, Json}, + relay_rpc::domain::ProjectId, std::sync::Arc, tracing::instrument, }; @@ -15,6 +17,10 @@ pub async fn handler( State(state): State>, AuthedProjectId(project_id, _): AuthedProjectId, ) -> Result { + if let Some(redis) = state.redis.as_ref() { + get_subscribers_rate_limit(redis, &project_id).await?; + } + let accounts = get_subscriber_accounts_and_scopes_by_project_id( project_id, &state.postgres, @@ -28,3 +34,14 @@ pub async fn handler( Ok((StatusCode::OK, Json(accounts)).into_response()) } + +pub async fn get_subscribers_rate_limit(redis: &Arc, project_id: &ProjectId) -> Result<()> { + rate_limit::token_bucket( + redis, + project_id.to_string(), + 5, + chrono::Duration::seconds(1), + 1, + ) + .await +} diff --git a/src/services/public_http_server/handlers/health.rs b/src/services/public_http_server/handlers/health.rs index 0d95428d..a182c619 100644 --- a/src/services/public_http_server/handlers/health.rs +++ b/src/services/public_http_server/handlers/health.rs @@ -4,6 +4,9 @@ use { std::sync::Arc, }; +// No rate limit necessary since returning a fixed string is less computational intensive than tracking the rate limit + +// TODO generate this response at app startup to avoid unnecessary string allocations pub async fn handler(State(state): State>) -> impl IntoResponse { ( StatusCode::OK, diff --git a/src/services/public_http_server/handlers/notify_v1.rs b/src/services/public_http_server/handlers/notify_v1.rs index 985486a2..1b1901e3 100644 --- a/src/services/public_http_server/handlers/notify_v1.rs +++ b/src/services/public_http_server/handlers/notify_v1.rs @@ -1,13 +1,13 @@ use { crate::{ - error, - error::Error, + error::{self, Error}, metrics::Metrics, model::{ helpers::{get_project_by_project_id, get_subscribers_for_project_in}, types::AccountId, }, - registry::extractor::AuthedProjectId, + rate_limit, + registry::{extractor::AuthedProjectId, storage::redis::Redis}, services::publisher_service::helpers::{ upsert_notification, upsert_subscriber_notifications, }, @@ -16,8 +16,13 @@ use { }, axum::{extract::State, http::StatusCode, response::IntoResponse, Json}, error::Result, + relay_rpc::domain::ProjectId, serde::{Deserialize, Serialize}, - std::{collections::HashSet, sync::Arc, time::Instant}, + std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Instant, + }, tracing::{info, instrument}, uuid::Uuid, wc::metrics::otel::{Context, KeyValue}, @@ -64,12 +69,20 @@ pub async fn handler_impl( ) -> Result { let start = Instant::now(); + if let Some(redis) = state.redis.as_ref() { + notify_rate_limit(redis, &project_id).await?; + } + for notification in &body { notification.notification.validate()?; } info!("notification count: {}", body.len()); - let subscriber_notification_count = body.iter().map(|n| n.accounts.len()).sum::(); + let subscriber_notifications = body + .iter() + .flat_map(|n| n.accounts.clone()) + .collect::>(); + let subscriber_notification_count = subscriber_notifications.len(); info!("subscriber_notification_count: {subscriber_notification_count}"); const SUBSCRIBER_NOTIFICATION_COUNT_LIMIT: usize = 500; if subscriber_notification_count > SUBSCRIBER_NOTIFICATION_COUNT_LIMIT { @@ -119,7 +132,7 @@ pub async fn handler_impl( ) .await?; - let mut subscriber_ids = Vec::with_capacity(subscribers.len()); + let mut valid_subscribers = Vec::with_capacity(subscribers.len()); for subscriber in subscribers { let account = subscriber.account; response.not_found.remove(&account); @@ -132,8 +145,42 @@ pub async fn handler_impl( continue; } - info!("Sending notification for {account}"); - subscriber_ids.push(subscriber.id); + valid_subscribers.push((subscriber.id, account)); + } + + let valid_subscribers = if let Some(redis) = state.redis.as_ref() { + let result = subscriber_rate_limit( + redis, + &project_id, + valid_subscribers + .iter() + .map(|(subscriber_id, _account)| *subscriber_id), + ) + .await?; + + let mut valid_subscribers2 = Vec::with_capacity(valid_subscribers.len()); + for (subscriber_id, account) in valid_subscribers.into_iter() { + let key = subscriber_rate_limit_key(&project_id, &subscriber_id); + let (remaining, _reset) = result + .get(&key) + .expect("rate limit key expected in response"); + if remaining.is_negative() { + response.failed.insert(SendFailure { + account: account.clone(), + reason: "Rate limit exceeded".into(), + }); + } else { + valid_subscribers2.push((subscriber_id, account)); + } + } + valid_subscribers2 + } else { + valid_subscribers + }; + + let mut subscriber_ids = Vec::with_capacity(valid_subscribers.len()); + for (subscriber_id, account) in valid_subscribers { + subscriber_ids.push(subscriber_id); response.sent.insert(account); } @@ -176,3 +223,35 @@ fn send_metrics(metrics: &Metrics, response: &Response, start: Instant) { .notify_latency .record(&ctx, start.elapsed().as_millis().try_into().unwrap(), &[]) } + +pub async fn notify_rate_limit(redis: &Arc, project_id: &ProjectId) -> Result<()> { + rate_limit::token_bucket( + redis, + project_id.to_string(), + 20, + chrono::Duration::seconds(1), + 2, + ) + .await +} + +type SubscriberRateLimitKey = String; + +pub fn subscriber_rate_limit_key( + project_id: &ProjectId, + subscriber: &Uuid, +) -> SubscriberRateLimitKey { + format!("{}:{}", project_id, subscriber) +} + +pub async fn subscriber_rate_limit( + redis: &Arc, + project_id: &ProjectId, + subscribers: impl IntoIterator, +) -> Result> { + let keys = subscribers + .into_iter() + .map(|subscriber| subscriber_rate_limit_key(project_id, &subscriber)) + .collect::>(); + rate_limit::token_bucket_many(redis, keys, 50, chrono::Duration::hours(1), 2).await +} diff --git a/src/services/public_http_server/handlers/subscribe_topic.rs b/src/services/public_http_server/handlers/subscribe_topic.rs index e981ad49..e0f22ebb 100644 --- a/src/services/public_http_server/handlers/subscribe_topic.rs +++ b/src/services/public_http_server/handlers/subscribe_topic.rs @@ -1,6 +1,9 @@ use { crate::{ - error::Result, model::helpers::upsert_project, registry::extractor::AuthedProjectId, + error::Result, + model::helpers::upsert_project, + rate_limit, + registry::{extractor::AuthedProjectId, storage::redis::Redis}, state::AppState, }, axum::{self, extract::State, response::IntoResponse, Json}, @@ -8,7 +11,7 @@ use { hyper::StatusCode, once_cell::sync::Lazy, regex::Regex, - relay_rpc::domain::Topic, + relay_rpc::domain::{ProjectId, Topic}, serde::{Deserialize, Serialize}, serde_json::json, std::sync::Arc, @@ -40,6 +43,10 @@ pub async fn handler( // ) // .entered(); + if let Some(redis) = state.redis.as_ref() { + subscribe_topic_rate_limit(redis, &project_id).await?; + } + let app_domain = subscribe_topic_data.app_domain; if app_domain.len() > 253 { // Domains max at 253 chars according to: https://en.wikipedia.org/wiki/Hostname @@ -98,6 +105,17 @@ pub async fn handler( .into_response()) } +pub async fn subscribe_topic_rate_limit(redis: &Arc, project_id: &ProjectId) -> Result<()> { + rate_limit::token_bucket( + redis, + project_id.to_string(), + 100, + chrono::Duration::minutes(1), + 1, + ) + .await +} + fn is_domain(domain: &str) -> bool { static DOMAIN_REGEX: Lazy = Lazy::new(|| Regex::new(r"^[a-z0-9-_\.]+$").unwrap()); DOMAIN_REGEX.is_match(domain) diff --git a/src/services/websocket_server/handlers/notify_delete.rs b/src/services/websocket_server/handlers/notify_delete.rs index 7459291e..dfd27240 100644 --- a/src/services/websocket_server/handlers/notify_delete.rs +++ b/src/services/websocket_server/handlers/notify_delete.rs @@ -8,6 +8,8 @@ use { error::Error, model::helpers::{delete_subscriber, get_project_by_id, get_subscriber_by_topic}, publish_relay_message::publish_relay_message, + rate_limit, + registry::storage::redis::Redis, services::websocket_server::{ decode_key, handlers::{decrypt_message, notify_watch_subscriptions::update_subscription_watchers}, @@ -22,11 +24,11 @@ use { chrono::Utc, relay_client::websocket::{Client, PublishedMessage}, relay_rpc::{ - domain::DecodedClientId, + domain::{DecodedClientId, Topic}, rpc::{Publish, JSON_RPC_VERSION_STR}, }, serde_json::{json, Value}, - std::collections::HashSet, + std::{collections::HashSet, sync::Arc}, tracing::{info, warn}, }; @@ -35,6 +37,10 @@ pub async fn handle(msg: PublishedMessage, state: &AppState, client: &Client) -> let topic = msg.topic; let subscription_id = msg.subscription_id; + if let Some(redis) = state.redis.as_ref() { + notify_delete_rate_limit(redis, &topic).await?; + } + // TODO combine these two SQL queries let subscriber = get_subscriber_by_topic(topic.clone(), &state.postgres, state.metrics.as_ref()) @@ -183,3 +189,14 @@ pub async fn handle(msg: PublishedMessage, state: &AppState, client: &Client) -> Ok(()) } + +pub async fn notify_delete_rate_limit(redis: &Arc, topic: &Topic) -> Result<()> { + rate_limit::token_bucket( + redis, + format!("notify-delete-{topic}"), + 10, + chrono::Duration::hours(1), + 1, + ) + .await +} diff --git a/src/services/websocket_server/handlers/notify_subscribe.rs b/src/services/websocket_server/handlers/notify_subscribe.rs index 70b6507d..954a8d17 100644 --- a/src/services/websocket_server/handlers/notify_subscribe.rs +++ b/src/services/websocket_server/handlers/notify_subscribe.rs @@ -8,6 +8,8 @@ use { error::Error, model::helpers::{get_project_by_topic, upsert_subscriber}, publish_relay_message::publish_relay_message, + rate_limit, + registry::storage::redis::Redis, services::websocket_server::{ decode_key, derive_key, handlers::{decrypt_message, notify_watch_subscriptions::update_subscription_watchers}, @@ -26,16 +28,23 @@ use { rpc::{Publish, JSON_RPC_VERSION_STR}, }, serde_json::{json, Value}, - std::collections::HashSet, + std::{collections::HashSet, sync::Arc}, tracing::{info, instrument}, - x25519_dalek::StaticSecret, + x25519_dalek::{PublicKey, StaticSecret}, }; +// TODO limit each subscription to 15 notification types +// TODO limit each account to max 500 subscriptions + // TODO test idempotency (create subscriber a second time for the same account) #[instrument(name = "wc_notifySubscribe", skip_all)] pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { let topic = msg.topic; + if let Some(redis) = state.redis.as_ref() { + notify_subscribe_project_rate_limit(redis, &topic).await?; + } + let project = get_project_by_topic(topic.clone(), &state.postgres, state.metrics.as_ref()) .await .map_err(|e| match e { @@ -48,11 +57,14 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { base64::engine::general_purpose::STANDARD.decode(msg.message.to_string())?, )?; - let client_pubkey = envelope.pubkey(); - let client_pubkey = x25519_dalek::PublicKey::from(client_pubkey); + let client_public_key = x25519_dalek::PublicKey::from(envelope.pubkey()); + + if let Some(redis) = state.redis.as_ref() { + notify_subscribe_client_rate_limit(redis, &client_public_key).await?; + } let sym_key = derive_key( - &client_pubkey, + &client_public_key, &x25519_dalek::StaticSecret::from(decode_key(&project.subscribe_private_key)?), )?; let response_topic = sha256::digest(&sym_key); @@ -133,7 +145,7 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { result: json!({ "responseAuth": response_auth }), // TODO use structure }; - let notify_key = derive_key(&client_pubkey, &secret)?; + let notify_key = derive_key(&client_public_key, &secret)?; let envelope = Envelope::::new(&sym_key, response)?; @@ -238,3 +250,31 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { Ok(()) } + +pub async fn notify_subscribe_client_rate_limit( + redis: &Arc, + client_public_key: &PublicKey, +) -> Result<()> { + rate_limit::token_bucket( + redis, + format!( + "notify-subscribe-client-{}", + hex::encode(client_public_key.as_bytes()) + ), + 500, + chrono::Duration::days(1), + 100, + ) + .await +} + +pub async fn notify_subscribe_project_rate_limit(redis: &Arc, topic: &Topic) -> Result<()> { + rate_limit::token_bucket( + redis, + format!("notify-subscribe-project-{topic}"), + 50000, + chrono::Duration::seconds(1), + 1, + ) + .await +} diff --git a/src/services/websocket_server/handlers/notify_update.rs b/src/services/websocket_server/handlers/notify_update.rs index 6c6fabdf..0b3bda54 100644 --- a/src/services/websocket_server/handlers/notify_update.rs +++ b/src/services/websocket_server/handlers/notify_update.rs @@ -9,6 +9,8 @@ use { error::Error, model::helpers::{get_project_by_id, get_subscriber_by_topic, update_subscriber}, publish_relay_message::publish_relay_message, + rate_limit, + registry::storage::redis::Redis, services::websocket_server::{ decode_key, handlers::decrypt_message, NotifyRequest, NotifyResponse, NotifyUpdate, }, @@ -21,11 +23,11 @@ use { chrono::Utc, relay_client::websocket::PublishedMessage, relay_rpc::{ - domain::DecodedClientId, + domain::{DecodedClientId, Topic}, rpc::{Publish, JSON_RPC_VERSION_STR}, }, serde_json::{json, Value}, - std::collections::HashSet, + std::{collections::HashSet, sync::Arc}, tracing::info, }; @@ -33,6 +35,10 @@ use { pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { let topic = msg.topic; + if let Some(redis) = state.redis.as_ref() { + notify_update_rate_limit(redis, &topic).await?; + } + // TODO combine these two SQL queries let subscriber = get_subscriber_by_topic(topic.clone(), &state.postgres, state.metrics.as_ref()) @@ -187,3 +193,14 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { Ok(()) } + +pub async fn notify_update_rate_limit(redis: &Arc, topic: &Topic) -> Result<()> { + rate_limit::token_bucket( + redis, + format!("notify-update-{topic}"), + 100, + chrono::Duration::seconds(1), + 1, + ) + .await +} diff --git a/src/services/websocket_server/handlers/notify_watch_subscriptions.rs b/src/services/websocket_server/handlers/notify_watch_subscriptions.rs index caab1853..2de3f900 100644 --- a/src/services/websocket_server/handlers/notify_watch_subscriptions.rs +++ b/src/services/websocket_server/handlers/notify_watch_subscriptions.rs @@ -16,6 +16,8 @@ use { types::AccountId, }, publish_relay_message::publish_relay_message, + rate_limit, + registry::storage::redis::Redis, services::websocket_server::{ decode_key, derive_key, handlers::decrypt_message, NotifyRequest, NotifyResponse, NotifyWatchSubscriptions, @@ -38,7 +40,9 @@ use { }, serde_json::{json, Value}, sqlx::PgPool, + std::sync::Arc, tracing::{info, instrument}, + x25519_dalek::PublicKey, }; #[instrument(name = "wc_notifyWatchSubscriptions", skip_all)] @@ -53,6 +57,11 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { let client_public_key = x25519_dalek::PublicKey::from(envelope.pubkey()); info!("client_public_key: {client_public_key:?}"); + + if let Some(redis) = state.redis.as_ref() { + notify_watch_subscriptions_rate_limit(redis, &client_public_key).await?; + } + let response_sym_key = derive_key(&client_public_key, &state.notify_keys.key_agreement_secret)?; let response_topic = sha256::digest(&response_sym_key); @@ -181,6 +190,23 @@ pub async fn handle(msg: PublishedMessage, state: &AppState) -> Result<()> { Ok(()) } +pub async fn notify_watch_subscriptions_rate_limit( + redis: &Arc, + client_public_key: &PublicKey, +) -> Result<()> { + rate_limit::token_bucket( + redis, + format!( + "notify-watch-subscriptions-{}", + hex::encode(client_public_key.as_bytes()) + ), + 100, + chrono::Duration::seconds(1), + 1, + ) + .await +} + #[instrument(skip(postgres, metrics))] pub async fn collect_subscriptions( account: AccountId, diff --git a/src/state.rs b/src/state.rs index a9aea5af..3baad5b2 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,7 +1,11 @@ use { crate::{ - analytics::NotifyAnalytics, error::Result, metrics::Metrics, notify_keys::NotifyKeys, - registry::Registry, Configuration, + analytics::NotifyAnalytics, + error::Result, + metrics::Metrics, + notify_keys::NotifyKeys, + registry::{storage::redis::Redis, Registry}, + Configuration, }, build_info::BuildInfo, relay_rpc::auth::ed25519_dalek::Keypair, @@ -20,6 +24,7 @@ pub struct AppState { pub keypair: Keypair, pub relay_ws_client: Arc, pub relay_http_client: Arc, + pub redis: Option>, pub registry: Arc, pub notify_keys: NotifyKeys, } @@ -37,6 +42,7 @@ impl AppState { relay_ws_client: Arc, relay_http_client: Arc, metrics: Option, + redis: Option>, registry: Arc, ) -> crate::Result { let build_info: &BuildInfo = build_info(); @@ -52,6 +58,7 @@ impl AppState { keypair, relay_ws_client, relay_http_client, + redis, registry, notify_keys, }) diff --git a/tests/deployment.rs b/tests/deployment.rs index 0533eee7..920d5662 100644 --- a/tests/deployment.rs +++ b/tests/deployment.rs @@ -1,5 +1,5 @@ use { - crate::utils::{create_client, verify_jwt, JWT_LEEWAY}, + crate::utils::{create_client, generate_account, verify_jwt, JWT_LEEWAY}, base64::Engine, chacha20poly1305::{ aead::{generic_array::GenericArray, Aead, OsRng}, @@ -15,12 +15,10 @@ use { SubscriptionDeleteRequestAuth, SubscriptionDeleteResponseAuth, SubscriptionRequestAuth, SubscriptionResponseAuth, SubscriptionUpdateRequestAuth, SubscriptionUpdateResponseAuth, WatchSubscriptionsChangedRequestAuth, - WatchSubscriptionsRequestAuth, WatchSubscriptionsResponseAuth, STATEMENT, - STATEMENT_ALL_DOMAINS, STATEMENT_ALL_DOMAINS_IDENTITY, STATEMENT_ALL_DOMAINS_OLD, - STATEMENT_THIS_DOMAIN, STATEMENT_THIS_DOMAIN_IDENTITY, + WatchSubscriptionsRequestAuth, WatchSubscriptionsResponseAuth, STATEMENT_ALL_DOMAINS, + STATEMENT_THIS_DOMAIN, }, jsonrpc::NotifyPayload, - model::types::AccountId, services::{ public_http_server::handlers::{ notify_v0::NotifyBody, @@ -340,16 +338,7 @@ async fn run_test(statement: String, watch_subscriptions_all_domains: bool) { (signing_key, client_did_key) }; - let account_signing_key = k256::ecdsa::SigningKey::random(&mut OsRng); - let address = &Keccak256::default() - .chain_update( - &account_signing_key - .verifying_key() - .to_encoded_point(false) - .as_bytes()[1..], - ) - .finalize()[12..]; - let account: AccountId = format!("eip155:1:0x{}", hex::encode(address)).into(); + let (account_signing_key, account) = generate_account(); let did_pkh = format!("did:pkh:{account}"); let app_domain = &format!("{}.walletconnect.com", vars.notify_project_id); @@ -1039,48 +1028,16 @@ async fn run_test(statement: String, watch_subscriptions_all_domains: bool) { } } -// TODO make into storage test #[tokio::test] async fn notify_all_domains() { run_test(STATEMENT_ALL_DOMAINS.to_owned(), true).await } -#[tokio::test] -async fn notify_all_domains_old() { - run_test(STATEMENT_ALL_DOMAINS_OLD.to_owned(), true).await -} - #[tokio::test] async fn notify_this_domain() { run_test(STATEMENT_THIS_DOMAIN.to_owned(), false).await } -#[tokio::test] -async fn notify_all_domains_identity() { - run_test(STATEMENT_ALL_DOMAINS_IDENTITY.to_owned(), true).await -} - -#[tokio::test] -async fn notify_this_domain_identity() { - run_test(STATEMENT_THIS_DOMAIN_IDENTITY.to_owned(), false).await -} - -#[tokio::test] -async fn old_siwe_compatible() { - run_test(STATEMENT.to_owned(), false).await -} - -#[tokio::test] -async fn old_old_siwe_compatible() { - run_test( - "I further authorize this DAPP to send and receive messages on my behalf for \ - this domain using my WalletConnect identity." - .to_owned(), - false, - ) - .await -} - pub fn encode_auth(auth: &T, signing_key: &SigningKey) -> String { let data = JwtHeader { typ: JWT_HEADER_TYP, diff --git a/tests/integration.rs b/tests/integration.rs index 6f6155e0..4b86b6a8 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -3,7 +3,7 @@ use { async_trait::async_trait, base64::Engine, chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit}, - chrono::{Duration, Utc}, + chrono::{Duration, TimeZone, Utc}, hyper::StatusCode, notify_server::{ auth::{ @@ -23,11 +23,15 @@ use { }, types::AccountId, }, - registry::RegistryAuthResponse, + rate_limit, + registry::{storage::redis::Redis, RegistryAuthResponse}, services::{ public_http_server::handlers::{ notify_v0::NotifyBody, - notify_v1::NotifyBodyNotification, + notify_v1::{ + self, notify_rate_limit, subscriber_rate_limit, subscriber_rate_limit_key, + NotifyBodyNotification, + }, subscribe_topic::{SubscribeTopicRequestData, SubscribeTopicResponseData}, }, publisher_service::helpers::{ @@ -50,6 +54,7 @@ use { collections::HashSet, env, net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, }, test_context::{test_context, AsyncTestContext}, tokio::{ @@ -59,7 +64,7 @@ use { }, tracing_subscriber::fmt::format::FmtSpan, url::Url, - utils::create_client, + utils::{create_client, generate_account}, uuid::Uuid, }; @@ -125,7 +130,7 @@ fn generate_authentication_key() -> ed25519_dalek::SigningKey { } fn generate_account_id() -> AccountId { - "eip155:1:0xfff".into() + generate_account().1 } #[tokio::test] @@ -792,6 +797,7 @@ struct NotifyServerContext { socket_addr: SocketAddr, url: Url, postgres: PgPool, + redis: Arc, } #[async_trait] @@ -834,8 +840,8 @@ impl AsyncTestContext for NotifyServerContext { relay_url: vars.relay_url.parse().unwrap(), notify_url: notify_url.clone(), registry_auth_token: "".to_owned(), - auth_redis_addr_read: None, - auth_redis_addr_write: None, + auth_redis_addr_read: Some("redis://localhost:6379/0".to_owned()), + auth_redis_addr_write: Some("redis://localhost:6379/0".to_owned()), redis_pool_size: 1, telemetry_prometheus_port: None, s3_endpoint: None, @@ -868,11 +874,20 @@ impl AsyncTestContext for NotifyServerContext { .await .unwrap(); + let redis = Arc::new( + Redis::new( + &config.auth_redis_addr().unwrap(), + config.redis_pool_size as usize, + ) + .unwrap(), + ); + Self { shutdown: signal, socket_addr, url: notify_url, postgres, + redis, } } @@ -1229,7 +1244,7 @@ async fn test_notify_v1(notify_server: &NotifyServerContext) { .url .join(&format!("/v1/{project_id}/notify")) .unwrap(); - assert_successful_response( + let response = assert_successful_response( reqwest::Client::new() .post(notify_url) .bearer_auth(Uuid::new_v4()) @@ -1238,7 +1253,13 @@ async fn test_notify_v1(notify_server: &NotifyServerContext) { .await .unwrap(), ) - .await; + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert!(response.failed.is_empty()); + assert_eq!(response.sent, HashSet::from([account.clone()])); let resp = rx.recv().await.unwrap(); let RelayClientEvent::Message(msg) = resp else { @@ -1286,7 +1307,67 @@ async fn test_notify_v1(notify_server: &NotifyServerContext) { #[test_context(NotifyServerContext)] #[tokio::test] -async fn test_notify_idempotent(notify_server: &NotifyServerContext) { +async fn test_notify_v1_response_not_found(notify_server: &NotifyServerContext) { + let project_id = ProjectId::generate(); + let app_domain = generate_app_domain(); + let topic = Topic::generate(); + let subscribe_key = generate_subscribe_key(); + let authentication_key = generate_authentication_key(); + upsert_project( + project_id.clone(), + &app_domain, + topic, + &authentication_key, + &subscribe_key, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let account = generate_account_id(); + let notification_type = Uuid::new_v4(); + + let notification = Notification { + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + icon: Some("icon".to_owned()), + url: Some("url".to_owned()), + }; + + let notification_body = NotifyBodyNotification { + notification_id: None, + notification: notification.clone(), + accounts: vec![account.clone()], + }; + let notify_body = vec![notification_body]; + + let notify_url = notify_server + .url + .join(&format!("/v1/{project_id}/notify")) + .unwrap(); + let response = assert_successful_response( + reqwest::Client::new() + .post(notify_url) + .bearer_auth(Uuid::new_v4()) + .json(¬ify_body) + .send() + .await + .unwrap(), + ) + .await + .json::() + .await + .unwrap(); + assert_eq!(response.not_found, HashSet::from([account.clone()])); + assert!(response.failed.is_empty()); + assert!(response.sent.is_empty()); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_notify_v1_response_not_subscribed_to_scope(notify_server: &NotifyServerContext) { let project_id = ProjectId::generate(); let app_domain = generate_app_domain(); let topic = Topic::generate(); @@ -1325,21 +1406,100 @@ async fn test_notify_idempotent(notify_server: &NotifyServerContext) { .unwrap(); let notification = Notification { - r#type: notification_type, + r#type: Uuid::new_v4(), title: "title".to_owned(), body: "body".to_owned(), icon: Some("icon".to_owned()), url: Some("url".to_owned()), }; - let notification_id = Uuid::new_v4().to_string(); let notification_body = NotifyBodyNotification { - notification_id: Some(notification_id), + notification_id: None, notification: notification.clone(), accounts: vec![account.clone()], }; let notify_body = vec![notification_body]; + let notify_url = notify_server + .url + .join(&format!("/v1/{project_id}/notify")) + .unwrap(); + let response = assert_successful_response( + reqwest::Client::new() + .post(notify_url) + .bearer_auth(Uuid::new_v4()) + .json(¬ify_body) + .send() + .await + .unwrap(), + ) + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert_eq!( + response.failed, + HashSet::from([notify_v1::SendFailure { + account: account.clone(), + reason: "Client is not subscribed to this notification type".into(), + }]) + ); + assert!(response.sent.is_empty()); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_notify_idempotent(notify_server: &NotifyServerContext) { + let project_id = ProjectId::generate(); + let app_domain = generate_app_domain(); + let topic = Topic::generate(); + let subscribe_key = generate_subscribe_key(); + let authentication_key = generate_authentication_key(); + upsert_project( + project_id.clone(), + &app_domain, + topic, + &authentication_key, + &subscribe_key, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + let project = get_project_by_project_id(project_id.clone(), ¬ify_server.postgres, None) + .await + .unwrap(); + + let account = generate_account_id(); + let notification_type = Uuid::new_v4(); + let scope = HashSet::from([notification_type]); + let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); + let notify_topic: Topic = sha256::digest(¬ify_key).into(); + upsert_subscriber( + project.id, + account.clone(), + scope.clone(), + ¬ify_key, + notify_topic.clone(), + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let notify_body = vec![NotifyBodyNotification { + notification_id: Some(Uuid::new_v4().to_string()), + notification: Notification { + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + icon: Some("icon".to_owned()), + url: Some("url".to_owned()), + }, + accounts: vec![account.clone()], + }]; + let notify_url = notify_server .url .join(&format!("/v1/{project_id}/notify")) @@ -1367,6 +1527,393 @@ async fn test_notify_idempotent(notify_server: &NotifyServerContext) { .await; } +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_token_bucket(notify_server: &NotifyServerContext) { + let key = Uuid::new_v4(); + let max_tokens = 2; + let refill_interval = chrono::Duration::milliseconds(500); + let refill_rate = 1; + let rate_limit = || async { + rate_limit::token_bucket_many( + ¬ify_server.redis, + vec![key.to_string()], + max_tokens, + refill_interval, + refill_rate, + ) + .await + .unwrap() + .get(&key.to_string()) + .unwrap() + .to_owned() + }; + + let burst = || async { + for tokens_remaining in (0..max_tokens).rev() { + let result = rate_limit().await; + assert_eq!(result.0, tokens_remaining as i64); + } + + // Do it again, but fail, wait half a second and then it works again 1 time + for _ in 0..2 { + let result = rate_limit().await; + assert!(result.0.is_negative()); + println!("result.1: {}", result.1); + let refill_in = Utc + .from_local_datetime( + &chrono::NaiveDateTime::from_timestamp_millis(result.1 as i64).unwrap(), + ) + .unwrap() + .signed_duration_since(Utc::now()); + println!("refill_in: {refill_in}"); + assert!(refill_in > chrono::Duration::zero()); + assert!(refill_in < refill_interval); + + tokio::time::sleep(refill_interval.to_std().unwrap()).await; + + let result = rate_limit().await; + assert_eq!(result.0, 0); + } + }; + + burst().await; + + // Let burst ability recover + tokio::time::sleep( + (refill_interval * (max_tokens / refill_rate) as i32) + .to_std() + .unwrap(), + ) + .await; + + burst().await; +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_token_bucket_separate_keys(notify_server: &NotifyServerContext) { + let rate_limit = |key: String| async move { + rate_limit::token_bucket_many( + ¬ify_server.redis, + vec![key.clone()], + 2, + chrono::Duration::milliseconds(500), + 1, + ) + .await + .unwrap() + .get(&key) + .unwrap() + .to_owned() + }; + + let key1 = Uuid::new_v4(); + let key2 = Uuid::new_v4(); + + let result = rate_limit(key1.to_string()).await; + assert_eq!(result.0, 1); + let result = rate_limit(key2.to_string()).await; + assert_eq!(result.0, 1); + let result = rate_limit(key2.to_string()).await; + assert_eq!(result.0, 0); + let result = rate_limit(key1.to_string()).await; + assert_eq!(result.0, 0); + let result = rate_limit(key2.to_string()).await; + assert_eq!(result.0, -1); + let result = rate_limit(key1.to_string()).await; + assert_eq!(result.0, -1); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_notify_rate_limit(notify_server: &NotifyServerContext) { + let project_id = ProjectId::generate(); + let app_domain = generate_app_domain(); + let topic = Topic::generate(); + let subscribe_key = generate_subscribe_key(); + let authentication_key = generate_authentication_key(); + upsert_project( + project_id.clone(), + &app_domain, + topic, + &authentication_key, + &subscribe_key, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let notification_type = Uuid::new_v4(); + let notify_body = vec![NotifyBodyNotification { + notification_id: None, + notification: Notification { + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + icon: None, + url: None, + }, + accounts: vec![], + }]; + + let notify_url = notify_server + .url + .join(&format!("/v1/{project_id}/notify")) + .unwrap(); + let notify = || async { + reqwest::Client::new() + .post(notify_url.clone()) + .bearer_auth(Uuid::new_v4()) + .json(¬ify_body) + .send() + .await + .unwrap() + }; + + // Use up the rate limit + for _ in 0..20 { + notify_rate_limit(¬ify_server.redis, &project_id) + .await + .unwrap(); + } + + // No longer successful + let response = notify().await; + let status = response.status(); + if status != StatusCode::TOO_MANY_REQUESTS { + panic!( + "expected too many requests response, got {status}: {:?}", + response.text().await + ); + } +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_notify_subscriber_rate_limit(notify_server: &NotifyServerContext) { + let project_id = ProjectId::generate(); + let app_domain = generate_app_domain(); + let topic = Topic::generate(); + let subscribe_key = generate_subscribe_key(); + let authentication_key = generate_authentication_key(); + upsert_project( + project_id.clone(), + &app_domain, + topic, + &authentication_key, + &subscribe_key, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + let project = get_project_by_project_id(project_id.clone(), ¬ify_server.postgres, None) + .await + .unwrap(); + + let account = generate_account_id(); + let notification_type = Uuid::new_v4(); + let scope = HashSet::from([notification_type]); + let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); + let notify_topic: Topic = sha256::digest(¬ify_key).into(); + let subscriber_id = upsert_subscriber( + project.id, + account.clone(), + scope.clone(), + ¬ify_key, + notify_topic.clone(), + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let notify_body = vec![NotifyBodyNotification { + notification_id: None, + notification: Notification { + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + icon: None, + url: None, + }, + accounts: vec![account.clone()], + }]; + + let notify_url = notify_server + .url + .join(&format!("/v1/{project_id}/notify")) + .unwrap(); + let notify = || async { + reqwest::Client::new() + .post(notify_url.clone()) + .bearer_auth(Uuid::new_v4()) + .json(¬ify_body) + .send() + .await + .unwrap() + }; + + for _ in 0..49 { + let result = subscriber_rate_limit(¬ify_server.redis, &project_id, [subscriber_id]) + .await + .unwrap(); + assert!(result + .get(&subscriber_rate_limit_key(&project_id, &subscriber_id)) + .unwrap() + .0 + .is_positive()); + } + + let response = assert_successful_response(notify().await) + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert!(response.failed.is_empty()); + assert_eq!(response.sent, HashSet::from([account.clone()])); + + let response = assert_successful_response(notify().await) + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert_eq!( + response.failed, + HashSet::from([notify_v1::SendFailure { + account: account.clone(), + reason: "Rate limit exceeded".into(), + }]) + ); + assert!(response.sent.is_empty()); +} + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn test_notify_subscriber_rate_limit_single(notify_server: &NotifyServerContext) { + let project_id = ProjectId::generate(); + let app_domain = generate_app_domain(); + let topic = Topic::generate(); + let subscribe_key = generate_subscribe_key(); + let authentication_key = generate_authentication_key(); + upsert_project( + project_id.clone(), + &app_domain, + topic, + &authentication_key, + &subscribe_key, + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + let project = get_project_by_project_id(project_id.clone(), ¬ify_server.postgres, None) + .await + .unwrap(); + + let notification_type = Uuid::new_v4(); + + let account1 = generate_account_id(); + let scope = HashSet::from([notification_type]); + let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); + let notify_topic: Topic = sha256::digest(¬ify_key).into(); + let subscriber_id1 = upsert_subscriber( + project.id, + account1.clone(), + scope.clone(), + ¬ify_key, + notify_topic.clone(), + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let account2 = generate_account_id(); + let scope = HashSet::from([notification_type]); + let notify_key = rand::Rng::gen::<[u8; 32]>(&mut rand::thread_rng()); + let notify_topic: Topic = sha256::digest(¬ify_key).into(); + let _subscriber_id2 = upsert_subscriber( + project.id, + account2.clone(), + scope.clone(), + ¬ify_key, + notify_topic.clone(), + ¬ify_server.postgres, + None, + ) + .await + .unwrap(); + + let notify_body = vec![NotifyBodyNotification { + notification_id: None, + notification: Notification { + r#type: notification_type, + title: "title".to_owned(), + body: "body".to_owned(), + icon: None, + url: None, + }, + accounts: vec![account1.clone(), account2.clone()], + }]; + + let notify_url = notify_server + .url + .join(&format!("/v1/{project_id}/notify")) + .unwrap(); + let notify = || async { + reqwest::Client::new() + .post(notify_url.clone()) + .bearer_auth(Uuid::new_v4()) + .json(¬ify_body) + .send() + .await + .unwrap() + }; + + for _ in 0..49 { + let result = subscriber_rate_limit(¬ify_server.redis, &project_id, [subscriber_id1]) + .await + .unwrap(); + assert!(result + .get(&subscriber_rate_limit_key(&project_id, &subscriber_id1)) + .unwrap() + .0 + .is_positive()); + } + + let response = assert_successful_response(notify().await) + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert!(response.failed.is_empty()); + assert_eq!( + response.sent, + HashSet::from([account1.clone(), account2.clone()]) + ); + + let response = assert_successful_response(notify().await) + .await + .json::() + .await + .unwrap(); + assert!(response.not_found.is_empty()); + assert_eq!( + response.failed, + HashSet::from([notify_v1::SendFailure { + account: account1.clone(), + reason: "Rate limit exceeded".into(), + }]) + ); + assert_eq!(response.sent, HashSet::from([account2.clone()])); +} + #[test_context(NotifyServerContext)] #[tokio::test] async fn test_ignores_invalid_scopes(notify_server: &NotifyServerContext) { diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 4538719f..53ba532a 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,16 +1,21 @@ use { base64::Engine, ed25519_dalek::VerifyingKey, + k256::ecdsa::SigningKey, notify_server::{ auth::AuthError, + model::types::AccountId, notify_message::JwtMessage, relay_client_helpers::create_ws_connect_options, services::websocket_server::relay_ws_client::{RelayClientEvent, RelayConnectionHandler}, }, rand::rngs::StdRng, + rand_chacha::rand_core::OsRng, rand_core::SeedableRng, relay_client::websocket, relay_rpc::{auth::ed25519_dalek::Keypair, domain::ProjectId}, + sha2::Digest, + sha3::Keccak256, std::sync::Arc, tokio::sync::mpsc::UnboundedReceiver, url::Url, @@ -74,3 +79,17 @@ pub fn verify_jwt(jwt: &str, key: &VerifyingKey) -> notify_server::error::Result Ok(false) | Err(_) => Err(AuthError::InvalidSignature)?, } } + +pub fn generate_account() -> (SigningKey, AccountId) { + let account_signing_key = k256::ecdsa::SigningKey::random(&mut OsRng); + let address = &Keccak256::default() + .chain_update( + &account_signing_key + .verifying_key() + .to_encoded_point(false) + .as_bytes()[1..], + ) + .finalize()[12..]; + let account = format!("eip155:1:0x{}", hex::encode(address)).into(); + (account_signing_key, account) +}