diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 07e0e3014864..17d412a427be 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -14,7 +14,7 @@ use tracing::{debug, info}; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool}; use super::http_conn_pool::{self, poll_http2_client, Send}; -use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION}; +use super::local_conn_pool::{self, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, check_peer_addr_is_in_list, AuthError}; @@ -207,7 +207,7 @@ impl PoolingBackend { conn_info: ConnInfo, ) -> Result, HttpConnError> { info!("pool: looking for an existing connection"); - if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) { + if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); } @@ -250,7 +250,7 @@ impl PoolingBackend { &self, ctx: &RequestMonitoring, conn_info: ConnInfo, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { if let Some(client) = self.local_pool.get(ctx, &conn_info)? { return Ok(client); } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 7fa3357b5bf9..1845603bf738 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -18,7 +18,9 @@ use { std::{sync::atomic, time::Duration}, }; -use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool}; +use super::conn_pool_lib::{ + Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, GlobalConnPool, +}; use crate::context::RequestMonitoring; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::Metrics; @@ -152,53 +154,30 @@ pub(crate) fn poll_client( } .instrument(span)); - let inner = ClientInnerRemote { + let inner = ClientInnerCommon { inner: client, - session: tx, - cancel, aux, conn_id, + data: ClientDataEnum::Remote(ClientDataRemote { + session: tx, + cancel, + }), }; + Client::new(inner, conn_info, pool_clone) } -pub(crate) struct ClientInnerRemote { - inner: C, +pub(crate) struct ClientDataRemote { session: tokio::sync::watch::Sender, cancel: CancellationToken, - aux: MetricsAuxInfo, - conn_id: uuid::Uuid, } -impl ClientInnerRemote { - pub(crate) fn inner_mut(&mut self) -> &mut C { - &mut self.inner - } - - pub(crate) fn inner(&self) -> &C { - &self.inner - } - - pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender { +impl ClientDataRemote { + pub fn session(&mut self) -> &mut tokio::sync::watch::Sender { &mut self.session } - pub(crate) fn aux(&self) -> &MetricsAuxInfo { - &self.aux - } - - pub(crate) fn get_conn_id(&self) -> uuid::Uuid { - self.conn_id - } - - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() - } -} - -impl Drop for ClientInnerRemote { - fn drop(&mut self) { - // on client drop, tell the conn to shut down + pub fn cancel(&mut self) { self.cancel.cancel(); } } @@ -228,15 +207,13 @@ mod tests { } } - fn create_inner() -> ClientInnerRemote { + fn create_inner() -> ClientInnerCommon { create_inner_with(MockClient::new(false)) } - fn create_inner_with(client: MockClient) -> ClientInnerRemote { - ClientInnerRemote { + fn create_inner_with(client: MockClient) -> ClientInnerCommon { + ClientInnerCommon { inner: client, - session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()), - cancel: CancellationToken::new(), aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -244,6 +221,10 @@ mod tests { cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, conn_id: uuid::Uuid::new_v4(), + data: ClientDataEnum::Remote(ClientDataRemote { + session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()), + cancel: CancellationToken::new(), + }), } } @@ -280,7 +261,7 @@ mod tests { { let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); assert_eq!(0, pool.get_global_connections_count()); - client.inner_mut().1.discard(); + client.inner().1.discard(); // Discard should not add the connection from the pool. assert_eq!(0, pool.get_global_connections_count()); } diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 8830cddf0c12..00a8ac47681d 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -11,10 +11,13 @@ use tokio_postgres::ReadyForQueryStatus; use tracing::{debug, info, Span}; use super::backend::HttpConnError; -use super::conn_pool::ClientInnerRemote; +use super::conn_pool::ClientDataRemote; +use super::http_conn_pool::ClientDataHttp; +use super::local_conn_pool::ClientDataLocal; use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; use crate::control_plane::messages::ColdStartInfo; +use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::types::{DbName, EndpointCacheKey, RoleName}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; @@ -41,8 +44,46 @@ impl ConnInfo { } } +pub(crate) enum ClientDataEnum { + Remote(ClientDataRemote), + Local(ClientDataLocal), + #[allow(dead_code)] + Http(ClientDataHttp), +} + +pub(crate) struct ClientInnerCommon { + pub(crate) inner: C, + pub(crate) aux: MetricsAuxInfo, + pub(crate) conn_id: uuid::Uuid, + pub(crate) data: ClientDataEnum, // custom client data like session, key, jti +} + +impl Drop for ClientInnerCommon { + fn drop(&mut self) { + match &mut self.data { + ClientDataEnum::Remote(remote_data) => { + remote_data.cancel(); + } + ClientDataEnum::Local(local_data) => { + local_data.cancel(); + } + ClientDataEnum::Http(_http_data) => (), + } + } +} + +impl ClientInnerCommon { + pub(crate) fn get_conn_id(&self) -> uuid::Uuid { + self.conn_id + } + + pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum { + &mut self.data + } +} + pub(crate) struct ConnPoolEntry { - pub(crate) conn: ClientInnerRemote, + pub(crate) conn: ClientInnerCommon, pub(crate) _last_access: std::time::Instant, } @@ -55,10 +96,33 @@ pub(crate) struct EndpointConnPool { _guard: HttpEndpointPoolsGuard<'static>, global_connections_count: Arc, global_pool_size_max_conns: usize, + pool_name: String, } impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { + pub(crate) fn new( + hmap: HashMap<(DbName, RoleName), DbUserConnPool>, + tconns: usize, + max_conns_per_endpoint: usize, + global_connections_count: Arc, + max_total_conns: usize, + pname: String, + ) -> Self { + Self { + pools: hmap, + total_conns: tconns, + max_conns: max_conns_per_endpoint, + _guard: Metrics::get().proxy.http_endpoint_pools.guard(), + global_connections_count, + global_pool_size_max_conns: max_total_conns, + pool_name: pname, + } + } + + pub(crate) fn get_conn_entry( + &mut self, + db_user: (DbName, RoleName), + ) -> Option> { let Self { pools, total_conns, @@ -84,9 +148,10 @@ impl EndpointConnPool { .. } = self; if let Some(pool) = pools.get_mut(&db_user) { - let old_len = pool.conns.len(); - pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id); - let new_len = pool.conns.len(); + let old_len = pool.get_conns().len(); + pool.get_conns() + .retain(|conn| conn.conn.get_conn_id() != conn_id); + let new_len = pool.get_conns().len(); let removed = old_len - new_len; if removed > 0 { global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); @@ -103,11 +168,26 @@ impl EndpointConnPool { } } - pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerRemote) { - let conn_id = client.get_conn_id(); + pub(crate) fn get_name(&self) -> &str { + &self.pool_name + } + + pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool> { + self.pools.get(&db_user) + } - if client.is_closed() { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); + pub(crate) fn get_pool_mut( + &mut self, + db_user: (DbName, RoleName), + ) -> Option<&mut DbUserConnPool> { + self.pools.get_mut(&db_user) + } + + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerCommon) { + let conn_id = client.get_conn_id(); + let pool_name = pool.read().get_name().to_string(); + if client.inner.is_closed() { + info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name); return; } @@ -118,7 +198,7 @@ impl EndpointConnPool { .load(atomic::Ordering::Relaxed) >= global_max_conn { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); + info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name); return; } @@ -130,13 +210,13 @@ impl EndpointConnPool { if pool.total_conns < pool.max_conns { let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); - pool_entries.conns.push(ConnPoolEntry { + pool_entries.get_conns().push(ConnPoolEntry { conn: client, _last_access: std::time::Instant::now(), }); returned = true; - per_db_size = pool_entries.conns.len(); + per_db_size = pool_entries.get_conns().len(); pool.total_conns += 1; pool.global_connections_count @@ -153,9 +233,9 @@ impl EndpointConnPool { // do logging outside of the mutex if returned { - info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + info!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); } else { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); } } } @@ -176,19 +256,39 @@ impl Drop for EndpointConnPool { pub(crate) struct DbUserConnPool { pub(crate) conns: Vec>, + pub(crate) initialized: Option, // a bit ugly, exists only for local pools } impl Default for DbUserConnPool { fn default() -> Self { - Self { conns: Vec::new() } + Self { + conns: Vec::new(), + initialized: None, + } } } -impl DbUserConnPool { +pub(crate) trait DbUserConn: Default { + fn set_initialized(&mut self); + fn is_initialized(&self) -> bool; + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize; + fn get_conn_entry(&mut self, conns: &mut usize) -> (Option>, usize); + fn get_conns(&mut self) -> &mut Vec>; +} + +impl DbUserConn for DbUserConnPool { + fn set_initialized(&mut self) { + self.initialized = Some(true); + } + + fn is_initialized(&self) -> bool { + self.initialized.unwrap_or(false) + } + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { let old_len = self.conns.len(); - self.conns.retain(|conn| !conn.conn.is_closed()); + self.conns.retain(|conn| !conn.conn.inner.is_closed()); let new_len = self.conns.len(); let removed = old_len - new_len; @@ -196,10 +296,7 @@ impl DbUserConnPool { removed } - pub(crate) fn get_conn_entry( - &mut self, - conns: &mut usize, - ) -> (Option>, usize) { + fn get_conn_entry(&mut self, conns: &mut usize) -> (Option>, usize) { let mut removed = self.clear_closed_clients(conns); let conn = self.conns.pop(); if conn.is_some() { @@ -215,6 +312,10 @@ impl DbUserConnPool { (conn, removed) } + + fn get_conns(&mut self) -> &mut Vec> { + &mut self.conns + } } pub(crate) struct GlobalConnPool { @@ -278,6 +379,60 @@ impl GlobalConnPool { self.config.pool_options.idle_timeout } + pub(crate) fn get( + self: &Arc, + ctx: &RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result>, HttpConnError> { + let mut client: Option> = None; + let Some(endpoint) = conn_info.endpoint_cache_key() else { + return Ok(None); + }; + + let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); + if let Some(entry) = endpoint_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn); + } + let endpoint_pool = Arc::downgrade(&endpoint_pool); + + // ok return cached connection if found and establish a new one otherwise + if let Some(mut client) = client { + if client.inner.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); + } + tracing::Span::current() + .record("conn_id", tracing::field::display(client.get_conn_id())); + tracing::Span::current().record( + "pid", + tracing::field::display(client.inner.get_process_id()), + ); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "pool: reusing connection '{conn_info}'" + ); + + match client.get_data() { + ClientDataEnum::Local(data) => { + data.session().send(ctx.session_id())?; + } + + ClientDataEnum::Remote(data) => { + data.session().send(ctx.session_id())?; + } + ClientDataEnum::Http(_) => (), + } + + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); + } + Ok(None) + } + pub(crate) fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); @@ -374,6 +529,7 @@ impl GlobalConnPool { _guard: Metrics::get().proxy.http_endpoint_pools.guard(), global_connections_count: self.global_connections_count.clone(), global_pool_size_max_conns: self.config.pool_options.max_total_conns, + pool_name: String::from("remote"), })); // find or create a pool for this endpoint @@ -400,55 +556,23 @@ impl GlobalConnPool { pool } +} - pub(crate) fn get( - self: &Arc, - ctx: &RequestMonitoring, - conn_info: &ConnInfo, - ) -> Result>, HttpConnError> { - let mut client: Option> = None; - let Some(endpoint) = conn_info.endpoint_cache_key() else { - return Ok(None); - }; - - let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); - if let Some(entry) = endpoint_pool - .write() - .get_conn_entry(conn_info.db_and_user()) - { - client = Some(entry.conn); - } - let endpoint_pool = Arc::downgrade(&endpoint_pool); - - // ok return cached connection if found and establish a new one otherwise - if let Some(mut client) = client { - if client.is_closed() { - info!("pool: cached connection '{conn_info}' is closed, opening a new one"); - return Ok(None); - } - tracing::Span::current() - .record("conn_id", tracing::field::display(client.get_conn_id())); - tracing::Span::current().record( - "pid", - tracing::field::display(client.inner().get_process_id()), - ); - info!( - cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), - "pool: reusing connection '{conn_info}'" - ); +pub(crate) struct Client { + span: Span, + inner: Option>, + conn_info: ConnInfo, + pool: Weak>>, +} - client.session().send(ctx.session_id())?; - ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); - ctx.success(); - return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); - } - Ok(None) - } +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>>, } impl Client { pub(crate) fn new( - inner: ClientInnerRemote, + inner: ClientInnerCommon, conn_info: ConnInfo, pool: Weak>>, ) -> Self { @@ -460,7 +584,18 @@ impl Client { } } - pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) { + pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon, Discard<'_, C>) { + let Self { + inner, + pool, + conn_info, + span: _, + } = self; + let inner_m = inner.as_mut().expect("client inner should not be removed"); + (inner_m, Discard { conn_info, pool }) + } + + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, @@ -468,12 +603,11 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - let inner_ref = inner.inner_mut(); - (inner_ref, Discard { conn_info, pool }) + (&mut inner.inner, Discard { conn_info, pool }) } pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux(); + let aux = &self.inner.as_ref().unwrap().aux; USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, @@ -498,13 +632,6 @@ impl Client { } } -pub(crate) struct Client { - span: Span, - inner: Option>, - conn_info: ConnInfo, - pool: Weak>>, -} - impl Drop for Client { fn drop(&mut self) { if let Some(drop) = self.do_drop() { @@ -517,10 +644,11 @@ impl Deref for Client { type Target = C; fn deref(&self) -> &Self::Target { - self.inner + &self + .inner .as_ref() .expect("client inner should not be removed") - .inner() + .inner } } @@ -539,11 +667,6 @@ impl ClientInnerExt for tokio_postgres::Client { } } -pub(crate) struct Discard<'a, C: ClientInnerExt> { - conn_info: &'a ConnInfo, - pool: &'a mut Weak>>, -} - impl Discard<'_, C> { pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 934a50c14ff5..5052498f9942 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -7,9 +7,11 @@ use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use rand::Rng; +use std::result::Result::Ok; use tokio::net::TcpStream; use tracing::{debug, error, info, info_span, Instrument}; +use super::backend::HttpConnError; use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; @@ -28,6 +30,8 @@ pub(crate) struct ConnPoolEntry { aux: MetricsAuxInfo, } +pub(crate) struct ClientDataHttp(); + // Per-endpoint connection pool // Number of open connections is limited by the `max_conns_per_endpoint`. pub(crate) struct EndpointConnPool { @@ -206,14 +210,22 @@ impl GlobalConnPool { } } + #[expect(unused_results)] pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, - ) -> Option> { - let endpoint = conn_info.endpoint_cache_key()?; + ) -> Result>, HttpConnError> { + let result: Result>, HttpConnError>; + let Some(endpoint) = conn_info.endpoint_cache_key() else { + result = Ok(None); + return result; + }; let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); - let client = endpoint_pool.write().get_conn_entry()?; + let Some(client) = endpoint_pool.write().get_conn_entry() else { + result = Ok(None); + return result; + }; tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); info!( @@ -222,7 +234,7 @@ impl GlobalConnPool { ); ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); ctx.success(); - Some(Client::new(client.conn, client.aux)) + Ok(Some(Client::new(client.conn, client.aux))) } fn get_or_create_endpoint_pool( diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 064e7db7b3a7..99d4329f8811 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -11,7 +11,8 @@ use std::collections::HashMap; use std::pin::pin; -use std::sync::{Arc, Weak}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; use std::task::{ready, Poll}; use std::time::Duration; @@ -26,177 +27,42 @@ use signature::Signer; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_postgres::{AsyncMessage, Socket}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, info_span, warn, Instrument, Span}; +use tracing::{error, info, info_span, warn, Instrument}; use super::backend::HttpConnError; -use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; +use super::conn_pool_lib::{ + Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn, + EndpointConnPool, +}; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; -use crate::types::{DbName, RoleName}; -use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; pub(crate) const EXT_NAME: &str = "pg_session_jwt"; pub(crate) const EXT_VERSION: &str = "0.1.2"; pub(crate) const EXT_SCHEMA: &str = "auth"; -struct ConnPoolEntry { - conn: ClientInner, - _last_access: std::time::Instant, -} - -// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool -// Number of open connections is limited by the `max_conns_per_endpoint`. -pub(crate) struct EndpointConnPool { - pools: HashMap<(DbName, RoleName), DbUserConnPool>, - total_conns: usize, - max_conns: usize, - global_pool_size_max_conns: usize, -} - -impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { - let Self { - pools, total_conns, .. - } = self; - pools - .get_mut(&db_user) - .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) - } - - fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { - let Self { - pools, total_conns, .. - } = self; - if let Some(pool) = pools.get_mut(&db_user) { - let old_len = pool.conns.len(); - pool.conns.retain(|conn| conn.conn.conn_id != conn_id); - let new_len = pool.conns.len(); - let removed = old_len - new_len; - if removed > 0 { - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - } - *total_conns -= removed; - removed > 0 - } else { - false - } - } - - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { - let conn_id = client.conn_id; - - if client.is_closed() { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed"); - return; - } - let global_max_conn = pool.read().global_pool_size_max_conns; - if pool.read().total_conns >= global_max_conn { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full"); - return; - } - - // return connection to the pool - let mut returned = false; - let mut per_db_size = 0; - let total_conns = { - let mut pool = pool.write(); - - if pool.total_conns < pool.max_conns { - let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); - - returned = true; - per_db_size = pool_entries.conns.len(); - - pool.total_conns += 1; - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .inc(); - } - - pool.total_conns - }; - - // do logging outside of the mutex - if returned { - info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); - } else { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); - } - } -} - -impl Drop for EndpointConnPool { - fn drop(&mut self) { - if self.total_conns > 0 { - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(self.total_conns as i64); - } - } -} - -pub(crate) struct DbUserConnPool { - conns: Vec>, - - // true if we have definitely installed the extension and - // granted the role access to the auth schema. - initialized: bool, -} - -impl Default for DbUserConnPool { - fn default() -> Self { - Self { - conns: Vec::new(), - initialized: false, - } - } +pub(crate) struct ClientDataLocal { + session: tokio::sync::watch::Sender, + cancel: CancellationToken, + key: SigningKey, + jti: u64, } -impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { - let old_len = self.conns.len(); - - self.conns.retain(|conn| !conn.conn.is_closed()); - - let new_len = self.conns.len(); - let removed = old_len - new_len; - *conns -= removed; - removed +impl ClientDataLocal { + pub fn session(&mut self) -> &mut tokio::sync::watch::Sender { + &mut self.session } - fn get_conn_entry(&mut self, conns: &mut usize) -> Option> { - let mut removed = self.clear_closed_clients(conns); - let conn = self.conns.pop(); - if conn.is_some() { - *conns -= 1; - removed += 1; - } - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - conn + pub fn cancel(&mut self) { + self.cancel.cancel(); } } pub(crate) struct LocalConnPool { - global_pool: RwLock>, + global_pool: Arc>>, config: &'static crate::config::HttpConfig, } @@ -204,12 +70,14 @@ pub(crate) struct LocalConnPool { impl LocalConnPool { pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { Arc::new(Self { - global_pool: RwLock::new(EndpointConnPool { - pools: HashMap::new(), - total_conns: 0, - max_conns: config.pool_options.max_conns_per_endpoint, - global_pool_size_max_conns: config.pool_options.max_total_conns, - }), + global_pool: Arc::new(RwLock::new(EndpointConnPool::new( + HashMap::new(), + 0, + config.pool_options.max_conns_per_endpoint, + Arc::new(AtomicUsize::new(0)), + config.pool_options.max_total_conns, + String::from("local_pool"), + ))), config, }) } @@ -222,7 +90,7 @@ impl LocalConnPool { self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, - ) -> Result>, HttpConnError> { + ) -> Result>, HttpConnError> { let client = self .global_pool .write() @@ -230,12 +98,14 @@ impl LocalConnPool { .map(|entry| entry.conn); // ok return cached connection if found and establish a new one otherwise - if let Some(client) = client { - if client.is_closed() { + if let Some(mut client) = client { + if client.inner.is_closed() { info!("local_pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); } - tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); + + tracing::Span::current() + .record("conn_id", tracing::field::display(client.get_conn_id())); tracing::Span::current().record( "pid", tracing::field::display(client.inner.get_process_id()), @@ -244,47 +114,59 @@ impl LocalConnPool { cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), "local_pool: reusing connection '{conn_info}'" ); - client.session.send(ctx.session_id())?; + + match client.get_data() { + ClientDataEnum::Local(data) => { + data.session().send(ctx.session_id())?; + } + + ClientDataEnum::Remote(data) => { + data.session().send(ctx.session_id())?; + } + ClientDataEnum::Http(_) => (), + } + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); ctx.success(); - return Ok(Some(LocalClient::new( + + return Ok(Some(Client::new( client, conn_info.clone(), - Arc::downgrade(self), + Arc::downgrade(&self.global_pool), ))); } Ok(None) } pub(crate) fn initialized(self: &Arc, conn_info: &ConnInfo) -> bool { - self.global_pool - .read() - .pools - .get(&conn_info.db_and_user()) - .map_or(false, |pool| pool.initialized) + if let Some(pool) = self.global_pool.read().get_pool(conn_info.db_and_user()) { + return pool.is_initialized(); + } + false } pub(crate) fn set_initialized(self: &Arc, conn_info: &ConnInfo) { - self.global_pool + if let Some(pool) = self + .global_pool .write() - .pools - .entry(conn_info.db_and_user()) - .or_default() - .initialized = true; + .get_pool_mut(conn_info.db_and_user()) + { + pool.set_initialized(); + } } } #[allow(clippy::too_many_arguments)] -pub(crate) fn poll_client( - global_pool: Arc>, +pub(crate) fn poll_client( + global_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, - client: tokio_postgres::Client, + client: C, mut connection: tokio_postgres::Connection, key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> LocalClient { +) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let mut session_id = ctx.session_id(); let (tx, mut rx) = tokio::sync::watch::channel(session_id); @@ -377,111 +259,47 @@ pub(crate) fn poll_client( } .instrument(span)); - let inner = ClientInner { + let inner = ClientInnerCommon { inner: client, - session: tx, - cancel, aux, conn_id, - key, - jti: 0, + data: ClientDataEnum::Local(ClientDataLocal { + session: tx, + cancel, + key, + jti: 0, + }), }; - LocalClient::new(inner, conn_info, pool_clone) -} - -pub(crate) struct ClientInner { - inner: C, - session: tokio::sync::watch::Sender, - cancel: CancellationToken, - aux: MetricsAuxInfo, - conn_id: uuid::Uuid, - // needed for pg_session_jwt state - key: SigningKey, - jti: u64, -} - -impl Drop for ClientInner { - fn drop(&mut self) { - // on client drop, tell the conn to shut down - self.cancel.cancel(); - } + Client::new( + inner, + conn_info, + Arc::downgrade(&pool_clone.upgrade().unwrap().global_pool), + ) } -impl ClientInner { - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() - } -} - -impl ClientInner { +impl ClientInnerCommon { pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { - self.jti += 1; - let token = resign_jwt(&self.key, payload, self.jti)?; - - // initiates the auth session - self.inner.simple_query("discard all").await?; - self.inner - .query( - "select auth.jwt_session_init($1)", - &[&token as &(dyn ToSql + Sync)], - ) - .await?; - - let pid = self.inner.get_process_id(); - info!(pid, jti = self.jti, "user session state init"); - - Ok(()) - } -} - -pub(crate) struct LocalClient { - span: Span, - inner: Option>, - conn_info: ConnInfo, - pool: Weak>, -} - -pub(crate) struct Discard<'a, C: ClientInnerExt> { - conn_info: &'a ConnInfo, - pool: &'a mut Weak>, -} - -impl LocalClient { - pub(self) fn new( - inner: ClientInner, - conn_info: ConnInfo, - pool: Weak>, - ) -> Self { - Self { - inner: Some(inner), - span: Span::current(), - conn_info, - pool, + if let ClientDataEnum::Local(local_data) = &mut self.data { + local_data.jti += 1; + let token = resign_jwt(&local_data.key, payload, local_data.jti)?; + + // initiates the auth session + self.inner.simple_query("discard all").await?; + self.inner + .query( + "select auth.jwt_session_init($1)", + &[&token as &(dyn ToSql + Sync)], + ) + .await?; + + let pid = self.inner.get_process_id(); + info!(pid, jti = local_data.jti, "user session state init"); + Ok(()) + } else { + panic!("unexpected client data type"); } } - - pub(crate) fn client_inner(&mut self) -> (&mut ClientInner, Discard<'_, C>) { - let Self { - inner, - pool, - conn_info, - span: _, - } = self; - let inner_m = inner.as_mut().expect("client inner should not be removed"); - (inner_m, Discard { conn_info, pool }) - } - - pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { - let Self { - inner, - pool, - conn_info, - span: _, - } = self; - let inner = inner.as_mut().expect("client inner should not be removed"); - (&mut inner.inner, Discard { conn_info, pool }) - } } /// implements relatively efficient in-place json object key upserting @@ -547,58 +365,6 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { jwt } -impl LocalClient { - pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; - USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id, - branch_id: aux.branch_id, - }) - } - - fn do_drop(&mut self) -> Option> { - let conn_info = self.conn_info.clone(); - let client = self - .inner - .take() - .expect("client inner should not be removed"); - if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { - let current_span = self.span.clone(); - // return connection to the pool - return Some(move || { - let _span = current_span.enter(); - EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client); - }); - } - None - } -} - -impl Drop for LocalClient { - fn drop(&mut self) { - if let Some(drop) = self.do_drop() { - tokio::task::spawn_blocking(drop); - } - } -} - -impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!( - "local_pool: throwing away connection '{conn_info}' because connection is not idle" - ); - } - } - pub(crate) fn discard(&mut self) { - let conn_info = &self.conn_info; - if std::mem::take(self.pool).strong_count() > 0 { - info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); - } - } -} - #[cfg(test)] mod tests { use p256::ecdsa::SigningKey; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8e2d4c126ae2..eec660bb61a9 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -29,7 +29,6 @@ use super::conn_pool::{AuthData, ConnInfoWithAuth}; use super::conn_pool_lib::{self, ConnInfo}; use super::http_util::json_response; use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError}; -use super::local_conn_pool; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{endpoint_sni, ComputeUserInfoParseError}; use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig}; @@ -1024,12 +1023,12 @@ async fn query_to_json( enum Client { Remote(conn_pool_lib::Client), - Local(local_conn_pool::LocalClient), + Local(conn_pool_lib::Client), } enum Discard<'a> { Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>), - Local(local_conn_pool::Discard<'a, tokio_postgres::Client>), + Local(conn_pool_lib::Discard<'a, tokio_postgres::Client>), } impl Client { @@ -1043,7 +1042,7 @@ impl Client { fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { match self { Client::Remote(client) => { - let (c, d) = client.inner_mut(); + let (c, d) = client.inner(); (c, Discard::Remote(d)) } Client::Local(local_client) => {