-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerator.py
473 lines (373 loc) · 13 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
import random
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Union
import numpy as np
import pandas as pd
import torch
import pyspark.sql.types as st
from pyspark.sql import DataFrame
from pyspark.sql import functions as sf
from sdv.tabular import CopulaGAN, CTGAN, GaussianCopula, TVAE
from sim4rec.utils.session_handler import State
from sim4rec.params import (
HasLabel, HasDevice, HasDataSize,
HasParallelizationLevel, HasSeedSequence, HasWeights
)
from sim4rec.utils import (
NotFittedError, EmptyDataFrameError, save, load
)
class GeneratorBase(ABC, HasLabel, HasDataSize, HasSeedSequence):
"""
Base class for data generators
"""
def __init__(
self,
label : str,
seed : int = None
):
"""
:param label: Generator string label
:param seed: Fixes seed sequence to use during multiple
generator calls, defaults to None
"""
super().__init__()
self.setLabel(label)
self.setDataSize(0)
self.initSeedSequence(seed)
self._fit_called = False
self._df = None
def fit(
self,
df : DataFrame
):
"""
Fits generator on passed dataframe
:param df: Source dataframe to fit on
"""
raise NotImplementedError()
@abstractmethod
def generate(
self,
num_samples : int
):
"""
Generates num_samples from fitted model or saved dataframe
:param num_samples: Number of samples to generate
"""
raise NotImplementedError()
def sample(
self,
sample_frac : float
) -> DataFrame:
"""
Samples a fraction of rows from a dataframe, generated with
generate() call
:param sample_frac: Fraction of rows
:returns: Sampled dataframe
"""
if self._df is None:
raise EmptyDataFrameError(
'Dataframe is empty. Maybe the generate() was never called?'
)
seed = self.getNextSeed()
return self._df.sample(sample_frac, seed=seed)
class RealDataGenerator(GeneratorBase, HasParallelizationLevel):
"""
Real data generator, which can sample from existing dataframe
"""
_source_df : DataFrame
def __init__(
self,
label : str,
parallelization_level : int = 1,
seed : int = None
):
"""
:param label: Generator string label
:param parallelization_level: Parallelization level, defaults to 1
:param seed: Fixes seed sequence to use during multiple
generator calls, defaults to None
"""
super().__init__(label=label, seed=seed)
self.setParallelizationLevel(parallelization_level)
def fit(
self,
df : DataFrame
) -> None:
"""
:param df: Dataframe for generation and sampling
"""
if self._fit_called:
self._source_df.unpersist()
self._source_df = df.cache()
self._fit_called = True
def generate(
self,
num_samples : int
) -> DataFrame:
"""
Generates a number of samples from fitted dataframe
and keeps it for sampling
:param num_samples: Number of samples to generate
:returns: Generated dataframe
"""
if not self._fit_called:
raise NotFittedError()
source_size = self._source_df.count()
if num_samples > source_size:
raise ValueError('Not enough samples in fitted dataframe')
seed = self.getNextSeed()
if self._df is not None:
self._df.unpersist()
pl = self.getParallelizationLevel()
self._df = self._source_df.orderBy(sf.rand(seed=seed))\
.limit(num_samples)\
.repartition(pl)\
.cache()
self.setDataSize(self._df.count())
return self._df
def set_sdv_seed(seed : int = None):
"""
Fixes seed for SDV
"""
# this is the only way to fix seed in SDV library
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(torch.seed() if seed is None else seed)
# pylint: disable=too-many-ancestors
class SDVDataGenerator(GeneratorBase, HasParallelizationLevel, HasDevice):
"""
Synthetic data generator with a bunch of models from SDV library
"""
_model : Optional[Union[CopulaGAN, CTGAN, GaussianCopula, TVAE]] = None
SEED_COLUMN_NAME = '__seed'
_sdv_model_dict = {
'copulagan' : CopulaGAN,
'ctgan' : CTGAN,
'gaussiancopula' : GaussianCopula,
'tvae' : TVAE
}
# pylint: disable=too-many-arguments
def __init__(
self,
label : str,
id_column_name : str,
model_name : str = 'gaussiancopula',
parallelization_level : int = 1,
device_name : str = 'cpu',
seed : int = None
):
"""
:param label: Generator string label
:param id_column_name: Column name for identifier
:param model_name: Name of a SDV model. Possible values are:
['copulagan', 'ctgan', 'gaussiancopula', 'tvae'],
defaults to 'gaussiancopula'
:param parallelization_level: Parallelization level, defaults to 1
:param device_name: PyTorch device name, defaults to 'cpu'
:param seed: Fixes seed sequence to use during multiple
generator calls, defaults to None
"""
super().__init__(label=label, seed=seed)
self.setParallelizationLevel(parallelization_level)
self._id_col_name = id_column_name
self._model_name = model_name
self.setDevice(device_name)
self._schema = None
def fit(
self,
df : DataFrame
) -> None:
"""
Fits a generation model with a passed dataframe. The one should
pass only feature columns
:param df: Dataframe to fit on
"""
model_params = {'cuda' : self.getDevice()}\
if self._model_name != 'gaussiancopula' else {}
self._model = self._sdv_model_dict[self._model_name](**model_params)
if self._id_col_name in df.columns:
df = df.drop(self._id_col_name)
self._schema = st.StructType(
[
st.StructField(self._id_col_name, st.StringType())
] + df.schema.fields
)
self._model.fit(df.toPandas())
self._fit_called = True
def setDevice(
self,
value : str
) -> None:
"""
Changes the current device. Note, that for gaussiancopula
model, only cpu is supported
:param device_name: PyTorch device name
"""
super().setDevice(value)
if self._model_name != 'gaussiancopula' and self._fit_called:
self._model._model.set_device(torch.device(value))
def generate(
self,
num_samples : int
) -> DataFrame:
"""
Generates a number of samples from fitted dataframe
and keeps it for sampling
:param num_samples: Number of samples to generate
:returns: Generated dataframe
"""
if not self._fit_called:
raise NotFittedError('Fit was never called')
if num_samples < 0:
raise ValueError('num_samples must be non-negative value')
if self._df is not None:
self._df.unpersist()
label = self.getLabel()
pl = self.getParallelizationLevel()
seed = self.getNextSeed()
result_df = State().session.range(
start=0, end=num_samples, numPartitions=pl
)
result_df = result_df\
.withColumnRenamed('id', self._id_col_name)\
.withColumn(
self._id_col_name,
sf.concat(sf.lit(f'{label}_'), sf.col(self._id_col_name))
)\
.withColumn(self.SEED_COLUMN_NAME, sf.spark_partition_id() + sf.lit(seed))
model = self._model
seed_col = self.SEED_COLUMN_NAME
def generate_pdf(iterator):
for pdf in iterator:
seed = hash(pdf[seed_col][0]) & 0xffffffff
set_sdv_seed(seed)
sampled_df = model.sample(len(pdf), output_file_path='disable')
yield pd.concat([pdf.drop(columns=seed_col), sampled_df], axis=1)
set_sdv_seed()
result_df = result_df.mapInPandas(generate_pdf, self._schema)
self._df = result_df.cache()
self.setDataSize(self._df.count())
return self._df
def save_model(
self,
filename : str
):
"""
Saves generator model to file. Note, that it saves only
fitted model, but not the generated dataframe
:param filename: Path to the file
"""
if not self._fit_called:
raise NotFittedError('Fit was never called')
save_device = self.getDevice()
self.setDevice('cpu')
generator_data = (
self.getLabel(),
self._id_col_name,
self._model_name,
self.getParallelizationLevel(),
save_device,
self.getInitSeed(),
self._model,
self._schema
)
save(generator_data, filename)
@staticmethod
def load(filename : str):
"""
Loads the generator model from the file
:param filename: Path to the file
:return: Generator instance with restored model
"""
label, id_col_name, model_name, p_level,\
device_name, init_seed, model, schema = load(filename)
generator = SDVDataGenerator(
label=label,
id_column_name=id_col_name,
model_name=model_name,
parallelization_level=p_level,
device_name=device_name,
seed=init_seed
)
generator._model = model
generator._fit_called = True
generator._schema = schema
try:
generator.setDevice(device_name)
except RuntimeError:
print(f'Cannot load model to device {device_name}. Setting cpu instead')
generator.setDevice('cpu')
return generator
# pylint: disable=too-many-ancestors
class CompositeGenerator(GeneratorBase, HasWeights):
"""
Wrapper for sampling from multiple generators. Use weights
parameter to control the sampling fraction for each of the
generator
"""
def __init__(
self,
generators : List[GeneratorBase],
label : str,
weights : Iterable = None,
):
"""
:param generators: List of generators
:param label: Generator string label
:param weights: Weights for each of the generator. Weights
must be normalized (sums to 1), defaults to None
"""
super().__init__(label=label)
self._generators = generators
data_sizes = [g.getDataSize() for g in self._generators]
data_sizes_sum = sum(data_sizes)
self.setDataSize(data_sizes_sum)
if weights is None and data_sizes_sum != 0:
if data_sizes_sum > 0:
weights = [s / data_sizes_sum for s in data_sizes]
else:
n = len(self._generators)
weights = [1 / n] * n
self.setWeights(weights)
def generate(
self,
num_samples: int
) -> None:
"""
For each generator calls generate() with number of samples,
proportional to weights to generate num_samples in total. You
can call this method to not perform generate() separately on
each generator
:param num_samples: Total number of samples to generate
"""
weights = self.getWeights()
num_required_samples = [round(num_samples * w) for w in weights]
for g, n in zip(self._generators, num_required_samples):
_ = g.generate(n)
self.setDataSize(sum([g.getDataSize() for g in self._generators]))
def sample(
self,
sample_frac : float
) -> DataFrame:
"""
Samples a fraction of rows from generators according to the weights.
:param sample_frac: Fraction of rows
:returns: Sampled dataframe
"""
weights = self.getWeights()
data_sizes = [g.getDataSize() for g in self._generators]
data_sizes_sum = sum(data_sizes)
num_required_samples = [int(data_sizes_sum * sample_frac * w) for w in weights]
for i in range(len(data_sizes)):
if num_required_samples[i] > data_sizes[i]:
raise ValueError(
f'Not enough samples in generator {self._generators[i].getLabel()}'
)
generator_fracs = []
for n, s in zip(num_required_samples, data_sizes):
generator_fracs.append(0.0 if s == 0 else n / s)
result_df = self._generators[0].sample(generator_fracs[0])
for g, f in zip(self._generators[1:], generator_fracs[1:]):
result_df = result_df.unionByName(g.sample(f))
return result_df