Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the ColumnSynthesizer follow the sdtypes in the metadata (not the data's dtypes) #374

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sdgym/synthesizers/column.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""ColumnSynthesizer module."""

import logging

import pandas as pd
from rdt.hyper_transformer import HyperTransformer
from sklearn.mixture import GaussianMixture

from sdgym.synthesizers.base import BaselineSynthesizer

LOGGER = logging.getLogger(__name__)


class ColumnSynthesizer(BaselineSynthesizer):
"""Synthesizer that learns each column independently.
Expand All @@ -17,6 +21,21 @@ class ColumnSynthesizer(BaselineSynthesizer):
def _get_trained_synthesizer(self, real_data, metadata):
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(real_data)
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
config = {}
for column_name, column in metadata.columns.items():
sdtype = column['sdtype']
if sdtype in supported_sdtypes:
config[column_name] = sdtype
elif column.get('pii', False):
config[column_name] = 'pii'
else:
LOGGER.info(
f'Column {column} sdtype: {sdtype} is not supported, '
f'defaulting to inferred type.'
)

hyper_transformer.update_sdtypes(config)

# This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
columns_to_remove = [
Expand Down
140 changes: 111 additions & 29 deletions tests/integration/synthesizers/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,114 @@
from sdgym.synthesizers.column import ColumnSynthesizer


def test_column_synthesizer():
"""Ensure all sdtypes can be sampled."""
# Setup
n_samples = 10000
num_values = np.random.normal(size=n_samples)
cat_values = np.random.choice(['a', 'b', 'c'], size=n_samples, p=[0.1, 0.2, 0.7])
bool_values = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7])

dates = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
date_values = np.random.choice(dates, size=n_samples, p=[0.1, 0.2, 0.7])

data = pd.DataFrame({
'num': num_values,
'cat': cat_values,
'bool': bool_values,
'date': date_values,
})

column_synthesizer = ColumnSynthesizer()

# Run
trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {})
samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples)

# Assert
assert samples['num'].between(-10, 10).all()
assert ((samples['cat'] == 'a') | (samples['cat'] == 'b') | (samples['cat'] == 'c')).all()
assert ((samples['bool'] == True) | (samples['bool'] == False)).all() # noqa: E712
assert samples['date'].between(pd.to_datetime('2019-01-01'), pd.to_datetime('2021-01-01')).all()
class TestUniformSynthesizer:
def test_column_synthesizer(self):
"""Ensure all sdtypes can be sampled."""
# Setup
n_samples = 10000
num_values = np.random.normal(size=n_samples)
cat_values = np.random.choice(['a', 'b', 'c'], size=n_samples, p=[0.1, 0.2, 0.7])
bool_values = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7])

dates = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
date_values = np.random.choice(dates, size=n_samples, p=[0.1, 0.2, 0.7])

data = pd.DataFrame({
'num': num_values,
'cat': cat_values,
'bool': bool_values,
'date': date_values,
})

column_synthesizer = ColumnSynthesizer()

# Run
trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {})
samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples)

# Assert
assert samples['num'].between(-10, 10).all()
assert ((samples['cat'] == 'a') | (samples['cat'] == 'b') | (samples['cat'] == 'c')).all()
assert ((samples['bool'] == True) | (samples['bool'] == False)).all() # noqa: E712
assert (
samples['date']
.between(pd.to_datetime('2019-01-01'), pd.to_datetime('2021-01-01'))
.all()
)

def test_column_synthesizer_sdtypes(self):
"""Ensure that sdtypes are taken from metadata instead of inferred GH#249."""
# Setup
metadata = {
'primary_key': 'guest_email',
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'columns': {
'guest_email': {'sdtype': 'email', 'pii': True},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'billing_address': {'sdtype': 'address', 'pii': True},
'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True},
},
}

data = {
'guest_email': {
0: 'michaelsanders@shaw.net',
1: 'randy49@brown.biz',
2: 'webermelissa@neal.com',
3: 'gsims@terry.com',
4: 'misty33@smith.biz',
},
'has_rewards': {0: False, 1: False, 2: True, 3: False, 4: False},
'room_type': {0: 'BASIC', 1: 'BASIC', 2: 'DELUXE', 3: 'BASIC', 4: 'BASIC'},
'amenities_fee': {0: 37.89, 1: 24.37, 2: 0.0, 3: np.nan, 4: 16.45},
'checkin_date': {
0: '27 Dec 2020',
1: '30 Dec 2020',
2: '17 Sep 2020',
3: '28 Dec 2020',
4: '05 Apr 2020',
},
'checkout_date': {
0: '29 Dec 2020',
1: '02 Jan 2021',
2: '18 Sep 2020',
3: '31 Dec 2020',
4: np.nan,
},
'room_rate': {0: 131.23, 1: 114.43, 2: 368.33, 3: 115.61, 4: 122.41},
'billing_address': {
0: '49380 Rivers Street\nSpencerville, AK 68265',
1: '88394 Boyle Meadows\nConleyberg, TN 22063',
2: '0323 Lisa Station Apt. 208\nPort Thomas, LA 82585',
3: '77 Massachusetts Ave\nCambridge, MA 02139',
4: '1234 Corporate Drive\nBoston, MA 02116',
},
'credit_card_number': {
0: 4075084747483975747,
1: 180072822063468,
2: 38983476971380,
3: 4969551998845740,
4: 3558512986488983,
},
}

# Run
real_data = pd.DataFrame(data)
synthesizer = ColumnSynthesizer().get_trained_synthesizer(real_data, metadata)
hyper_transformer_config = synthesizer[0].get_config()

# Assert
config_sdtypes = hyper_transformer_config['sdtypes']
unknown_sdtypes = ['email', 'credit_card_number', 'address']
for column in metadata['columns']:
metadata_sdtype = metadata['columns'][column]['sdtype']
# Only data types that are known are overridden by metadata
if metadata_sdtype not in unknown_sdtypes:
assert metadata_sdtype == config_sdtypes[column]
else:
assert config_sdtypes[column] == 'pii'
7 changes: 5 additions & 2 deletions tests/unit/synthesizers/test_column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import Mock, patch
from unittest.mock import patch

import pandas as pd
from sdv.metadata import SingleTableMetadata

from sdgym.synthesizers import ColumnSynthesizer

Expand All @@ -13,9 +14,11 @@ def test__get_trained_synthesizer(self, gm_mock):
column_synthesizer = ColumnSynthesizer()
column_synthesizer.length = 10
data = pd.DataFrame({'col1': [1, 2, 3, 4]})
metadata = SingleTableMetadata()
metadata.add_column('col1', sdtype='numerical')

# Run
column_synthesizer._get_trained_synthesizer(data, Mock())
column_synthesizer._get_trained_synthesizer(data, metadata)

# Assert
gm_mock.assert_called_once_with(4)
Loading