Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from async-std and tide to tokio and axum #39

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,283 changes: 168 additions & 1,115 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ license = "MIT OR Apache-2.0"
[dependencies]
a2 = { git = "https://github.com/WalletConnect/a2/", branch = "master" }
anyhow = "1.0.32"
async-std = { version = "1.9", features = ["tokio1", "attributes", "unstable"] }
axum = "0.7.5"
femme = "2.1.0"
humantime = "2.0.1"
log = "0.4.11"
prometheus-client = "0.22.2"
rand = "0.8.5"
reqwest = "0.12.4"
serde = { version = "1.0.114", features = ["derive"] }
serde_json = "1.0.122"
sled = "0.34.2"
structopt = "0.3.15"
tide = "0.16.0"
tokio = { version = "1.39.2", features = ["full"] }
yup-oauth2 = "9.0.0"

[dev-dependencies]
Expand Down
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct Opt {
fcm_key_path: String,
}

#[async_std::main]
#[tokio::main]
async fn main() -> Result<()> {
femme::start();

Expand All @@ -63,7 +63,7 @@ async fn main() -> Result<()> {

if let Some(metrics_address) = opt.metrics.clone() {
let state = state.clone();
async_std::task::spawn(async move { metrics::start(state, metrics_address).await });
tokio::task::spawn(async move { metrics::start(state, metrics_address).await });
}

// Setup mulitple parallel notifiers.
Expand All @@ -72,7 +72,7 @@ async fn main() -> Result<()> {
// and use the same HTTP/2 clients, one for production and one for sandbox server.
for _ in 0..50 {
let state = state.clone();
async_std::task::spawn(async move { notifier::start(state, interval).await });
tokio::task::spawn(async move { notifier::start(state, interval).await });
}

server::start(state, host, port).await?;
Expand Down
31 changes: 19 additions & 12 deletions src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

use std::sync::atomic::AtomicI64;

use anyhow::Result;
use axum::http::{header, HeaderMap};
use axum::response::IntoResponse;
use axum::routing::get;
use prometheus_client::encoding::text::encode;
use prometheus_client::metrics::counter::Counter;
use prometheus_client::metrics::gauge::Gauge;
use prometheus_client::registry::Registry;

use anyhow::Result;

use crate::state::State;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -86,18 +88,23 @@ impl Metrics {
}

pub async fn start(state: State, server: String) -> Result<()> {
let mut app = tide::with_state(state);
app.at("/metrics").get(metrics);
app.listen(server).await?;
let app = axum::Router::new()
.route("/metrics", get(metrics))
.with_state(state);
let listener = tokio::net::TcpListener::bind(server).await?;
axum::serve(listener, app).await?;
Ok(())
}

async fn metrics(req: tide::Request<State>) -> tide::Result<tide::Response> {
async fn metrics(axum::extract::State(state): axum::extract::State<State>) -> impl IntoResponse {
let mut encoded = String::new();
encode(&mut encoded, &req.state().metrics().registry).unwrap();
let response = tide::Response::builder(tide::StatusCode::Ok)
.body(encoded)
.content_type("application/openmetrics-text; version=1.0.0; charset=utf-8")
.build();
Ok(response)
encode(&mut encoded, &state.metrics().registry).unwrap();
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
"application/openmetrics-text; version=1.0.0; charset=utf-8"
.parse()
.unwrap(),
);
(headers, encoded)
}
6 changes: 3 additions & 3 deletions src/notifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {

let Some((timestamp, token)) = schedule.pop()? else {
info!("No tokens to notify, sleeping for a minute.");
async_std::task::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(60)).await;
continue;
};

Expand All @@ -49,7 +49,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {
"Sleeping for {} before next notification.",
humantime::format_duration(delay)
);
async_std::task::sleep(delay).await;
tokio::time::sleep(delay).await;
}

if let Err(err) = wakeup(
Expand All @@ -66,7 +66,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {

// Sleep to avoid busy looping and flooding APNS
// with requests in case of database errors.
async_std::task::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(60)).await;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ mod tests {

use tempfile::tempdir;

#[async_std::test]
#[tokio::test]
async fn test_schedule() -> Result<()> {
let dir = tempdir()?;
let db_path = dir.path().join("db.sled");
Expand Down
108 changes: 67 additions & 41 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use a2::{
Priority, PushType,
};
use anyhow::{bail, Error, Result};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use log::*;
use serde::Deserialize;
use std::str::FromStr;
Expand All @@ -11,13 +14,13 @@ use crate::metrics::Metrics;
use crate::state::State;

pub async fn start(state: State, server: String, port: u16) -> Result<()> {
let mut app = tide::with_state(state);
app.at("/").get(|_| async { Ok("Hello, world!") });
app.at("/register").post(register_device);
app.at("/notify").post(notify_device);

info!("Listening on {server}:port");
app.listen((server, port)).await?;
let app = axum::Router::new()
.route("/", get(|| async { "Hello, world!" }))
.route("/register", post(register_device))
.route("/notify", post(notify_device))
.with_state(state);
let listener = tokio::net::TcpListener::bind((server, port)).await?;
axum::serve(listener, app).await?;
Ok(())
}

Expand All @@ -26,20 +29,44 @@ struct DeviceQuery {
token: String,
}

struct AppError(anyhow::Error);

impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}

/// Registers a device for heartbeat notifications.
async fn register_device(mut req: tide::Request<State>) -> tide::Result<tide::Response> {
let query: DeviceQuery = req.body_json().await?;
async fn register_device(
axum::extract::State(state): axum::extract::State<State>,
body: String,
) -> Result<(), AppError> {
let query: DeviceQuery = serde_json::from_str(&body)?;
info!("register_device {}", query.token);

let schedule = req.state().schedule();
let schedule = state.schedule();
schedule.insert_token_now(&query.token)?;

// Flush database to ensure we don't lose this token in case of restart.
schedule.flush().await?;

req.state().metrics().heartbeat_registrations_total.inc();
state.metrics().heartbeat_registrations_total.inc();

Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(())
}

enum NotificationToken {
Expand Down Expand Up @@ -90,17 +117,17 @@ async fn notify_fcm(
_package_name: &str,
token: &str,
metrics: &Metrics,
) -> tide::Result<tide::Response> {
) -> Result<StatusCode> {
let Some(fcm_api_key) = fcm_api_key else {
warn!("Cannot notify FCM because key is not set");
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
};

if !token
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == ':' || c == '-')
{
return Ok(tide::Response::new(tide::StatusCode::Gone));
return Ok(StatusCode::GONE);
}

let url = "https://fcm.googleapis.com/v1/projects/delta-chat-fcm/messages:send";
Expand All @@ -118,23 +145,19 @@ async fn notify_fcm(
warn!("Failed to deliver FCM notification to {token}");
warn!("BODY: {body:?}");
warn!("RES: {res:?}");
return Ok(tide::Response::new(tide::StatusCode::Gone));
return Ok(StatusCode::GONE);
}
if status.is_server_error() {
warn!("Internal server error while attempting to deliver FCM notification to {token}");
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
}
info!("Delivered notification to FCM token {token}");
metrics.fcm_notifications_total.inc();
Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(StatusCode::OK)
}

async fn notify_apns(
req: tide::Request<State>,
client: a2::Client,
device_token: String,
) -> tide::Result<tide::Response> {
let schedule = req.state().schedule();
async fn notify_apns(state: State, client: a2::Client, device_token: String) -> Result<StatusCode> {
let schedule = state.schedule();
let payload = DefaultNotificationBuilder::new()
.set_title("New messages")
.set_title_loc_key("new_messages") // Localization key for the title.
Expand All @@ -148,7 +171,7 @@ async fn notify_apns(
// High priority (10).
// <https://developer.apple.com/documentation/usernotifications/sending-notification-requests-to-apns>
apns_priority: Some(Priority::High),
apns_topic: req.state().topic(),
apns_topic: state.topic(),
apns_push_type: Some(PushType::Alert),
..Default::default()
},
Expand All @@ -159,14 +182,14 @@ async fn notify_apns(
match res.code {
200 => {
info!("delivered notification for {}", device_token);
req.state().metrics().direct_notifications_total.inc();
state.metrics().direct_notifications_total.inc();
}
_ => {
warn!("unexpected status: {:?}", res);
}
}

Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(StatusCode::OK)
}
Err(ResponseError(res)) => {
info!("Removing token {} due to error {:?}.", &device_token, res);
Expand All @@ -179,21 +202,23 @@ async fn notify_apns(
error!("failed to remove {}: {:?}", &device_token, err);
}
// Return 410 Gone response so email server can remove the token.
Ok(tide::Response::new(tide::StatusCode::Gone))
Ok(StatusCode::GONE)
} else {
Ok(tide::Response::new(tide::StatusCode::InternalServerError))
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
}
Err(err) => {
error!("failed to send notification: {}, {:?}", device_token, err);
Ok(tide::Response::new(tide::StatusCode::InternalServerError))
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

/// Notifies a single device with a visible notification.
async fn notify_device(mut req: tide::Request<State>) -> tide::Result<tide::Response> {
let device_token = req.body_string().await?;
async fn notify_device(
axum::extract::State(state): axum::extract::State<State>,
device_token: String,
) -> Result<StatusCode, AppError> {
info!("Got direct notification for {device_token}.");

let device_token: NotificationToken = device_token.as_str().parse()?;
Expand All @@ -203,27 +228,28 @@ async fn notify_device(mut req: tide::Request<State>) -> tide::Result<tide::Resp
package_name,
token,
} => {
let client = req.state().fcm_client().clone();
let Ok(fcm_token) = req.state().fcm_token().await else {
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
let client = state.fcm_client().clone();
let Ok(fcm_token) = state.fcm_token().await else {
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
};
let metrics = req.state().metrics();
let metrics = state.metrics();
notify_fcm(
&client,
fcm_token.as_deref(),
&package_name,
&token,
metrics,
)
.await
.await?;
}
NotificationToken::ApnsSandbox(token) => {
let client = req.state().sandbox_client().clone();
notify_apns(req, client, token).await
let client = state.sandbox_client().clone();
notify_apns(state, client, token).await?;
}
NotificationToken::ApnsProduction(token) => {
let client = req.state().production_client().clone();
notify_apns(req, client, token).await
let client = state.production_client().clone();
notify_apns(state, client, token).await?;
}
}
Ok(StatusCode::OK)
}
2 changes: 1 addition & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;

use a2::{Client, Endpoint};
use anyhow::{Context as _, Result};
use async_std::sync::Arc;

use crate::metrics::Metrics;
use crate::schedule::Schedule;
Expand Down
Loading