diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index b9ac6080..f9cfc957 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -48,4 +48,5 @@ jobs: with: files: ./coverage.xml name: dbldatagen - verbose: true \ No newline at end of file + verbose: true + token: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e8866f1f..e89ad16d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to the Databricks Labs Data Generator will be documented in #### Changed * Fixed use of logger in _version.py and in spark_singleton.py +* Fixed template issues ### Version 0.3.2 diff --git a/dbldatagen/text_generators.py b/dbldatagen/text_generators.py index 96c03a4e..7a68024a 100644 --- a/dbldatagen/text_generators.py +++ b/dbldatagen/text_generators.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +import logging #: list of hex digits for template generation _HEX_LOWER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'] @@ -168,6 +169,7 @@ class TemplateGenerator(TextGenerator): # lgtm [py/missing-equals] :param escapeSpecialChars: By default special chars in the template have special meaning if unescaped If set to true, then the special meaning requires escape char ``\\`` :param extendedWordList: if provided, use specified word list instead of default word list + :param legacyEscapeTreatment: if True, sequences of escape char are treated as single escape, Defaults to True The template generator generates text from a template to allow for generation of synthetic account card numbers, VINs, IBANs and many other structured codes. @@ -218,16 +220,21 @@ class TemplateGenerator(TextGenerator): # lgtm [py/missing-equals] If set to True, then the template ``r"dr_\\v"`` will generate the values ``"dr_0"`` ... ``"dr_999"`` when applied to the values zero to 999. This conforms to the preferred style going forward + .. note:: + In previous versions, multiple sequences of escape char are treated as single escape char + so template of r'\\w' and r'\w' have same meaning and dont emit escape char. If `legacyEscapeTreatment` + is false, r'\\w' will emit escape char in output string. + """ - def __init__(self, template, escapeSpecialChars=False, extendedWordList=None): + def __init__(self, template, escapeSpecialChars=False, extendedWordList=None, legacyEscapeTreatment=True): assert template is not None, "`template` must be specified" super().__init__() self._template = template self._escapeSpecialMeaning = bool(escapeSpecialChars) - template_str0 = self._template - self._templates = [x.replace('$__sep__', '|') for x in template_str0.replace(r'\|', '$__sep__').split('|')] + self._legacyEscapeTreatment = legacyEscapeTreatment + self._templates = self._splitTemplates(self._template) self._wordList = np.array(extendedWordList if extendedWordList is not None else _WORDS_LOWER) self._upperWordList = np.array([x.upper() for x in extendedWordList] if extendedWordList is not None else _WORDS_UPPER) @@ -243,18 +250,80 @@ def __init__(self, template, escapeSpecialChars=False, extendedWordList=None): self._np_letters_all = np.array(_LETTERS_ALL) self._lenWords = len(self._wordList) - # get the template metadata + # mappings must be mapping from string to tuple(length of mappings, mapping array or list) + self._templateMappings = { + 'a': (26, self._np_letters_lower), + 'A': (26, self._np_letters_upper), + 'x': (16, self._np_hex_lower), + 'X': (16, self._np_hex_upper), + 'd': (10, self._np_digits_zero), + 'D': (9, self._np_digits_non_zero), + 'k': (36, self._np_alnum_lower), + 'K': (36, self._np_alnum_upper) + } + + # ensure that each mapping is mapping from string to list or numpy array + for k, v in self._templateMappings.items(): + assert (k is not None) and isinstance(k, str) and len(k) > 0, "key must be non-empty string" + assert v is not None and isinstance(v, tuple) and len(v) == 2, "value must be tuple of length 2" + mapping_length, mappings = v + assert isinstance(mapping_length, int), "mapping length must be of type int" + assert isinstance(mappings, list) or isinstance(mappings, np.ndarray),\ + "mappings are lists or numpy arrays" + assert mapping_length == 0 or len(mappings) == mapping_length, "mappings must match mapping_length" + + self._templateEscapedMappings = { + 'n': (256, None), + 'N': (65536, None), + 'w': (self._lenWords, self._wordList), + 'W': (self._lenWords, self._upperWordList) + } + + # ensure that each escaped mapping is mapping from string to None, list or numpy array + for k, v in self._templateEscapedMappings.items(): + assert (k is not None) and isinstance(k, str) and len(k) > 0, "key must be non-empty string" + assert v is not None and isinstance(v, tuple) and len(v) == 2, "value must be tuple of length 2" + mapping_length, mappings = v + assert isinstance(mapping_length, int), "mapping length must be of type int" + assert mappings is None or isinstance(mappings, list) or isinstance(mappings, np.ndarray),\ + "mappings are lists or numpy arrays" + + # for escaped mappings, the mapping can be None in which case the mapping is to the number itself + # i.e mapping[4] = 4 + assert mappings is None or len(mappings) == mapping_length, "mappings must match mapping_length" + + # get the template metadata - this will be list of metadata entries for each template + # for each template, metadata will be tuple of number of placeholders followed by list of random bounds + # to be computed when replacing non static placeholder template_info = [self._prepareTemplateStrings(template, escapeSpecialMeaning=escapeSpecialChars) for template in self._templates] + + logger = logging.getLogger(__name__) + + #if logger.isEnabledFor(logging.DEBUG): + for ix, ti in template_info: + logger.info(f"templates - {ix} {ti}") + self._max_placeholders = max([ x[0] for x in template_info]) self._max_rnds_needed = max([ len(x[1]) for x in template_info]) self._placeholders_needed = [ x[0] for x in template_info] self._template_rnd_bounds = [ x[1] for x in template_info] - def __repr__(self): return f"TemplateGenerator(template='{self._template}')" + def _splitTemplates(self, templateStr): + """ Split template string into individual template strings + + :param templateStr: template string + :return: list of individual template strings + + + """ + tmp_template = templateStr.replace(r'\\', '$__escape__').replace(r'\|', '$__sep__') + results = [x.replace('$__escape__', r'\\').replace('$__sep__', '|') for x in tmp_template.split('|')] + return results + @property def templates(self): """ Get effective templates for text generator""" @@ -312,64 +381,27 @@ def _prepareTemplateStrings(self, genTemplate, escapeSpecialMeaning=False): char = genTemplate[i] following_char = genTemplate[i + 1] if i + 1 < template_len else None - if char == '\\': + if char == '\\' and (escape and not self._legacyEscapeTreatment): + escape = False + num_placeholders += 1 + elif char == '\\': escape = True elif use_value and ('0' <= char <= '9'): # val_index = int(char) # retval.append(str(baseValue[val_index])) num_placeholders += 1 use_value = False - elif char == 'x' and (not escape) ^ escapeSpecialMeaning: - retval.append(16) - num_placeholders += 1 - # used for retval.append(_HEX_LOWER[self._getRandomInt(0, 15, rndGenerator)]) - elif char == 'X' and (not escape) ^ escapeSpecialMeaning: - retval.append(16) - num_placeholders += 1 - # retval.append(_HEX_UPPER[self._getRandomInt(0, 15, rndGenerator)]) - elif char == 'd' and (not escape) ^ escapeSpecialMeaning: - retval.append(10) + elif (char in self._templateMappings.keys()) and (not escape) ^ escapeSpecialMeaning: + # handle case for ['a','A','k', 'K', 'x', 'X'] + bound, mappingArr = self._templateMappings[char] + retval.append(bound) num_placeholders += 1 - # retval.append(_DIGITS_ZERO[self._getRandomInt(0, 9, rndGenerator)]) - elif char == 'D' and (not escape) ^ escapeSpecialMeaning: - retval.append(9) - num_placeholders += 1 - # retval.append(_DIGITS_NON_ZERO[self._getRandomInt(0, 8, rndGenerator)]) - elif char == 'a' and (not escape) ^ escapeSpecialMeaning: - retval.append(26) - num_placeholders += 1 - # retval.append(_LETTERS_LOWER[self._getRandomInt(0, 25, rndGenerator)]) - elif char == 'A' and (not escape) ^ escapeSpecialMeaning: - retval.append(26) - num_placeholders += 1 - # retval.append(_LETTERS_UPPER[self._getRandomInt(0, 25, rndGenerator)]) - elif char == 'k' and (not escape) ^ escapeSpecialMeaning: - retval.append(26) - num_placeholders += 1 - # retval.append(_ALNUM_LOWER[self._getRandomInt(0, 35, rndGenerator)]) - elif char == 'K' and (not escape) ^ escapeSpecialMeaning: - retval.append(36) - num_placeholders += 1 - # retval.append(_ALNUM_UPPER[self._getRandomInt(0, 35, rndGenerator)]) - elif char == 'n' and escape: - retval.append(256) - num_placeholders += 1 - # retval.append(str(self._getRandomInt(0, 255, rndGenerator))) - escape = False - elif char == 'N' and escape: - retval.append(65536) - num_placeholders += 1 - # retval.append(str(self._getRandomInt(0, 65535, rndGenerator))) - escape = False - elif char == 'W' and escape: - retval.append(self._lenWords) - num_placeholders += 1 - # retval.append(self._upperWordList[self._getRandomWordOffset(self._lenWords, rndGenerator=rndGenerator)]) escape = False - elif char == 'w' and escape: - retval.append(self._lenWords) + elif (char in self._templateEscapedMappings.keys()) and escape: + # handle case for ['n', 'N', 'w', 'W'] + bound, mappingArr = self._templateEscapedMappings[char] + retval.append(bound) num_placeholders += 1 - # retval.append(self._wordList[self._getRandomWordOffset(self._lenWords, rndGenerator=rndGenerator)]) escape = False elif char == 'v' and escape: escape = False @@ -410,12 +442,23 @@ def _applyTemplateStringsForTemplate(self, baseValue, genTemplate, placeholders, `_escapeSpecialMeaning` parameter allows for backwards compatibility with old style syntax while allowing for preferred new style template syntax. Specify as True to force escapes for special meanings,. + .. note:: + Both `placeholders` and `rnds` are numpy masked arrays. If there are multiple templates in the template + generation source template, then this method will be called multiple times with each of + the distinct templates passed and the `placeholders` and `rnds` arrays masked so that the each call + will apply the template to rows to which that template applies. + + The template may be the empty string. + """ assert baseValue.shape[0] == placeholders.shape[0] assert baseValue.shape[0] == rnds.shape[0] _cached_values = {} + regularKeys = self._templateMappings.keys() + escapedKeys = self._templateEscapedMappings.keys() + def _get_values_as_np_array(): """Get baseValue which is pd.Series or Dataframe as a numpy array and cache it""" if "np_values" not in _cached_values: @@ -444,6 +487,15 @@ def _get_values_subelement(elem): num_placeholders = 0 rnd_offset = 0 + masked_rows = None + + assert isinstance(placeholders, np.ma.MaskedArray), "expecting MaskArray" + + # if template is empty, then nothing needs to be done + if template_len > 0 and isinstance(placeholders, np.ma.MaskedArray): + active_rows = ~placeholders.mask + masked_rows = active_rows[:, 0] + # in the following code, the construct `(not escape) ^ self._escapeSpecialMeaning` means apply # special meaning if either escape is not true or the option `self._escapeSpecialMeaning` is true. # This corresponds to the logical xor operation @@ -451,7 +503,12 @@ def _get_values_subelement(elem): char = genTemplate[i] following_char = genTemplate[i + 1] if i + 1 < template_len else None - if char == '\\': + if char == '\\' and (escape and not self._legacyEscapeTreatment): + escape = False + placeholders[:, num_placeholders] = char + # retval.append(char) + num_placeholders += 1 + elif char == '\\': escape = True elif use_value and ('0' <= char <= '9'): val_index = int(char) @@ -459,72 +516,37 @@ def _get_values_subelement(elem): #placeholders[:, num_placeholders] = pd_base_values.apply(lambda x: str(x[val_index])) num_placeholders += 1 use_value = False - elif char == 'x' and (not escape) ^ escapeSpecialMeaning: + elif char in regularKeys and (not escape) ^ escapeSpecialMeaning: # note vectorized lookup - `rnds[:, rnd_offset]` will get vertical column of # random numbers from `rnds` 2d array - placeholders[:, num_placeholders] = self._np_hex_lower[rnds[:, rnd_offset]] + bound, valueMappings = self._templateMappings[char] + + if masked_rows is not None: + placeholders[masked_rows, num_placeholders] = valueMappings[rnds[masked_rows, rnd_offset]] + else: + placeholders[:, num_placeholders] = valueMappings[rnds[:, rnd_offset]] + num_placeholders += 1 rnd_offset = rnd_offset + 1 + escape = False # used for retval.append(_HEX_LOWER[self._getRandomInt(0, 15, rndGenerator)]) - elif char == 'X' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_hex_upper[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_HEX_UPPER[self._getRandomInt(0, 15, rndGenerator)]) - elif char == 'd' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_digits_zero[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_DIGITS_ZERO[self._getRandomInt(0, 9, rndGenerator)]) - elif char == 'D' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_digits_non_zero[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_DIGITS_NON_ZERO[self._getRandomInt(0, 8, rndGenerator)]) - elif char == 'a' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_letters_lower[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_LETTERS_LOWER[self._getRandomInt(0, 25, rndGenerator)]) - elif char == 'A' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_letters_upper[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_LETTERS_UPPER[self._getRandomInt(0, 25, rndGenerator)]) - elif char == 'k' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_alnum_lower[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_ALNUM_LOWER[self._getRandomInt(0, 35, rndGenerator)]) - elif char == 'K' and (not escape) ^ escapeSpecialMeaning: - placeholders[:, num_placeholders] = self._np_alnum_upper[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(_ALNUM_UPPER[self._getRandomInt(0, 35, rndGenerator)]) - elif char == 'n' and escape: - placeholders[:, num_placeholders] = rnds[:, rnd_offset] + elif char in escapedKeys and escape: + bound, valueMappings = self._templateEscapedMappings[char] + + if valueMappings is not None: + if masked_rows is not None: + placeholders[masked_rows, num_placeholders] = valueMappings[rnds[masked_rows, rnd_offset]] + else: + placeholders[:, num_placeholders] = valueMappings[rnds[:, rnd_offset]] + else: + if masked_rows is not None: + placeholders[masked_rows, num_placeholders] = rnds[masked_rows, rnd_offset] + else: + placeholders[:, num_placeholders] = rnds[:, rnd_offset] num_placeholders += 1 rnd_offset = rnd_offset + 1 # retval.append(str(self._getRandomInt(0, 255, rndGenerator))) escape = False - elif char == 'N' and escape: - placeholders[:, num_placeholders] = rnds[:, rnd_offset] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(str(self._getRandomInt(0, 65535, rndGenerator))) - escape = False - elif char == 'W' and escape: - placeholders[:, num_placeholders] = self._upperWordList[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(self._upperWordList[self._getRandomWordOffset(self._lenWords, rndGenerator=rndGenerator)]) - escape = False - elif char == 'w' and escape: - placeholders[:, num_placeholders] = self._wordList[rnds[:, rnd_offset]] - num_placeholders += 1 - rnd_offset = rnd_offset + 1 - # retval.append(self._wordList[self._getRandomWordOffset(self._lenWords, rndGenerator=rndGenerator)]) - escape = False elif char == 'v' and escape: escape = False if following_char is not None and ('0' <= following_char <= '9'): @@ -637,7 +659,7 @@ def pandasGenerateText(self, v): for m in masked_matrices: np.ma.harden_mask(m) - # expand values into placeholders + # expand values into placeholders without affect masked values #self._applyTemplateStringsForTemplate(v.to_numpy(dtype=np.object_), #masked_base_values, self._applyTemplateStringsForTemplate(v, #masked_base_values, diff --git a/docs/source/APIDOCS.md b/docs/source/APIDOCS.md index c786458a..10679103 100644 --- a/docs/source/APIDOCS.md +++ b/docs/source/APIDOCS.md @@ -250,11 +250,11 @@ dataspec = (dg.DataGenerator(spark, rows=10000000, partitions=8, .withSchema(table_schema)) dataspec = (dataspec - .withColumnSpec("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumnSpec("name", percentNulls=0.01, template=r'\w \w|\w a. \w') .withColumnSpec("serial_number", minValue=1000000, maxValue=10000000, prefix="dr", random=True) - .withColumnSpec("email", template=r'\\w.\\w@\\w.com') - .withColumnSpec("license_plate", template=r'\\n-\\n') + .withColumnSpec("email", template=r'\w.\w@\w.com') + .withColumnSpec("license_plate", template=r'\n-\n') ) df1 = dataspec.build() @@ -472,7 +472,7 @@ data_rows = 10000000 spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions_requested) dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=8, randomSeedMethod="hash_fieldname") - .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumn("name", percentNulls=0.01, template=r'\w \w|\w a. \w') .withColumn("payment_instrument_type", values=['paypal', 'visa', 'mastercard', 'amex'], random=True) .withColumn("int_payment_instrument", "int", minValue=0000, maxValue=9999, @@ -481,7 +481,7 @@ dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=8, randomSeedMeth .withColumn("payment_instrument", expr="format_number(int_payment_instrument, '**** ****** *####')", baseColumn="int_payment_instrument") - .withColumn("email", template=r'\\w.\\w@\\w.com') + .withColumn("email", template=r'\w.\w@\w.com') .withColumn("md5_payment_instrument", expr="md5(concat(payment_instrument_type, ':', payment_instrument))", baseColumn=['payment_instrument_type', 'payment_instrument']) @@ -524,7 +524,7 @@ spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions_requested) dataspec = ( dg.DataGenerator(spark, rows=data_rows, partitions=8, randomSeedMethod="hash_fieldname", randomSeed=42) - .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumn("name", percentNulls=0.01, template=r'\w \w|\w a. \w') .withColumn("payment_instrument_type", values=['paypal', 'visa', 'mastercard', 'amex'], random=True) .withColumn("int_payment_instrument", "int", minValue=0000, maxValue=9999, @@ -533,7 +533,7 @@ dataspec = ( .withColumn("payment_instrument", expr="format_number(int_payment_instrument, '**** ****** *####')", baseColumn="int_payment_instrument") - .withColumn("email", template=r'\\w.\\w@\\w.com') + .withColumn("email", template=r'\w.\w@\w.com') .withColumn("md5_payment_instrument", expr="md5(concat(payment_instrument_type, ':', payment_instrument))", baseColumn=['payment_instrument_type', 'payment_instrument']) diff --git a/docs/source/generating_cdc_data.rst b/docs/source/generating_cdc_data.rst index 6528f4df..a3750690 100644 --- a/docs/source/generating_cdc_data.rst +++ b/docs/source/generating_cdc_data.rst @@ -49,8 +49,8 @@ We'll add a timestamp for when the row was generated and a memo field to mark wh dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) .withColumn("customer_id","long", uniqueValues=uniqueCustomers) - .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') - .withColumn("alias", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumn("name", percentNulls=0.01, template=r'\w \w|\w a. \w') + .withColumn("alias", percentNulls=0.01, template=r'\w \w|\w a. \w') .withColumn("payment_instrument_type", values=['paypal', 'Visa', 'Mastercard', 'American Express', 'discover', 'branded visa', 'branded mastercard'], random=True, distribution="normal") @@ -58,9 +58,9 @@ We'll add a timestamp for when the row was generated and a memo field to mark wh baseColumnType="hash", omit=True) .withColumn("payment_instrument", expr="format_number(int_payment_instrument, '**** ****** *####')", baseColumn="int_payment_instrument") - .withColumn("email", template=r'\\w.\\w@\\w.com|\\w-\\w@\\w') - .withColumn("email2", template=r'\\w.\\w@\\w.com') - .withColumn("ip_address", template=r'\\n.\\n.\\n.\\n') + .withColumn("email", template=r'\w.\w@\w.com|\w-\w@\w') + .withColumn("email2", template=r'\w.\w@\w.com') + .withColumn("ip_address", template=r'\n.\n.\n.\n') .withColumn("md5_payment_instrument", expr="md5(concat(payment_instrument_type, ':', payment_instrument))", base_column=['payment_instrument_type', 'payment_instrument']) diff --git a/docs/source/multi_table_data.rst b/docs/source/multi_table_data.rst index 44f99150..a3e4e988 100644 --- a/docs/source/multi_table_data.rst +++ b/docs/source/multi_table_data.rst @@ -167,7 +167,7 @@ when using hashed values, the range of the hashes produced can be large. customer_dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) .withColumn("customer_id","decimal(10)", minValue=CUSTOMER_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS) - .withColumn("customer_name", template=r"\\w \\w|\\w a. \\w") + .withColumn("customer_name", template=r"\w \w|\w a. \w") # use the following for a simple sequence #.withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, diff --git a/docs/source/textdata.rst b/docs/source/textdata.rst index 8b249783..b937f52c 100644 --- a/docs/source/textdata.rst +++ b/docs/source/textdata.rst @@ -164,6 +164,13 @@ If set to False, then the template ``r"\\dr_\\v"`` will generate the values ``"d to the values zero to 999. This conforms to earlier implementations for backwards compatibility. If set to True, then the template ``r"dr_\\v"`` will generate the values ``"dr_0"`` ... ``"dr_999"`` -when applied to the values zero to 999. This conforms to the preferred style going forward +when applied to the values zero to 999. This conforms to the preferred style going forward. In other words the char `d` +will not be treated as a special char. + +.. note:: + The legacy mode of operation has a bug where the template sequence r'\\a' produces the same result as r'\a'. + This can be disabled by setting the parameter `legacyEscapeTreatment` to False on the TemplateTextGenerator + object. It is true by default. + diff --git a/tests/test_text_templates.py b/tests/test_text_templates.py new file mode 100644 index 00000000..8396af83 --- /dev/null +++ b/tests/test_text_templates.py @@ -0,0 +1,478 @@ +import re +import pytest +import pandas as pd +import numpy as np + +import pyspark.sql.functions as F +from pyspark.sql.types import BooleanType, DateType +from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType + +import dbldatagen as dg +from dbldatagen import TemplateGenerator, TextGenerator + +# add the following if using pandas udfs +# .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") \ + + +spark = dg.SparkSingleton.getLocalInstance("unit tests") + +spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000") +spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + + +# Test manipulation and generation of test data for a large schema +class TestTextTemplates: + testDataSpec = None + row_count = 100000 + partitions_requested = 4 + + @pytest.mark.parametrize("templates, splitTemplates", + [ + (r"a|b", ['a', 'b']), + (r"a|b|", ['a', 'b', '']), + (r"a", ['a']), + (r"", ['']), + (r"a\|b", [r'a|b']), + (r"a\\|b", [r'a\\', 'b']), + (r"a\|b|c", [r'a|b', 'c']), + (r"123,$456|test test2 |\|\a\\a |021 \| 123", + ['123,$456', 'test test2 ', '|\\a\\\\a ', '021 | 123']), + ( + r"123 \\| 123 \|123 | 123|123|123 |asd023,\|23|", + ['123 \\\\', ' 123 |123 ', ' 123', '123', '123 ', 'asd023,|23', '']), + (r" 123|123|123 |asd023,\|23", [' 123', '123', '123 ', 'asd023,|23']), + (r'',[ '']) + ]) + def test_split_templates(self, templates, splitTemplates): + tg1 = TemplateGenerator("test", escapeSpecialChars=False) + + results = tg1._splitTemplates(templates) + + assert results == splitTemplates + + + @pytest.mark.parametrize("template_provided, escapeSpecial, useTemplateObject", + [ #(r'\w \w|\w \v. \w', False, False), + (r'A', False, True), + (r'D', False, True), + (r'K', False, True), + (r'X', False, True), + (r'\W', False, True), + (r'\W', True, True), + (r'\\w A. \\w|\\w \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, True), + (r'\\w \\w|\\w A. \\w', True, True), + (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), + (r'\\w \\w|\\w K. \\w', False, False), + (r'\\w \\w|\\w K. \\w', False, True), + (r'\\w \\w|\\w K. \\w', True, True), + (r'\\w \\w|\\w X. \\w', False, False), + (r'\\w \\w|\\w X. \\w', False, True), + (r'\\w \\w|\\w X. \\w', True, True), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\\w \\w|\\w k. \\w', False, False), + (r'\\w \\w|\\w k. \\w', False, True), + (r'\\w \\w|\\w k. \\w', True, True), + (r'\\w \\w|\\w x. \\w', False, False), + (r'\\w \\w|\\w x. \\w', False, True), + (r'\\w \\w|\\w x. \\w', True, True), + (r'\\w a. \\w', False, True), + (r'\\w a. \\w|\\w \\w', False, False), + (r'\\w k. \\w', False, True), + (r'\\w k. \\w|\\w \\w', False, False), + (r'\n', False, True), + (r'\n', True, True), + (r'\v', False, True), + (r'\v', True, True), + (r'\w A. \w', False, False), + (r'\w \a. \w', True, True), + (r'\w \k. \w', True, True), + (r'\w \n \w', True, True), + (r'\w \w|\w A. \w', False, False), + (r'\w \w|\w \A. \w', True, True), + (r'\w \w|\w \a. \w', True, True), + (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), + (r'\w aAdDkK \w', False, False), + (r'\w aAdDkKxX \n \N \w', False, False), + (r'\w', False, False), + (r'\w', False, True), + (r'\w', True, True), + (r'a', False, True), + (r'b', False, False), + (r'b', False, True), + (r'b', True, True), + (r'd', False, True), + (r'k', False, True), + (r'x', False, True), + ('', False, False), + ('', False, True), + (r'', True, True), + ]) + + def test_rnd_compute(self, template_provided, escapeSpecial, useTemplateObject): + template1 = TemplateGenerator(template_provided, escapeSpecialChars=escapeSpecial) + print(f"template [{template_provided}]") + + arr = np.arange(100) + + template_choices, template_rnd_bounds, template_rnds = template1._prepare_random_bounds(arr) + + assert template_choices is not None + assert template_rnd_bounds is not None + assert template_rnds is not None + assert len(template_choices) == len(template_rnds) + assert len(template_choices) == len(template_rnd_bounds) + + for ix in range(len(template_choices)): + bounds = template_rnd_bounds[ix] + rnds = template_rnds[ix] + + assert len(bounds) == len(rnds) + + for iy in range(len(bounds)): + assert bounds[iy] == -1 or (rnds[iy] < bounds[iy]) + + @pytest.mark.parametrize("template_provided, escapeSpecial, useTemplateObject", + [ #(r'\w \w|\w \v. \w', False, False), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\w \w|\w a. \w', False, False), + (r'\w.\w@\w.com', False, False), + (r'\n-\n', False, False), + (r'A', False, True), + (r'D', False, True), + (r'K', False, True), + (r'X', False, True), + (r'\W', False, True), + (r'\W', True, True), + (r'\\w A. \\w|\\w \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, True), + (r'\\w \\w|\\w A. \\w', True, True), + (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), + (r'\\w \\w|\\w K. \\w', False, False), + (r'\\w \\w|\\w K. \\w', False, True), + (r'\\w \\w|\\w K. \\w', True, True), + (r'\\w \\w|\\w X. \\w', False, False), + (r'\\w \\w|\\w X. \\w', False, True), + (r'\\w \\w|\\w X. \\w', True, True), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\\w \\w|\\w k. \\w', False, False), + (r'\\w \\w|\\w k. \\w', False, True), + (r'\\w \\w|\\w k. \\w', True, True), + (r'\\w \\w|\\w x. \\w', False, False), + (r'\\w \\w|\\w x. \\w', False, True), + (r'\\w \\w|\\w x. \\w', True, True), + (r'\\w a. \\w', False, True), + (r'\\w a. \\w|\\w \\w', False, False), + (r'\\w k. \\w', False, True), + (r'\\w k. \\w|\\w \\w', False, False), + (r'\n', False, True), + (r'\n', True, True), + (r'\v', False, True), + (r'\v', True, True), + (r'\v|\v-\v', False, True), + (r'\v|\v-\v', True, True), + (r'short string|a much longer string which is bigger than short string', False, True), + (r'short string|a much longer string which is bigger than short string', True, True), + (r'\w A. \w', False, False), + (r'\w \a. \w', True, True), + (r'\w \k. \w', True, True), + (r'\w \n \w', True, True), + (r'\w \w|\w A. \w', False, False), + (r'\w \w|\w \A. \w', True, True), + (r'\w \w|\w \a. \w', True, True), + (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), + (r'\w aAdDkK \w', False, False), + (r'\w aAdDkKxX \n \N \w', False, False), + (r'\w', False, False), + (r'\w', False, True), + (r'\w', True, True), + (r'a', False, True), + (r'b', False, False), + (r'b', False, True), + (r'b', True, True), + (r'd', False, True), + (r'k', False, True), + (r'x', False, True), + ('', False, False), + ('', False, True), + (r'', True, True), + ('|', False, False), + ('|', False, True), + (r'|', True, True), + (r'\ww - not e\xpecting two wor\ds', False, False), + (r'\ww - not expecting two words', True, True) + ]) + def test_use_pandas(self, template_provided, escapeSpecial, useTemplateObject): + template1 = TemplateGenerator(template_provided, escapeSpecialChars=escapeSpecial) + + TEST_ROWS = 100 + + arr = np.arange(TEST_ROWS) + + template_choices, template_rnd_bounds, template_rnds = template1._prepare_random_bounds(arr) + + assert len(template_choices) == len(template_rnds) + assert len(template_choices) == len(template_rnd_bounds) + + for ix in range(len(template_choices)): + bounds = template_rnd_bounds[ix] + rnds = template_rnds[ix] + + assert len(bounds) == len(rnds) + + for iy in range(len(bounds)): + assert bounds[iy] == -1 or (rnds[iy] < bounds[iy]) + + + results = template1.pandasGenerateText(arr) + assert results is not None + + results_list = results.tolist() + + results_rows = len(results_list) + assert results_rows == TEST_ROWS + + for r in range(len(results)): + result_str = results[r] + assert result_str is not None and isinstance(result_str, str) + assert len(result_str) >= 0 + + print("results") + for i in range(len(results)): + print(f"{i}: '{results[i]}'") + + @pytest.mark.parametrize("templateProvided, escapeSpecial, legacyEscapeTreatment,expectedPattern ", + [ (r'\\w \w', False, True, r"[a-z]+ [a-z]+"), + (r'\\w \w', False, False, r"\\w [a-z]+"), + (r'\\\w \w', False, False, r"\\[a-z]+ [a-z]+"), + (r'\\w \w', True, True, r"[a-z]+ [a-z]+"), + (r'\\w \w', True, False, r"\\w [a-z]+"), + (r'\\\w \w', True, False, r"\\[a-z]+ [a-z]+"), + (r'\n-\n', False, False, r"[0-9]+-[0-9]+"), + (r'\\n-\n', False, True, r"[0-9]+-[0-9]+"), + (r'\\n-\n', False, False, r"\\n-[0-9]+"), + (r'\\\n-\n', False, False, r"\\[0-9]+-[0-9]+"), + (r'\\n-\n', True, True, r"[0-9]+-[0-9]+"), + (r'\\n-\n', True, False, r"\\n-[0-9]+"), + (r'\\\n-\n', True, False, r"\\[0-9]+-[0-9]+"), + (r'\\\a', True, False, r"\\[a-z]"), + (r'\\a', False, False, r"\\[a-z]"), + (r'\\a c', False, True, r"[a-z] c"), + (r'\\a', True, True, r"[a-z]"), + ]) + def test_escape_treatment(self, templateProvided, escapeSpecial, legacyEscapeTreatment, expectedPattern): + + template1 = TemplateGenerator(templateProvided, escapeSpecialChars=escapeSpecial, + legacyEscapeTreatment=legacyEscapeTreatment) + + TEST_ROWS = 100 + + arr = np.arange(TEST_ROWS) + + results = template1.pandasGenerateText(arr) + assert results is not None + + results_list = results.tolist() + + results_rows = len(results_list) + assert results_rows == TEST_ROWS + + print(f"expected pattern - '{expectedPattern}', template '{templateProvided}'" ) + patt = re.compile(expectedPattern) + + print("results") + for i in range(len(results)): + print(f"{i}: '{results[i]}'") + assert isinstance(results[i], str) + assert patt.match(results[i]) is not None, f"expecting match '{results[i]}' === '{expectedPattern}'" + + @pytest.mark.parametrize("template_provided, escapeSpecial, legacyEscapeTreatment,expectedPattern ", + [ (r'\\w \\w|\\w a. \\w', False, False, ""), + (r'\\w \\w|\\w a. \\w', False, True, ""), + (r'\w \w|\w a. \w', False, False, ""), + (r'\w.\w@\w.com', False, False, ""), + (r'\n-\n', False, False, ""), + ]) + def test_value_sub1(self, template_provided, escapeSpecial, legacyEscapeTreatment, expectedPattern): + template1 = TemplateGenerator(template_provided, escapeSpecialChars=escapeSpecial) + print(f"template [{template_provided}]") + + print("max_placeholders", template1._max_placeholders ) + print("max_rnds", template1._max_rnds_needed) + print("placeholders", template1._placeholders_needed ) + print("bounds", template1._template_rnd_bounds) + + print("templates", template1.templates) + + TEST_ROWS = 100 + + arr = np.arange(TEST_ROWS) + + template_choices, template_rnd_bounds, template_rnds = template1._prepare_random_bounds(arr) + + print("choices", template_choices) + print("rnd bounds", template_rnd_bounds) + print("template_rnds", template_rnds) + + assert len(template_choices) == len(template_rnds) + assert len(template_choices) == len(template_rnd_bounds) + + for ix in range(len(template_choices)): + bounds = template_rnd_bounds[ix] + rnds = template_rnds[ix] + + assert len(bounds) == len(rnds) + + for iy in range(len(bounds)): + assert bounds[iy] == -1 or (rnds[iy] < bounds[iy]) + + + results = template1.pandasGenerateText(arr) + assert results is not None + + results_list = results.tolist() + + results_rows = len(results_list) + assert results_rows == TEST_ROWS + + for r in range(len(results)): + result_str = results[r] + assert result_str is not None and isinstance(result_str, str) + assert len(result_str) >= 0 + + print("results") + for i in range(len(results)): + print(f"{i}: '{results[i]}'") + + @pytest.mark.parametrize("template_provided, escapeSpecial, legacyEscapeTreatment,expectedPattern ", + [ (r'\\w \\w|\\w a. \\w', False, False, ""), + (r'\\w \\w|\\w a. \\w', False, True, ""), + (r'\w \w|\w a. \w', False, False, ""), + (r'\w.\w@\w.com', False, False, ""), + (r'\n-\n', False, False, ""), + ]) + def test_value_sub2(self, template_provided, escapeSpecial, legacyEscapeTreatment, expectedPattern): + template1 = TemplateGenerator(template_provided, escapeSpecialChars=escapeSpecial) + print(f"template [{template_provided}]") + + print("max_placeholders", template1._max_placeholders ) + print("max_rnds", template1._max_rnds_needed) + print("placeholders", template1._placeholders_needed ) + print("bounds", template1._template_rnd_bounds) + + print("templates", template1.templates) + + TEST_ROWS = 100 + + arr = np.arange(TEST_ROWS) + + template_choices, template_rnd_bounds, template_rnds = template1._prepare_random_bounds(arr) + + print("choices", template_choices) + print("rnd bounds", template_rnd_bounds) + print("template_rnds", template_rnds) + + assert len(template_choices) == len(template_rnds) + assert len(template_choices) == len(template_rnd_bounds) + + for ix in range(len(template_choices)): + bounds = template_rnd_bounds[ix] + rnds = template_rnds[ix] + + assert len(bounds) == len(rnds) + + for iy in range(len(bounds)): + assert bounds[iy] == -1 or (rnds[iy] < bounds[iy]) + + + results = template1.pandasGenerateText(arr) + assert results is not None + + results_list = results.tolist() + + results_rows = len(results_list) + assert results_rows == TEST_ROWS + + for r in range(len(results)): + result_str = results[r] + assert result_str is not None and isinstance(result_str, str) + assert len(result_str) >= 0 + + print("results") + for i in range(len(results)): + print(f"{i}: '{results[i]}'") + + + @pytest.mark.parametrize("template_provided, escapeSpecial, useTemplateObject", + [ (r'\w aAdDkK \w', False, False), + + (r'\\w \\w|\\w A. \\w', False, False), + (r'\w \w|\w A. \w', False, False), + (r'\w A. \w', False, False), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w k. \\w', False, False), + (r'\\w \\w|\\w K. \\w', False, False), + (r'\\w \\w|\\w x. \\w', False, False), + (r'\\w \\w|\\w X. \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, True), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w k. \\w', False, True), + (r'\\w \\w|\\w K. \\w', False, True), + (r'\\w \\w|\\w x. \\w', False, True), + (r'\\w \\w|\\w X. \\w', False, True), + (r'\\w \\w|\\w A. \\w', True, True), + (r'\w \w|\w \A. \w', True, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\w \w|\w \a. \w', True, True), + (r'\\w \\w|\\w k. \\w', True, True), + (r'\\w \\w|\\w K. \\w', True, True), + (r'\\w \\w|\\w x. \\w', True, True), + (r'\\w \\w|\\w X. \\w', True, True), + (r'\\w a. \\w|\\w \\w', False, False), + (r'\\w k. \\w|\\w \\w', False, False), + (r'\\w a. \\w', False, True), + (r'\\w k. \\w', False, True), + (r'\w \a. \w', True, True), + (r'\w \k. \w', True, True), + (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), + (r'\w \n \w', True, True), + (r'\w', True, True), + (r'\w', False, True), + (r'\w', False, False), + + ]) + + def test_full_build(self, template_provided, escapeSpecial, useTemplateObject): + pytest.skip("skipping to see if this is needed for coverage") + import dbldatagen as dg + print(f"template [{template_provided}]") + + data_rows = 10 * 1000 + + uniqueCustomers = 10 * 1000 + + dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=4) + .withColumn("customer_id", "long", uniqueValues=uniqueCustomers) + ) + + if useTemplateObject or escapeSpecial: + template1 = TemplateGenerator(template_provided, escapeSpecialChars=escapeSpecial) + dataspec = dataspec.withColumn("name", percentNulls=0.01, text=template1) + else: + dataspec = dataspec.withColumn("name", percentNulls=0.01, template=template_provided) + + df1 = dataspec.build() + df1.show() + + count = df1.where("name is not null").count() + assert count > 0 + +