Skip to content

Commit

Permalink
Merge pull request #11 from thibault-cne/feature/rate-limiter
Browse files Browse the repository at this point in the history
Feature/rate limiter
  • Loading branch information
thibault-cne authored Apr 14, 2024
2 parents 60fcab3 + 24c88fe commit f5471fe
Show file tree
Hide file tree
Showing 26 changed files with 362 additions and 23 deletions.
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ sea-query = { version = "0.30.7", features = ["backend-mysql"] }
r2d2 = "0.8.10"
mysql = { version = "25.0.0", default-features = false, features = ["minimal"] }
mysql_common = { version = "0.32.2", default-features = false, features = ["chrono", "derive", "bigdecimal"] }
redis = "0.25.3"
1 change: 1 addition & 0 deletions bin/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
rocket.workspace = true
dotenvy.workspace = true
mysql.workspace = true
chrono.workspace = true

api-lib.workspace = true
infrastructure.workspace = true
7 changes: 7 additions & 0 deletions bin/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ fn rocket() -> _ {
dotenv().ok();

rocket::build()
.attach(api_lib::fairings::helmet::Formula1Helmet)
.attach(api_lib::fairings::rate_limiter::RateLimiter)
.mount("/api", api_lib::handlers::handlers())
.mount("/fallback", api_lib::fallbacks::handlers())
.manage(infrastructure::ConnectionPool::try_new().unwrap())
.manage(api_lib::fairings::rate_limiter::SlidingWindow::new(
10,
chrono::Duration::seconds(60),
))
}
2 changes: 2 additions & 0 deletions crates/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ serde.workspace = true
r2d2.workspace = true
serde_json.workspace = true
mysql.workspace = true
chrono.workspace = true
redis.workspace = true

