Skip to content

Commit

Permalink
Support Metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jan 22, 2025
1 parent 2564a8a commit aa50092
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
9 changes: 8 additions & 1 deletion sdgym/synthesizers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
from rdt.hyper_transformer import HyperTransformer
from sdv.metadata import Metadata
from sklearn.mixture import GaussianMixture

from sdgym.synthesizers.base import BaselineSynthesizer
Expand All @@ -23,7 +24,13 @@ def _get_trained_synthesizer(self, real_data, metadata):
hyper_transformer.detect_initial_config(real_data)
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
config = {}
for column_name, column in metadata.columns.items():
if isinstance(metadata, Metadata):
table_name = metadata._get_single_table_name()
columns = metadata.tables[table_name].columns
else:
columns = metadata.columns

for column_name, column in columns.items():
sdtype = column['sdtype']
if sdtype in supported_sdtypes:
config[column_name] = sdtype
Expand Down
23 changes: 20 additions & 3 deletions tests/unit/synthesizers/test_column.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.metadata import Metadata, SingleTableMetadata

from sdgym.synthesizers import ColumnSynthesizer

Expand All @@ -13,9 +13,26 @@ def test__get_trained_synthesizer(self, gm_mock):
# Setup
column_synthesizer = ColumnSynthesizer()
column_synthesizer.length = 10
data = pd.DataFrame({'col1': [1, 2, 3, 4]})
data = pd.DataFrame({'col': [1, 2, 3, 4]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='numerical')

# Run
column_synthesizer._get_trained_synthesizer(data, metadata)

# Assert
gm_mock.assert_called_once_with(4)

@patch('sdgym.synthesizers.column.GaussianMixture')
def test__get_trained_synthesizer_single_table_metadata(self, gm_mock):
"""Expect that GaussianMixture is instantiated with 4 components."""
# Setup
column_synthesizer = ColumnSynthesizer()
column_synthesizer.length = 10
data = pd.DataFrame({'col': [1, 2, 3, 4]})
metadata = SingleTableMetadata()
metadata.add_column('col1', sdtype='numerical')
metadata.add_column('col', sdtype='numerical')

# Run
column_synthesizer._get_trained_synthesizer(data, metadata)
Expand Down

0 comments on commit aa50092

Please sign in to comment.