Skip to content

Commit

Permalink
durability impl - bincode encode, decode for Db types
Browse files Browse the repository at this point in the history
  • Loading branch information
justcoon committed Feb 2, 2025
1 parent ba41cce commit 9775ef1
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 40 deletions.
4 changes: 2 additions & 2 deletions golem-worker-executor-base/src/durable_host/rdbms/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ impl From<mysql_types::DbValue> for DbValue {
}
}

impl From<crate::services::rdbms::DbRow<mysql_types::DbValue>> for DbRow {
fn from(value: crate::services::rdbms::DbRow<mysql_types::DbValue>) -> Self {
impl From<crate::services::rdbms::DbRow<MysqlType>> for DbRow {
fn from(value: crate::services::rdbms::DbRow<MysqlType>) -> Self {
Self {
values: value.values.into_iter().map(|v| v.into()).collect(),
}
Expand Down
4 changes: 2 additions & 2 deletions golem-worker-executor-base/src/durable_host/rdbms/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ fn to_bound(
}

fn from_db_rows(
values: Vec<crate::services::rdbms::DbRow<postgres_types::DbValue>>,
values: Vec<crate::services::rdbms::DbRow<PostgresType>>,
resource_table: &mut ResourceTable,
) -> Result<Vec<DbRow>, String> {
let mut result: Vec<DbRow> = Vec::with_capacity(values.len());
Expand All @@ -1304,7 +1304,7 @@ fn from_db_rows(
}

fn from_db_row(
value: crate::services::rdbms::DbRow<postgres_types::DbValue>,
value: crate::services::rdbms::DbRow<PostgresType>,
resource_table: &mut ResourceTable,
) -> Result<DbRow, String> {
let mut values: Vec<DbValue> = Vec::with_capacity(value.values.len());
Expand Down
39 changes: 28 additions & 11 deletions golem-worker-executor-base/src/services/rdbms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::services::golem_config::RdbmsConfig;
use crate::services::rdbms::mysql::MysqlType;
use crate::services::rdbms::postgres::PostgresType;
use async_trait::async_trait;
use bincode::{Decode, Encode};
use bincode::{BorrowDecode, Decode, Encode};
use golem_common::model::WorkerId;
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
Expand All @@ -33,8 +33,24 @@ use std::sync::Arc;
use url::Url;

pub trait RdbmsType: Debug + Display + Default + Send {
type DbColumn: Clone + Send + Sync + PartialEq + Debug + Decode + Encode;
type DbValue: Clone + Send + Sync + PartialEq + Debug + Decode + Encode;
type DbColumn: Clone
+ Send
+ Sync
+ PartialEq
+ Debug
+ Decode
+ for<'de> BorrowDecode<'de>
+ Encode
+ 'static;
type DbValue: Clone
+ Send
+ Sync
+ PartialEq
+ Debug
+ Decode
+ for<'de> BorrowDecode<'de>
+ Encode
+ 'static;
}

#[derive(Clone)]
Expand Down Expand Up @@ -160,8 +176,9 @@ impl RdbmsService for RdbmsServiceDefault {
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, Encode, Decode)]
pub struct RdbmsPoolKey {
#[bincode(with_serde)]
pub address: Url,
}

Expand Down Expand Up @@ -232,26 +249,26 @@ impl Display for RdbmsPoolKey {
}
}

#[derive(Clone, Debug, PartialEq)]
pub struct DbRow<V> {
pub values: Vec<V>,
#[derive(Clone, Debug, PartialEq, Encode, Decode)]
pub struct DbRow<T: RdbmsType> {
pub values: Vec<T::DbValue>,
}

#[async_trait]
pub trait DbResultStream<T: RdbmsType> {
async fn get_columns(&self) -> Result<Vec<T::DbColumn>, Error>;

async fn get_next(&self) -> Result<Option<Vec<DbRow<T::DbValue>>>, Error>;
async fn get_next(&self) -> Result<Option<Vec<DbRow<T>>>, Error>;
}

#[derive(Clone, Debug, PartialEq)]
pub struct DbResult<T: RdbmsType> {
pub columns: Vec<T::DbColumn>,
pub rows: Vec<DbRow<T::DbValue>>,
pub rows: Vec<DbRow<T>>,
}

impl<T: RdbmsType> DbResult<T> {
pub fn new(columns: Vec<T::DbColumn>, rows: Vec<DbRow<T::DbValue>>) -> Self {
pub fn new(columns: Vec<T::DbColumn>, rows: Vec<DbRow<T>>) -> Self {
Self { columns, rows }
}

Expand All @@ -264,7 +281,7 @@ impl<T: RdbmsType> DbResult<T> {
result_set: Arc<dyn DbResultStream<T> + Send + Sync>,
) -> Result<DbResult<T>, Error> {
let columns = result_set.get_columns().await?;
let mut rows: Vec<DbRow<T::DbValue>> = vec![];
let mut rows: Vec<DbRow<T>> = vec![];

while let Some(vs) = result_set.get_next().await? {
rows.extend(vs);
Expand Down
2 changes: 1 addition & 1 deletion golem-worker-executor-base/src/services/rdbms/mysql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;

pub(crate) const MYSQL: &str = "mysql";

#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Default, PartialEq)]
pub struct MysqlType;

impl MysqlType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ fn bind_value(
}
}

impl TryFrom<&sqlx::mysql::MySqlRow> for DbRow<DbValue> {
impl TryFrom<&sqlx::mysql::MySqlRow> for DbRow<MysqlType> {
type Error = String;

fn try_from(value: &sqlx::mysql::MySqlRow) -> Result<Self, Self::Error> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;

pub(crate) const POSTGRES: &str = "postgres";

#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Default, PartialEq)]
pub struct PostgresType;

impl PostgresType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ fn get_range<T>(value: PgCustomRange<T>, f: impl Fn(T) -> DbValue + Clone) -> Db
DbValue::Range(Range::new(name, value))
}

impl TryFrom<&sqlx::postgres::PgRow> for DbRow<DbValue> {
impl TryFrom<&sqlx::postgres::PgRow> for DbRow<PostgresType> {
type Error = String;

fn try_from(value: &sqlx::postgres::PgRow) -> Result<Self, Self::Error> {
Expand Down
14 changes: 7 additions & 7 deletions golem-worker-executor-base/src/services/rdbms/sqlx_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,20 +736,20 @@ where
pub struct SqlxDbResultStream<'q, T: RdbmsType, DB: Database> {
rdbms_type: T,
columns: Vec<T::DbColumn>,
first_rows: Arc<async_mutex::Mutex<Option<Vec<DbRow<T::DbValue>>>>>,
first_rows: Arc<async_mutex::Mutex<Option<Vec<DbRow<T>>>>>,
row_stream: Arc<async_mutex::Mutex<BoxStream<'q, Vec<Result<DB::Row, sqlx::Error>>>>>,
}

impl<'q, T, DB> SqlxDbResultStream<'q, T, DB>
where
T: RdbmsType + Sync,
DB: Database,
DbRow<T::DbValue>: for<'a> TryFrom<&'a DB::Row, Error = String>,
DbRow<T>: for<'a> TryFrom<&'a DB::Row, Error = String>,
T::DbColumn: for<'a> TryFrom<&'a DB::Column, Error = String>,
{
fn new(
columns: Vec<T::DbColumn>,
first_rows: Vec<DbRow<T::DbValue>>,
first_rows: Vec<DbRow<T>>,
row_stream: BoxStream<'q, Vec<Result<DB::Row, sqlx::Error>>>,
) -> Self {
let rdbms_type = T::default();
Expand Down Expand Up @@ -793,16 +793,16 @@ where
#[async_trait]
impl<T, DB> DbResultStream<T> for SqlxDbResultStream<'_, T, DB>
where
T: RdbmsType + Sync,
T: RdbmsType + Sync + Clone,
DB: Database,
DbRow<T::DbValue>: for<'a> TryFrom<&'a DB::Row, Error = String>,
DbRow<T>: for<'a> TryFrom<&'a DB::Row, Error = String>,
{
async fn get_columns(&self) -> Result<Vec<T::DbColumn>, Error> {
debug!(rdbms_type = self.rdbms_type.to_string(), "get columns");
Ok(self.columns.clone())
}

async fn get_next(&self) -> Result<Option<Vec<DbRow<T::DbValue>>>, Error> {
async fn get_next(&self) -> Result<Option<Vec<DbRow<T>>>, Error> {
let mut rows = self.first_rows.lock().await;
if rows.is_some() {
debug!(
Expand Down Expand Up @@ -843,7 +843,7 @@ pub(crate) fn create_db_result<T, DB>(rows: Vec<DB::Row>) -> Result<DbResult<T>,
where
T: RdbmsType + Sync,
DB: Database,
DbRow<T::DbValue>: for<'a> TryFrom<&'a DB::Row, Error = String>,
DbRow<T>: for<'a> TryFrom<&'a DB::Row, Error = String>,
T::DbColumn: for<'a> TryFrom<&'a DB::Column, Error = String>,
{
if rows.is_empty() {
Expand Down
28 changes: 14 additions & 14 deletions golem-worker-executor-base/src/services/rdbms/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ enum StatementAction<T: RdbmsType + Clone> {
#[derive(Clone, Debug)]
struct ExpectedQueryResult<T: RdbmsType + Clone> {
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
}

impl<T: RdbmsType + Clone> ExpectedQueryResult<T> {
fn new(
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
) -> Self {
Self {
expected_rows,
Expand Down Expand Up @@ -124,7 +124,7 @@ impl<T: RdbmsType + Clone> StatementTest<T> {
statement: &'static str,
params: Vec<T::DbValue>,
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
) -> Self {
Self {
action: StatementAction::Query(ExpectedQueryResult::new(
Expand All @@ -140,7 +140,7 @@ impl<T: RdbmsType + Clone> StatementTest<T> {
statement: &'static str,
params: Vec<T::DbValue>,
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
) -> Self {
Self {
action: StatementAction::QueryStream(ExpectedQueryResult::new(
Expand All @@ -155,7 +155,7 @@ impl<T: RdbmsType + Clone> StatementTest<T> {
fn with_query_expected(
&self,
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
) -> Self {
Self {
action: StatementAction::Query(ExpectedQueryResult::new(
Expand All @@ -169,7 +169,7 @@ impl<T: RdbmsType + Clone> StatementTest<T> {
fn with_query_stream_expected(
&self,
expected_columns: Option<Vec<T::DbColumn>>,
expected_rows: Option<Vec<DbRow<T::DbValue>>>,
expected_rows: Option<Vec<DbRow<T>>>,
) -> Self {
Self {
action: StatementAction::QueryStream(ExpectedQueryResult::new(
Expand Down Expand Up @@ -251,7 +251,7 @@ async fn postgres_transaction_tests(

let count = 60;

let mut rows: Vec<DbRow<postgres_types::DbValue>> = Vec::with_capacity(count);
let mut rows: Vec<DbRow<PostgresType>> = Vec::with_capacity(count);

let mut statements: Vec<StatementTest<PostgresType>> = Vec::with_capacity(count);

Expand Down Expand Up @@ -478,7 +478,7 @@ async fn postgres_create_insert_select_test(

let count = 4;

let mut rows: Vec<DbRow<postgres_types::DbValue>> = Vec::with_capacity(count);
let mut rows: Vec<DbRow<PostgresType>> = Vec::with_capacity(count);
let mut statements: Vec<StatementTest<PostgresType>> = Vec::with_capacity(count);
for i in 0..count {
let mut params: Vec<postgres_types::DbValue> =
Expand Down Expand Up @@ -1201,7 +1201,7 @@ async fn postgres_create_insert_select_array_test(

let count = 4;

let mut rows: Vec<DbRow<postgres_types::DbValue>> = Vec::with_capacity(count);
let mut rows: Vec<DbRow<PostgresType>> = Vec::with_capacity(count);
let mut statements: Vec<StatementTest<PostgresType>> = Vec::with_capacity(count);
for i in 0..count {
let mut params: Vec<postgres_types::DbValue> =
Expand Down Expand Up @@ -1933,7 +1933,7 @@ async fn mysql_transaction_tests(mysql: &DockerMysqlRdbs, rdbms_service: &RdbmsS

let count = 60;

let mut rows: Vec<DbRow<mysql_types::DbValue>> = Vec::with_capacity(count);
let mut rows: Vec<DbRow<MysqlType>> = Vec::with_capacity(count);

let mut statements: Vec<StatementTest<MysqlType>> = Vec::with_capacity(count);

Expand Down Expand Up @@ -2097,7 +2097,7 @@ async fn mysql_create_insert_select_test(

let count = 4;

let mut rows: Vec<DbRow<mysql_types::DbValue>> = Vec::with_capacity(count);
let mut rows: Vec<DbRow<MysqlType>> = Vec::with_capacity(count);
let mut statements: Vec<StatementTest<MysqlType>> = Vec::with_capacity(count);
for i in 0..count {
let mut params: Vec<mysql_types::DbValue> =
Expand Down Expand Up @@ -2519,7 +2519,7 @@ async fn execute_rdbms_test<T: RdbmsType + Clone>(
results
}

async fn rdbms_test<T: RdbmsType + Clone>(
async fn rdbms_test<T: RdbmsType + Clone + PartialEq>(
rdbms: Arc<dyn Rdbms<T> + Send + Sync>,
db_address: &str,
test: RdbmsTest<T>,
Expand All @@ -2537,7 +2537,7 @@ async fn rdbms_test<T: RdbmsType + Clone>(
check!(!exists);
}

fn check_test_results<T: RdbmsType + Clone>(
fn check_test_results<T: RdbmsType + Clone + PartialEq>(
worker_id: &WorkerId,
test: RdbmsTest<T>,
results: Vec<Result<StatementResult<T>, Error>>,
Expand Down Expand Up @@ -2601,7 +2601,7 @@ fn check_test_results<T: RdbmsType + Clone>(
}
}

async fn rdbms_par_test<T: RdbmsType + Clone + 'static>(
async fn rdbms_par_test<T: RdbmsType + Clone + PartialEq + 'static>(
rdbms: Arc<dyn Rdbms<T> + Send + Sync>,
db_addresses: Vec<String>,
count: u8,
Expand Down

0 comments on commit 9775ef1

Please sign in to comment.