Skip to content

Commit a0cd86c

Browse files
authored
Merge pull request #50 from thibault-cne/issue/39
Added a cache layer
2 parents b97c8ac + e2f9236 commit a0cd86c

File tree

7 files changed

+129
-23
lines changed

7 files changed

+129
-23
lines changed

config.yml

+4
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ middlewares:
1616
requests: 100
1717
seconds: 60
1818
type: sliding_window
19+
- cache:
20+
enabled: true
21+
# Time to live in seconds
22+
ttl: 3600

crates/api/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ tokio.workspace = true
1313
serde_json.workspace = true
1414
redis.workspace = true
1515

16+
http-body-util = "0.1.0"
17+
1618
application = { path = "../application" }
1719
infrastructure = { path = "../infrastructure" }
1820
shared = { path = "../shared" }
@@ -21,4 +23,3 @@ shared = { path = "../shared" }
2123
serde.workspace = true
2224

2325
tower = "0.4.13"
24-
http-body-util = "0.1.0"

crates/api/src/lib.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use axum::{middleware, Extension, Router};
22

33
use infrastructure::config::{Config, MiddlewareConfig};
4+
use middlewares::cache::Cache;
45
use shared::error::Result;
56

67
use crate::middlewares::rate_limiter::RateLimiter;
@@ -85,8 +86,8 @@ struct ServiceBuilder<'c> {
8586
impl<'c> ServiceBuilder<'c> {
8687
fn middlewares(self) -> Result<Router> {
8788
if let Some(middlewares) = &self.config.middlewares {
88-
let router = middlewares.iter().fold(self.router, |router, m| match m {
89-
&MiddlewareConfig::RateLimiter {
89+
let router = middlewares.iter().fold(self.router, |router, m| match *m {
90+
MiddlewareConfig::RateLimiter {
9091
enabled,
9192
ty,
9293
seconds,
@@ -95,6 +96,12 @@ impl<'c> ServiceBuilder<'c> {
9596
RateLimiter::new(ty, requests, seconds),
9697
middlewares::rate_limiter::mw_rate_limiter,
9798
)),
99+
MiddlewareConfig::Cache { enabled, ttl } if enabled => {
100+
router.route_layer(middleware::from_fn_with_state(
101+
Cache::new(ttl),
102+
middlewares::cache::mw_cache_layer,
103+
))
104+
}
98105
_ => router,
99106
});
100107
Ok(router)

crates/api/src/middlewares/cache.rs

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use axum::body::{Body, Bytes};
2+
use axum::extract::State;
3+
use axum::http::{header, StatusCode};
4+
use axum::Extension;
5+
use axum::{http::Request, middleware::Next, response::Response};
6+
use http_body_util::BodyExt;
7+
use redis::Commands;
8+
9+
use infrastructure::ConnectionPool;
10+
use shared::prelude::*;
11+
12+
const CACHE_KEY_PREFIX: &str = "cache";
13+
14+
#[derive(Clone, Copy)]
15+
pub struct Cache {
16+
ttl: u64,
17+
}
18+
19+
pub async fn mw_cache_layer(
20+
Extension(conn): Extension<ConnectionPool>,
21+
State(cache): State<Cache>,
22+
req: Request<Body>,
23+
next: Next,
24+
) -> Result<Response> {
25+
let conn = &mut conn.cache.get()?;
26+
let key = format!("{}:{}", CACHE_KEY_PREFIX, req.uri());
27+
28+
let val: String = match conn.get(&key) {
29+
Ok(val) => val,
30+
Err(_) => {
31+
log::info!("missed cache, forwarding the request");
32+
let res = next.run(req).await;
33+
let (parts, body) = res.into_parts();
34+
let (val, bytes) = buffer_and_print(body).await?;
35+
let res = Response::from_parts(parts, Body::from(bytes));
36+
37+
match conn.set_ex(key, val, cache.ttl) {
38+
Ok(()) => (),
39+
Err(e) => return Err(e.into()),
40+
}
41+
log::info!("added response to the cache with ttl: {}", cache.ttl);
42+
43+
return Ok(res);
44+
}
45+
};
46+
log::info!("retrieved value from cache");
47+
let bytes = Bytes::from(val);
48+
let response = Response::builder()
49+
.status(StatusCode::OK)
50+
.header(header::CONTENT_TYPE, "application/json")
51+
.body(Body::from(bytes))?;
52+
53+
Ok(response)
54+
}
55+
56+
async fn buffer_and_print<B>(body: B) -> Result<(String, Bytes)>
57+
where
58+
B: axum::body::HttpBody<Data = Bytes>,
59+
B::Error: std::fmt::Display,
60+
{
61+
let bytes = match body.collect().await {
62+
Ok(collected) => collected.to_bytes(),
63+
Err(_) => todo!(),
64+
};
65+
66+
let body = String::from_utf8(bytes.clone().into())?;
67+
Ok((body, bytes))
68+
}
69+
70+
impl Cache {
71+
pub fn new(ttl: u64) -> Cache {
72+
Cache { ttl }
73+
}
74+
}

crates/api/src/middlewares/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod cache;
12
pub mod rate_limiter;

crates/infrastructure/src/config.rs

+5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ pub enum MiddlewareConfig {
4646
seconds: i64,
4747
requests: usize,
4848
},
49+
Cache {
50+
#[serde(default)]
51+
enabled: bool,
52+
ttl: u64,
53+
},
4954
}
5055

5156
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]

crates/shared/src/error.rs

+34-20
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ pub enum ErrorKind {
2727
EntityNotFound,
2828
ResourceNotFound,
2929
IpHeaderNotFound,
30+
3031
RateLimitReached,
3132
InternalServer,
3233
MissingEnvVar,
3334
ConnectionPool,
35+
36+
// Std errors
3437
ParseInt,
38+
FromUtf8,
3539

3640
// External error kinds
3741
R2D2,
@@ -49,24 +53,28 @@ macros::error_from!(Figment => figment::Error);
4953
macros::error_from!(Serde => serde_json::Error);
5054
macros::error_from!(Axum => axum::Error);
5155
macros::error_from!(Axum => axum::http::Error);
56+
macros::error_from!(FromUtf8 => std::string::FromUtf8Error);
5257

5358
impl std::fmt::Display for ErrorKind {
5459
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5560
use ErrorKind::*;
5661

5762
match self {
5863
InvalidParameter => write!(f, "invalid parameter"),
59-
R2D2 => write!(f, "r2d2 error"),
60-
Mysql => write!(f, "mysql error"),
6164
EntityNotFound => write!(f, "entity not found"),
6265
IpHeaderNotFound => write!(f, "ip header not found"),
6366
RateLimitReached => write!(f, "rate limit reached on given the sliding window"),
6467
InternalServer => write!(f, "an unexpected error occured"),
6568
ResourceNotFound => write!(f, "the queried resource was not found"),
66-
Redis => write!(f, "redis error"),
6769
MissingEnvVar => write!(f, "an environment variable is missing"),
6870
ConnectionPool => write!(f, "an error occured while setting-up a connection pool"),
71+
6972
ParseInt => write!(f, "can't parse int"),
73+
FromUtf8 => write!(f, "invalid utf8 string"),
74+
75+
R2D2 => write!(f, "r2d2 error"),
76+
Mysql => write!(f, "mysql error"),
77+
Redis => write!(f, "redis error"),
7078
Figment => write!(f, "figment error"),
7179
Serde => write!(f, "serde error"),
7280
Axum => write!(f, "axum error"),
@@ -83,20 +91,23 @@ impl Serialize for ErrorKind {
8391

8492
match self {
8593
InvalidParameter => s.serialize_unit_variant("ErrorKind", 0, "InvalidParameter"),
86-
R2D2 => s.serialize_unit_variant("ErrorKind", 1, "R2D2"),
87-
Mysql => s.serialize_unit_variant("ErrorKind", 2, "Mysql"),
88-
EntityNotFound => s.serialize_unit_variant("ErrorKind", 3, "EntityNotFound"),
89-
IpHeaderNotFound => s.serialize_unit_variant("ErrorKind", 4, "IpHeaderNotFound"),
90-
RateLimitReached => s.serialize_unit_variant("ErrorKind", 5, "RateLimitReached"),
91-
InternalServer => s.serialize_unit_variant("ErrorKind", 6, "InternalServerError"),
92-
ResourceNotFound => s.serialize_unit_variant("ErrorKind", 7, "ResourceNotFound"),
93-
Redis => s.serialize_unit_variant("ErrorKind", 8, "Redis"),
94-
MissingEnvVar => s.serialize_unit_variant("ErrorKind", 9, "MissingEnvVar"),
95-
ConnectionPool => s.serialize_unit_variant("ErrorKind", 10, "ConnectionPool"),
96-
ParseInt => s.serialize_unit_variant("ErrorKind", 11, "ParseInt"),
97-
Figment => s.serialize_unit_variant("ErrorKind", 11, "Figment"),
98-
Serde => s.serialize_unit_variant("ErrorKind", 12, "Serde"),
99-
Axum => s.serialize_unit_variant("ErrorKind", 13, "Axum"),
94+
EntityNotFound => s.serialize_unit_variant("ErrorKind", 1, "EntityNotFound"),
95+
IpHeaderNotFound => s.serialize_unit_variant("ErrorKind", 2, "IpHeaderNotFound"),
96+
RateLimitReached => s.serialize_unit_variant("ErrorKind", 3, "RateLimitReached"),
97+
InternalServer => s.serialize_unit_variant("ErrorKind", 4, "InternalServerError"),
98+
ResourceNotFound => s.serialize_unit_variant("ErrorKind", 5, "ResourceNotFound"),
99+
MissingEnvVar => s.serialize_unit_variant("ErrorKind", 6, "MissingEnvVar"),
100+
ConnectionPool => s.serialize_unit_variant("ErrorKind", 7, "ConnectionPool"),
101+
102+
ParseInt => s.serialize_unit_variant("ErrorKind", 8, "ParseInt"),
103+
FromUtf8 => s.serialize_unit_variant("ErrorKind", 9, "FromUtf8"),
104+
105+
R2D2 => s.serialize_unit_variant("ErrorKind", 10, "R2D2"),
106+
Mysql => s.serialize_unit_variant("ErrorKind", 11, "Mysql"),
107+
Redis => s.serialize_unit_variant("ErrorKind", 12, "Redis"),
108+
Figment => s.serialize_unit_variant("ErrorKind", 13, "Figment"),
109+
Serde => s.serialize_unit_variant("ErrorKind", 14, "Serde"),
110+
Axum => s.serialize_unit_variant("ErrorKind", 15, "Axum"),
100111
}
101112
}
102113
}
@@ -109,17 +120,20 @@ impl From<ErrorKind> for StatusCode {
109120

110121
match kind {
111122
InvalidParameter => Self::BAD_REQUEST,
112-
R2D2 => Self::INTERNAL_SERVER_ERROR,
113-
Mysql => Self::INTERNAL_SERVER_ERROR,
114123
EntityNotFound => Self::NOT_FOUND,
115124
IpHeaderNotFound => Self::BAD_REQUEST,
116125
RateLimitReached => Self::TOO_MANY_REQUESTS,
117126
InternalServer => Self::INTERNAL_SERVER_ERROR,
118127
ResourceNotFound => Self::NOT_FOUND,
119-
Redis => Self::INTERNAL_SERVER_ERROR,
120128
MissingEnvVar => Self::INTERNAL_SERVER_ERROR,
121129
ConnectionPool => Self::INTERNAL_SERVER_ERROR,
130+
122131
ParseInt => Self::INTERNAL_SERVER_ERROR,
132+
FromUtf8 => Self::INTERNAL_SERVER_ERROR,
133+
134+
R2D2 => Self::INTERNAL_SERVER_ERROR,
135+
Mysql => Self::INTERNAL_SERVER_ERROR,
136+
Redis => Self::INTERNAL_SERVER_ERROR,
123137
Figment => Self::INTERNAL_SERVER_ERROR,
124138
Serde => Self::INTERNAL_SERVER_ERROR,
125139
Axum => Self::INTERNAL_SERVER_ERROR,

0 commit comments

Comments
 (0)