From 2564a8a54c02a5ff004d3ec245b1654771f5492c Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 20 Jan 2025 19:23:27 -0800 Subject: [PATCH 1/2] Fix ColumnSynthesizer --- sdgym/synthesizers/column.py | 19 +++ tests/integration/synthesizers/test_column.py | 140 ++++++++++++++---- tests/unit/synthesizers/test_column.py | 7 +- 3 files changed, 135 insertions(+), 31 deletions(-) diff --git a/sdgym/synthesizers/column.py b/sdgym/synthesizers/column.py index 020a3731..740e47e0 100644 --- a/sdgym/synthesizers/column.py +++ b/sdgym/synthesizers/column.py @@ -1,11 +1,15 @@ """ColumnSynthesizer module.""" +import logging + import pandas as pd from rdt.hyper_transformer import HyperTransformer from sklearn.mixture import GaussianMixture from sdgym.synthesizers.base import BaselineSynthesizer +LOGGER = logging.getLogger(__name__) + class ColumnSynthesizer(BaselineSynthesizer): """Synthesizer that learns each column independently. @@ -17,6 +21,21 @@ class ColumnSynthesizer(BaselineSynthesizer): def _get_trained_synthesizer(self, real_data, metadata): hyper_transformer = HyperTransformer() hyper_transformer.detect_initial_config(real_data) + supported_sdtypes = hyper_transformer._get_supported_sdtypes() + config = {} + for column_name, column in metadata.columns.items(): + sdtype = column['sdtype'] + if sdtype in supported_sdtypes: + config[column_name] = sdtype + elif column.get('pii', False): + config[column_name] = 'pii' + else: + LOGGER.info( + f'Column {column} sdtype: {sdtype} is not supported, ' + f'defaulting to inferred type.' + ) + + hyper_transformer.update_sdtypes(config) # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [ diff --git a/tests/integration/synthesizers/test_column.py b/tests/integration/synthesizers/test_column.py index 161e7f75..e22c1196 100644 --- a/tests/integration/synthesizers/test_column.py +++ b/tests/integration/synthesizers/test_column.py @@ -6,32 +6,114 @@ from sdgym.synthesizers.column import ColumnSynthesizer -def test_column_synthesizer(): - """Ensure all sdtypes can be sampled.""" - # Setup - n_samples = 10000 - num_values = np.random.normal(size=n_samples) - cat_values = np.random.choice(['a', 'b', 'c'], size=n_samples, p=[0.1, 0.2, 0.7]) - bool_values = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7]) - - dates = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01']) - date_values = np.random.choice(dates, size=n_samples, p=[0.1, 0.2, 0.7]) - - data = pd.DataFrame({ - 'num': num_values, - 'cat': cat_values, - 'bool': bool_values, - 'date': date_values, - }) - - column_synthesizer = ColumnSynthesizer() - - # Run - trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {}) - samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) - - # Assert - assert samples['num'].between(-10, 10).all() - assert ((samples['cat'] == 'a') | (samples['cat'] == 'b') | (samples['cat'] == 'c')).all() - assert ((samples['bool'] == True) | (samples['bool'] == False)).all() # noqa: E712 - assert samples['date'].between(pd.to_datetime('2019-01-01'), pd.to_datetime('2021-01-01')).all() +class TestUniformSynthesizer: + def test_column_synthesizer(self): + """Ensure all sdtypes can be sampled.""" + # Setup + n_samples = 10000 + num_values = np.random.normal(size=n_samples) + cat_values = np.random.choice(['a', 'b', 'c'], size=n_samples, p=[0.1, 0.2, 0.7]) + bool_values = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7]) + + dates = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01']) + date_values = np.random.choice(dates, size=n_samples, p=[0.1, 0.2, 0.7]) + + data = pd.DataFrame({ + 'num': num_values, + 'cat': cat_values, + 'bool': bool_values, + 'date': date_values, + }) + + column_synthesizer = ColumnSynthesizer() + + # Run + trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {}) + samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) + + # Assert + assert samples['num'].between(-10, 10).all() + assert ((samples['cat'] == 'a') | (samples['cat'] == 'b') | (samples['cat'] == 'c')).all() + assert ((samples['bool'] == True) | (samples['bool'] == False)).all() # noqa: E712 + assert ( + samples['date'] + .between(pd.to_datetime('2019-01-01'), pd.to_datetime('2021-01-01')) + .all() + ) + + def test_column_synthesizer_sdtypes(self): + """Ensure that sdtypes are taken from metadata instead of inferred GH#249.""" + # Setup + metadata = { + 'primary_key': 'guest_email', + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'has_rewards': {'sdtype': 'boolean'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'billing_address': {'sdtype': 'address', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + } + + data = { + 'guest_email': { + 0: 'michaelsanders@shaw.net', + 1: 'randy49@brown.biz', + 2: 'webermelissa@neal.com', + 3: 'gsims@terry.com', + 4: 'misty33@smith.biz', + }, + 'has_rewards': {0: False, 1: False, 2: True, 3: False, 4: False}, + 'room_type': {0: 'BASIC', 1: 'BASIC', 2: 'DELUXE', 3: 'BASIC', 4: 'BASIC'}, + 'amenities_fee': {0: 37.89, 1: 24.37, 2: 0.0, 3: np.nan, 4: 16.45}, + 'checkin_date': { + 0: '27 Dec 2020', + 1: '30 Dec 2020', + 2: '17 Sep 2020', + 3: '28 Dec 2020', + 4: '05 Apr 2020', + }, + 'checkout_date': { + 0: '29 Dec 2020', + 1: '02 Jan 2021', + 2: '18 Sep 2020', + 3: '31 Dec 2020', + 4: np.nan, + }, + 'room_rate': {0: 131.23, 1: 114.43, 2: 368.33, 3: 115.61, 4: 122.41}, + 'billing_address': { + 0: '49380 Rivers Street\nSpencerville, AK 68265', + 1: '88394 Boyle Meadows\nConleyberg, TN 22063', + 2: '0323 Lisa Station Apt. 208\nPort Thomas, LA 82585', + 3: '77 Massachusetts Ave\nCambridge, MA 02139', + 4: '1234 Corporate Drive\nBoston, MA 02116', + }, + 'credit_card_number': { + 0: 4075084747483975747, + 1: 180072822063468, + 2: 38983476971380, + 3: 4969551998845740, + 4: 3558512986488983, + }, + } + + # Run + real_data = pd.DataFrame(data) + synthesizer = ColumnSynthesizer().get_trained_synthesizer(real_data, metadata) + hyper_transformer_config = synthesizer[0].get_config() + + # Assert + config_sdtypes = hyper_transformer_config['sdtypes'] + unknown_sdtypes = ['email', 'credit_card_number', 'address'] + for column in metadata['columns']: + metadata_sdtype = metadata['columns'][column]['sdtype'] + # Only data types that are known are overridden by metadata + if metadata_sdtype not in unknown_sdtypes: + assert metadata_sdtype == config_sdtypes[column] + else: + assert config_sdtypes[column] == 'pii' diff --git a/tests/unit/synthesizers/test_column.py b/tests/unit/synthesizers/test_column.py index 88b9e158..a695cf5b 100644 --- a/tests/unit/synthesizers/test_column.py +++ b/tests/unit/synthesizers/test_column.py @@ -1,6 +1,7 @@ -from unittest.mock import Mock, patch +from unittest.mock import patch import pandas as pd +from sdv.metadata import SingleTableMetadata from sdgym.synthesizers import ColumnSynthesizer @@ -13,9 +14,11 @@ def test__get_trained_synthesizer(self, gm_mock): column_synthesizer = ColumnSynthesizer() column_synthesizer.length = 10 data = pd.DataFrame({'col1': [1, 2, 3, 4]}) + metadata = SingleTableMetadata() + metadata.add_column('col1', sdtype='numerical') # Run - column_synthesizer._get_trained_synthesizer(data, Mock()) + column_synthesizer._get_trained_synthesizer(data, metadata) # Assert gm_mock.assert_called_once_with(4) From aa500920fc241a985f3b846faf9580c40598cdee Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 22 Jan 2025 10:34:09 -0800 Subject: [PATCH 2/2] Support Metadata --- sdgym/synthesizers/column.py | 9 ++++++++- tests/unit/synthesizers/test_column.py | 23 ++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sdgym/synthesizers/column.py b/sdgym/synthesizers/column.py index 740e47e0..69233283 100644 --- a/sdgym/synthesizers/column.py +++ b/sdgym/synthesizers/column.py @@ -4,6 +4,7 @@ import pandas as pd from rdt.hyper_transformer import HyperTransformer +from sdv.metadata import Metadata from sklearn.mixture import GaussianMixture from sdgym.synthesizers.base import BaselineSynthesizer @@ -23,7 +24,13 @@ def _get_trained_synthesizer(self, real_data, metadata): hyper_transformer.detect_initial_config(real_data) supported_sdtypes = hyper_transformer._get_supported_sdtypes() config = {} - for column_name, column in metadata.columns.items(): + if isinstance(metadata, Metadata): + table_name = metadata._get_single_table_name() + columns = metadata.tables[table_name].columns + else: + columns = metadata.columns + + for column_name, column in columns.items(): sdtype = column['sdtype'] if sdtype in supported_sdtypes: config[column_name] = sdtype diff --git a/tests/unit/synthesizers/test_column.py b/tests/unit/synthesizers/test_column.py index a695cf5b..8c32375f 100644 --- a/tests/unit/synthesizers/test_column.py +++ b/tests/unit/synthesizers/test_column.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pandas as pd -from sdv.metadata import SingleTableMetadata +from sdv.metadata import Metadata, SingleTableMetadata from sdgym.synthesizers import ColumnSynthesizer @@ -13,9 +13,26 @@ def test__get_trained_synthesizer(self, gm_mock): # Setup column_synthesizer = ColumnSynthesizer() column_synthesizer.length = 10 - data = pd.DataFrame({'col1': [1, 2, 3, 4]}) + data = pd.DataFrame({'col': [1, 2, 3, 4]}) + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('col', 'table', sdtype='numerical') + + # Run + column_synthesizer._get_trained_synthesizer(data, metadata) + + # Assert + gm_mock.assert_called_once_with(4) + + @patch('sdgym.synthesizers.column.GaussianMixture') + def test__get_trained_synthesizer_single_table_metadata(self, gm_mock): + """Expect that GaussianMixture is instantiated with 4 components.""" + # Setup + column_synthesizer = ColumnSynthesizer() + column_synthesizer.length = 10 + data = pd.DataFrame({'col': [1, 2, 3, 4]}) metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='numerical') + metadata.add_column('col', sdtype='numerical') # Run column_synthesizer._get_trained_synthesizer(data, metadata)