From eb24b21ebae3979154f65b8ef144857f9dcbef54 Mon Sep 17 00:00:00 2001 From: Allen Date: Tue, 8 Feb 2022 15:46:58 -0800 Subject: [PATCH] add cors --- .env.example | 2 ++ Cargo.lock | 26 ++++++++++++++++++++++++++ Cargo.toml | 8 ++++++-- src/auth.rs | 14 +++++++++----- src/handlers/auth.rs | 6 ++++-- src/server.rs | 32 +++++++++++++++++++++++++++++--- src/utils/mod.rs | 8 +++++--- 7 files changed, 81 insertions(+), 15 deletions(-) diff --git a/.env.example b/.env.example index ee483e3..f85ad57 100644 --- a/.env.example +++ b/.env.example @@ -7,3 +7,5 @@ SMTP_USERNAME= SMTP_PASSWORD= SMTP_SERVER= SMTP_EMAIL= + +ALLOWED_ORIGINS="https://foo.example, https://bar.example" diff --git a/Cargo.lock b/Cargo.lock index 4fcb01b..ae646ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,7 @@ dependencies = [ "bitflags", "bytes", "futures-util", + "headers", "http", "http-body", "hyper", @@ -782,6 +783,31 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "headers" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c84c647447a07ca16f5fbd05b633e535cc41a08d2d74ab1e08648df53be9cb89" +dependencies = [ + "base64", + "bitflags", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha-1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.3.3" diff --git a/Cargo.toml b/Cargo.toml index dc0b27e..e98fbfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" [dependencies] anyhow = "1.0.53" -axum = "0.4.5" +axum = { version = "0.4.5", features = ["headers"] } bincode = "1.3.3" chrono = "0.4.19" dotenv = "0.15.0" @@ -47,7 +47,11 @@ tantivy = "0.16.1" thiserror = "1.0.30" tokio = { version = "1.16.1", features = ["rt-multi-thread", "macros", "sync"] } tower = "0.4.11" -tower-http = { version = "0.2.2", features = ["add-extension", "trace"] } +tower-http = { version = "0.2.2", features = [ + "add-extension", + "trace", + "cors", +] } tracing = "0.1.30" tracing-subscriber = "0.3.8" ulid = { version = "0.5.0", features = ["serde", "uuid"] } diff --git a/src/auth.rs b/src/auth.rs index 6ae02c8..10937f0 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,6 +1,6 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{Extension, FromRequest, RequestParts, TypedHeader}, }; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; @@ -8,7 +8,8 @@ use uuid::Uuid; use crate::{error::MixiniError, server::State}; -pub(crate) const AUTH_KEY_PREFIX: &str = "auth:"; +const MIXINI_SESSION_COOKIE_NAME: &str = "mixsession"; +pub(crate) const SESSION_KEY_PREFIX: &str = "session:"; #[derive(Debug)] pub(crate) enum Auth { @@ -36,14 +37,17 @@ where let headers = req.headers().expect("another extractor took the headers"); + // TODO: use cookies instead + match headers .get(http::header::AUTHORIZATION) .and_then(|value| value.to_str().ok()) .map(|value| value.to_string()) { - Some(auth) => { - let key = format!("{}{}", AUTH_KEY_PREFIX, &auth); - let maybe_value: Option> = state.redis_manager.clone().get(&key).await?; + Some(session_key) => { + let qualified_key = format!("{}{}", SESSION_KEY_PREFIX, &session_key); + let maybe_value: Option> = + state.redis_manager.clone().get(&qualified_key).await?; if let Some(raw_user_info) = maybe_value { let user_info: UserInfo = bincode::deserialize(&raw_user_info)?; diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index 1660295..094b6d3 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use validator::Validate; -use crate::auth::{UserInfo, AUTH_KEY_PREFIX}; +use crate::auth::{UserInfo, SESSION_KEY_PREFIX}; use crate::error::MixiniError; use crate::handlers::{ValidatedForm, RE_PASSWORD, RE_USERNAME}; use crate::models::User; @@ -16,6 +16,8 @@ use crate::utils::{ pass::{HASHER, PWD_SCHEME_VERSION}, }; +// TODO: possibly rework this as `POST /user/login` and `DELETE /user/logout`, also rename auth to session + /// The form input of a `POST /auth` request. #[derive(Debug, Validate, Deserialize)] pub(crate) struct NewAuthInput { @@ -87,7 +89,7 @@ pub(crate) async fn create_auth( } // create auth entry in redis - let key = generate_redis_key(AUTH_KEY_PREFIX); + let key = generate_redis_key(SESSION_KEY_PREFIX); let value: Vec = bincode::serialize(&UserInfo { id: user.id, name: user.name, diff --git a/src/server.rs b/src/server.rs index f7b53bf..ca66a53 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,7 +12,10 @@ use sqlx::PgPool; use std::{str::FromStr, sync::Arc}; use tokio::sync::Mutex; use tower::ServiceBuilder; -use tower_http::trace::TraceLayer; +use tower_http::{ + cors::{CorsLayer, Origin}, + trace::TraceLayer, +}; use crate::handlers; @@ -53,7 +56,7 @@ impl State { } /// Attempt to create a new oso instance for managing authorization schemes. -pub(crate) fn try_register_oso() -> Result { +fn try_register_oso() -> Result { use crate::models::*; let mut oso = Oso::new(); @@ -67,6 +70,28 @@ pub(crate) fn try_register_oso() -> Result { Ok(oso) } +/// Attempt to setup the CORS layer. +fn try_cors_layer() -> Result { + use http::Method; + + if cfg!(debug_assertions) { + Ok(CorsLayer::permissive()) + } else { + let origins = std::env::var("ALLOWED_ORIGINS")? + .split(',') + .map(|s| s.trim().parse()) + .collect::, _>>()?; + + Ok(CorsLayer::new() + // allow `GET`, `POST`, `PUT`, `DELETE` + .allow_methods(vec![Method::GET, Method::POST, Method::PUT, Method::DELETE]) + // allow credentials + .allow_credentials(true) + // allow requests from specified env origins + .allow_origin(Origin::list(origins))) + } +} + /// Run the server. pub(crate) async fn run() -> Result<()> { let addr = std::net::SocketAddr::from_str(&std::env::var("ADDR")?)?; @@ -83,7 +108,8 @@ async fn try_app() -> Result { let middleware_stack = ServiceBuilder::new() .layer(TraceLayer::new_for_http()) - .layer(AddExtensionLayer::new(state)); + .layer(AddExtensionLayer::new(state)) + .layer(try_cors_layer()?); Ok(Router::new() .route("/", get(|| async { "Hello, World!" })) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1055379..435775e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,11 +5,13 @@ use rand::{thread_rng, Rng}; pub(crate) mod mail; pub(crate) mod pass; -/// Generate random key for use in Redis given the prefix -pub(crate) fn generate_redis_key(prefix: &str) -> String { +const KEY_LENGTH: usize = 32; + +/// Generate random key for use in Redis given a prefix +pub(crate) fn generate_redis_key(prefix: &'static str) -> String { let key: String = thread_rng() .sample_iter(&Alphanumeric) - .take(32) + .take(KEY_LENGTH) .map(char::from) .collect(); format!("{}{}", prefix, &key)