Skip to content

Commit

Permalink
add cors
Browse files Browse the repository at this point in the history
  • Loading branch information
fairingrey committed Feb 8, 2022
1 parent 50b1418 commit eb24b21
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ SMTP_USERNAME=
SMTP_PASSWORD=
SMTP_SERVER=
SMTP_EMAIL=

ALLOWED_ORIGINS="https://foo.example, https://bar.example"
26 changes: 26 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2018"

[dependencies]
anyhow = "1.0.53"
axum = "0.4.5"
axum = { version = "0.4.5", features = ["headers"] }
bincode = "1.3.3"
chrono = "0.4.19"
dotenv = "0.15.0"
Expand Down Expand Up @@ -47,7 +47,11 @@ tantivy = "0.16.1"
thiserror = "1.0.30"
tokio = { version = "1.16.1", features = ["rt-multi-thread", "macros", "sync"] }
tower = "0.4.11"
tower-http = { version = "0.2.2", features = ["add-extension", "trace"] }
tower-http = { version = "0.2.2", features = [
"add-extension",
"trace",
"cors",
] }
tracing = "0.1.30"
tracing-subscriber = "0.3.8"
ulid = { version = "0.5.0", features = ["serde", "uuid"] }
Expand Down
14 changes: 9 additions & 5 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use axum::{
async_trait,
extract::{Extension, FromRequest, RequestParts},
extract::{Extension, FromRequest, RequestParts, TypedHeader},
};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::{error::MixiniError, server::State};

pub(crate) const AUTH_KEY_PREFIX: &str = "auth:";
const MIXINI_SESSION_COOKIE_NAME: &str = "mixsession";
pub(crate) const SESSION_KEY_PREFIX: &str = "session:";

#[derive(Debug)]
pub(crate) enum Auth {
Expand Down Expand Up @@ -36,14 +37,17 @@ where

let headers = req.headers().expect("another extractor took the headers");

// TODO: use cookies instead

match headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string())
{
Some(auth) => {
let key = format!("{}{}", AUTH_KEY_PREFIX, &auth);
let maybe_value: Option<Vec<u8>> = state.redis_manager.clone().get(&key).await?;
Some(session_key) => {
let qualified_key = format!("{}{}", SESSION_KEY_PREFIX, &session_key);
let maybe_value: Option<Vec<u8>> =
state.redis_manager.clone().get(&qualified_key).await?;

if let Some(raw_user_info) = maybe_value {
let user_info: UserInfo = bincode::deserialize(&raw_user_info)?;
Expand Down
6 changes: 4 additions & 2 deletions src/handlers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc;
use validator::Validate;

use crate::auth::{UserInfo, AUTH_KEY_PREFIX};
use crate::auth::{UserInfo, SESSION_KEY_PREFIX};
use crate::error::MixiniError;
use crate::handlers::{ValidatedForm, RE_PASSWORD, RE_USERNAME};
use crate::models::User;
Expand All @@ -16,6 +16,8 @@ use crate::utils::{
pass::{HASHER, PWD_SCHEME_VERSION},
};

// TODO: possibly rework this as `POST /user/login` and `DELETE /user/logout`, also rename auth to session

/// The form input of a `POST /auth` request.
#[derive(Debug, Validate, Deserialize)]
pub(crate) struct NewAuthInput {
Expand Down Expand Up @@ -87,7 +89,7 @@ pub(crate) async fn create_auth(
}

// create auth entry in redis
let key = generate_redis_key(AUTH_KEY_PREFIX);
let key = generate_redis_key(SESSION_KEY_PREFIX);
let value: Vec<u8> = bincode::serialize(&UserInfo {
id: user.id,
name: user.name,
Expand Down
32 changes: 29 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use sqlx::PgPool;
use std::{str::FromStr, sync::Arc};
use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tower_http::{
cors::{CorsLayer, Origin},
trace::TraceLayer,
};

use crate::handlers;

Expand Down Expand Up @@ -53,7 +56,7 @@ impl State {
}

/// Attempt to create a new oso instance for managing authorization schemes.
pub(crate) fn try_register_oso() -> Result<Oso> {
fn try_register_oso() -> Result<Oso> {
use crate::models::*;

let mut oso = Oso::new();
Expand All @@ -67,6 +70,28 @@ pub(crate) fn try_register_oso() -> Result<Oso> {
Ok(oso)
}

/// Attempt to setup the CORS layer.
fn try_cors_layer() -> Result<CorsLayer> {
use http::Method;

if cfg!(debug_assertions) {
Ok(CorsLayer::permissive())
} else {
let origins = std::env::var("ALLOWED_ORIGINS")?
.split(',')
.map(|s| s.trim().parse())
.collect::<Result<Vec<_>, _>>()?;

Ok(CorsLayer::new()
// allow `GET`, `POST`, `PUT`, `DELETE`
.allow_methods(vec![Method::GET, Method::POST, Method::PUT, Method::DELETE])
// allow credentials
.allow_credentials(true)
// allow requests from specified env origins
.allow_origin(Origin::list(origins)))
}
}

/// Run the server.
pub(crate) async fn run() -> Result<()> {
let addr = std::net::SocketAddr::from_str(&std::env::var("ADDR")?)?;
Expand All @@ -83,7 +108,8 @@ async fn try_app() -> Result<Router> {

let middleware_stack = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(state));
.layer(AddExtensionLayer::new(state))
.layer(try_cors_layer()?);

Ok(Router::new()
.route("/", get(|| async { "Hello, World!" }))
Expand Down
8 changes: 5 additions & 3 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ use rand::{thread_rng, Rng};
pub(crate) mod mail;
pub(crate) mod pass;

/// Generate random key for use in Redis given the prefix
pub(crate) fn generate_redis_key(prefix: &str) -> String {
const KEY_LENGTH: usize = 32;

/// Generate random key for use in Redis given a prefix
pub(crate) fn generate_redis_key(prefix: &'static str) -> String {
let key: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.take(KEY_LENGTH)
.map(char::from)
.collect();
format!("{}{}", prefix, &key)
Expand Down

0 comments on commit eb24b21

Please sign in to comment.