diff --git a/sdgym/synthesizers/generate.py b/sdgym/synthesizers/generate.py index 20c9bd5..0ea678d 100644 --- a/sdgym/synthesizers/generate.py +++ b/sdgym/synthesizers/generate.py @@ -10,6 +10,7 @@ ) from sdgym.synthesizers.base import BaselineSynthesizer, MultiSingleTableBaselineSynthesizer +from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer from sdgym.synthesizers.sdv import SDVRelationalSynthesizer, SDVTabularSynthesizer SYNTHESIZER_MAPPING = { @@ -19,6 +20,7 @@ 'TVAESynthesizer': TVAESynthesizer, 'PARSynthesizer': PARSynthesizer, 'HMASynthesizer': HMASynthesizer, + 'RealTabFormerSynthesizer': RealTabFormerSynthesizer, } @@ -55,6 +57,8 @@ def create_sdv_synthesizer_variant(display_name, synthesizer_class, synthesizer_ baseclass = SDVTabularSynthesizer if synthesizer_class == 'HMASynthesizer': baseclass = SDVRelationalSynthesizer + if synthesizer_class == 'RealTabFormerSynthesizer': + baseclass = RealTabFormerSynthesizer class NewSynthesizer(baseclass): """New Synthesizer class. diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index 6c4098f..3af1d30 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -55,9 +55,15 @@ def test_benchmark_single_table_basic_synthsizers(): def test_benchmark_single_table_realtabformer_no_metrics(): """Test it without metrics.""" # Run + custom_synthesizer = create_sdv_synthesizer_variant( + display_name='RealTabFormerSynthesizer', + synthesizer_class='RealTabFormerSynthesizer', + synthesizer_parameters={'epochs': 2}, + ) output = sdgym.benchmark_single_table( - synthesizers=['RealTabFormerSynthesizer'], - sdv_datasets=['student_placements'], + synthesizers=[], + custom_synthesizers=[custom_synthesizer], + sdv_datasets=['fake_companies'], sdmetrics=[], )