application.workspace = true
infrastructure.workspace = true
Expand Down
4 changes: 2 additions & 2 deletions crates/api/src/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/circuits?<circuit_ref>", rank = 1)]
pub fn circuits_ref(
fn circuits_ref(
db: &State<ConnectionPool>,
series: Series,
circuit_ref: shared::parameters::CircuitRef,
Expand All @@ -24,7 +24,7 @@ pub fn circuits_ref(
}

#[get("/<series>/circuits?<param..>", rank = 2)]
pub fn circuits(
fn circuits(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetCircuitsParameter,
Expand Down
2 changes: 1 addition & 1 deletion crates/api/src/constructor_standings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/constructors/standing?<param..>")]
pub fn constructor_standings(
fn constructor_standings(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetConstructorStandingsParameter,
Expand Down
4 changes: 2 additions & 2 deletions crates/api/src/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/constructors?<constructor_ref>", rank = 1)]
pub fn constructors_ref(
fn constructors_ref(
db: &State<ConnectionPool>,
series: Series,
constructor_ref: shared::parameters::ConstructorRef,
Expand All @@ -25,7 +25,7 @@ pub fn constructors_ref(
}

#[get("/<series>/constructors?<param..>", rank = 2)]
pub fn constructors(
fn constructors(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetConstructorsParameter,
Expand Down
2 changes: 1 addition & 1 deletion crates/api/src/driver_standings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/drivers/standing?<param..>")]
pub fn driver_standings(
fn driver_standings(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetDriverStandingsParameter,
Expand Down
4 changes: 2 additions & 2 deletions crates/api/src/drivers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/drivers?<driver_ref>", rank = 1)]
pub fn drivers_ref(
fn drivers_ref(
db: &State<ConnectionPool>,
series: Series,
driver_ref: shared::parameters::DriverRef,
Expand All @@ -24,7 +24,7 @@ pub fn drivers_ref(
}

#[get("/<series>/drivers?<param..>", rank = 2)]
pub fn drivers(
fn drivers(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetDriversParameter,
Expand Down
34 changes: 34 additions & 0 deletions crates/api/src/fairings/helmet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use rocket::fairing::{Fairing, Info, Kind};
use rocket::{uri, Data, Orbit, Request, Rocket};

use infrastructure::ConnectionPool;

use crate::fairings::rate_limiter::SlidingWindow;
use crate::fallbacks::rocket_uri_macro_internal_ressource;

pub struct Formula1Helmet;

#[rocket::async_trait]
impl Fairing for Formula1Helmet {
fn info(&self) -> Info {
Info {
name: "Formula1Helmet",
kind: Kind::Liftoff | Kind::Request,
}
}

async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
if rocket.config().ip_header.is_none()
|| rocket.state::<ConnectionPool>().is_none()
|| rocket.state::<SlidingWindow>().is_none()
{
rocket.shutdown().notify()
}
}

async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
if req.uri().path().as_str().starts_with("/fallback") {
req.set_uri(uri!("/fallback", internal_ressource))
}
}
}
2 changes: 2 additions & 0 deletions crates/api/src/fairings/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod helmet;
pub mod rate_limiter;
116 changes: 116 additions & 0 deletions crates/api/src/fairings/rate_limiter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use chrono::{DateTime, Duration, Utc};
use redis::Commands;
use rocket::fairing::{Fairing, Info, Kind};
use rocket::{uri, Data, Request};

use infrastructure::ConnectionPool;

use crate::fallbacks::rocket_uri_macro_rate_limiter_fallback;

const RATE_LIMITER_KEY_PREFIX: &str = "RATE_LIMITER_";

pub struct RateLimiter;

#[rocket::async_trait]
impl Fairing for RateLimiter {
fn info(&self) -> Info {
Info {
name: "RateLimiter",
kind: Kind::Request,
}
}

async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
// SAFETY: This values should always be defined
let ip_header = req
.rocket()
.config()
.ip_header
.as_ref()
.unwrap()
.to_string();
let pool = req.rocket().state::<ConnectionPool>().unwrap();
let window = req.rocket().state::<SlidingWindow>().unwrap();
let rate_limiter = &mut pool.cache.get().unwrap();

let ip_addr = match req.real_ip() {
Some(ip_addr) => ip_addr,
None => {
req.set_uri(uri!("/fallback", rate_limiter_fallback(Some(ip_header), _)));
return;
}
};

let key = format!("{}{}", RATE_LIMITER_KEY_PREFIX, ip_addr);
let now = Utc::now();
let timestamps: String = match rate_limiter.get(&key) {
Ok(res) => res,
Err(_) => {
rate_limiter
.set::<String, String, ()>(
key,
serde_json::to_string(&[(now.timestamp(), now.timestamp_subsec_nanos())])
.unwrap(),
)
.unwrap();
return;
}
};

let window_validation = |date| now - date > window.duration;
let mut timestamps = cleanup(
serde_json::from_str(&timestamps).unwrap(),
window_validation,
);

if timestamps.len() == window.request_num {
// SAFETY: timestamps is not empty and secs and nsecs
// are from DateTime::timestamp and DateTime::timestamp_subsec_nanos
let first = timestamps
.first()
.map(|&(secs, nsecs)| DateTime::from_timestamp(secs, nsecs).unwrap())
.unwrap();
let time_to_wait = (window.duration - (now - first)).num_seconds();
req.set_uri(uri!(
"/fallback",
rate_limiter_fallback(_, Some(time_to_wait))
));
return;
}

timestamps.push((now.timestamp(), now.timestamp_subsec_nanos()));
rate_limiter
.set::<String, String, ()>(key, serde_json::to_string(&timestamps).unwrap())
.unwrap();
}
}

pub struct SlidingWindow {
request_num: usize,
duration: Duration,
}

impl SlidingWindow {
/// Create a sliding window of `request_num` requests over a given `duration`.
pub fn new(request_num: usize, duration: Duration) -> Self {
Self {
request_num,
duration,
}
}
}

fn cleanup<F>(timestamps: Vec<(i64, u32)>, mut f: F) -> Vec<(i64, u32)>
where
F: FnMut(DateTime<Utc>) -> bool,
{
timestamps
.into_iter()
.skip_while(|&(secs, nsecs)| {
// SAFETY: secs and nsecs are from DateTime::timestamp
// and DateTime::timestamp_subsec_nanos
let date = DateTime::from_timestamp(secs, nsecs).unwrap();
f(date)
})
.collect()
}
36 changes: 36 additions & 0 deletions crates/api/src/fallbacks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use rocket::{get, routes, Route};

use shared::{error, prelude::*};

#[get("/rate_limiter?<header_not_found>&<too_many_requests>")]
fn rate_limiter_fallback(
header_not_found: Option<String>,
too_many_requests: Option<i64>,
) -> Result<()> {
if let Some(ip_header) = header_not_found {
return Err(
error!(IpHeaderNotFound => "ip header not found, expected to find it under the `{}` header", ip_header),
);
}

if let Some(time_to_wait) = too_many_requests {
return Err(
error!(RateLimitReached => "you reached the rate limit, please wait `{}s` before your next request", time_to_wait),
);
}

Err(
error!(InternalServerError => "rate_limiter_fallback should not be called without any parameters, please open an issue"),
)
}

#[get("/internal_ressource")]
fn internal_ressource() -> Result<()> {
Err(
error!(InternalRessource => "this ressource is intended for internal purposes you can't access it"),
)
}

pub fn handlers() -> Vec<Route> {
routes![rate_limiter_fallback, internal_ressource]
}
1 change: 1 addition & 0 deletions crates/api/src/guards/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion crates/api/src/laps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use infrastructure::ConnectionPool;
use shared::prelude::*;

#[get("/<series>/laps?<param..>")]
pub fn laps(
fn laps(
db: &State<ConnectionPool>,
series: Series,
param: shared::parameters::GetLapsParameter,
Expand Down
Loading

0 comments on commit f5471fe

Please sign in to comment.