diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 49eea723..fd7c7f99 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -41,12 +41,14 @@ from .spark_singleton import SparkSingleton from .text_generators import TemplateGenerator, ILText, TextGenerator from .text_generator_plugins import PyfuncText, PyfuncTextFactory, FakerTextFactory, fakerText +from .text_generatestring import GenerateString +from .value_based_prng import ValueBasedPRNG from .html_utils import HtmlUtils __all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange", "column_generation_spec", "utils", "function_builder", "spark_singleton", "text_generators", "datarange", "datagen_constants", - "text_generator_plugins", "html_utils" + "text_generator_plugins", "html_utils", "text_generatestring", "value_based_prng" ] diff --git a/dbldatagen/column_generation_spec.py b/dbldatagen/column_generation_spec.py index 6456a389..da9c99f1 100644 --- a/dbldatagen/column_generation_spec.py +++ b/dbldatagen/column_generation_spec.py @@ -1107,7 +1107,7 @@ def _applyPrefixSuffixExpressions(self, cprefix, csuffix, new_def): new_def = concat(new_def.astype(IntegerType()), lit(text_separator), lit(csuffix)) return new_def - def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations): + def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations=True): """Apply text generation expression to column expression :param new_def : column definition being created @@ -1118,6 +1118,9 @@ def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations): # while it seems like this could use a shared instance, this does not work if initialized # in a class method tg = self.textGenerator + + new_def = tg.prepareBaseValue(new_def) + if use_pandas_optimizations: self.executionHistory.append(f".. text generation via pandas scalar udf `{tg}`") u_value_from_generator = pandas_udf(tg.pandasGenerateText, diff --git a/dbldatagen/text_generatestring.py b/dbldatagen/text_generatestring.py new file mode 100644 index 00000000..fdb561bd --- /dev/null +++ b/dbldatagen/text_generatestring.py @@ -0,0 +1,181 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This file defines the RandomStr text generator +""" + +import math +import random + +import numpy as np +import pandas as pd + +import pyspark.sql.functions as F + +from .text_generators import TextGenerator +from .text_generators import _DIGITS_ZERO, _LETTERS_UPPER, _LETTERS_LOWER, _LETTERS_ALL + + +class GenerateString(TextGenerator): # lgtm [py/missing-equals] + """This class handles the generation of string text of specified length drawn from alphanumeric characters. + + The set of chars to be used can be modified based on the parameters + + This will generate deterministic strings chosen from the pool of characters `0-9`, `a-z`, `A-Z`, or from a + custom character range if specified. + + :param length: length of string. Can be integer, or tuple (min, max) + :param leadingAlpha: If True, leading character will be in range a-zAA-Z + :param allUpper: If True, any alpha chars will be uppercase + :param allLower: If True, any alpha chars will be lowercase + :param allAlpha: If True, all chars will be non numeric + :param customChars: If supplied, specifies a list of chars to use, or string of chars to use. + + This method will generate deterministic strings varying in size from `minLength` to `maxLength`. + The characters chosen will be in the range 0-9`, `a-z`, `A-Z` unless modified using the `leadingAlpha`, + `allUpper`, `allLower`, `allAlpha` or `customChars` parameters. + + The modifiers can be combined - for example GenerateString(1, 5, leadingAlpha=True, allUpper=True) + + When the length is specified to be a tuple, it wll generate variable length strings of lengths from the lower bound + to the upper bound inclusive. + + The strings are generated deterministically so that they can be used for predictable primary and foreign keys. + + If the column definition that includes this specifies `random` then the string generation will be determined by a + seeded random number according to the rules for random numbers and random seeds used in other columns + + If random is false, then the string will be generated from a pseudo random sequence generated purely from the + SQL hash of the `baseColumns` + + .. note:: + If customChars are specified, then the flag `allAlpha` will only remove digits. + + """ + + def __init__(self, length, leadingAlpha=True, allUpper=False, allLower=False, allAlpha=False, customChars=None): + super().__init__() + + assert not customChars or isinstance(customChars, (list, str)), \ + "`customChars` should be list of characters or string containing custom chars" + + assert not allUpper or not allLower, "allUpper and allLower cannot both be True" + + if isinstance(customChars, str): + assert len(customChars) > 0, "string of customChars must be non-empty" + elif isinstance(customChars, list): + assert all(isinstance(c, str) for c in customChars) + assert len(customChars) > 0, "list of customChars must be non-empty" + + self.leadingAlpha = leadingAlpha + self.allUpper = allUpper + self.allLower = allLower + self.allAlpha = allAlpha + + # determine base alphabet + if isinstance(customChars, list): + charAlphabet = set("".join(customChars)) + elif isinstance(customChars, str): + charAlphabet = set(customChars) + else: + charAlphabet = set(_LETTERS_ALL).union(set(_DIGITS_ZERO)) + + if allLower: + charAlphabet = charAlphabet.difference(set(_LETTERS_UPPER)) + elif allUpper: + charAlphabet = charAlphabet.difference(set(_LETTERS_LOWER)) + + if allAlpha: + charAlphabet = charAlphabet.difference(set(_DIGITS_ZERO)) + + self._charAlphabet = np.array(list(charAlphabet)) + + if leadingAlpha: + self._firstCharAlphabet = np.array(list(charAlphabet.difference(set(_DIGITS_ZERO)))) + else: + self._firstCharAlphabet = self._charAlphabet + + # compute string lengths + if isinstance(length, int): + self._minLength = length + self._maxLength = length + elif isinstance(length, tuple): + assert len(length) == 2, "only 2 elements can be specified if length is a tuple" + assert all(isinstance(el, int) for el in length) + self._minLength, self._maxLength = length + else: + raise ValueError("`length` must be an integer or a tuple of two integers") + + # compute bounds for generated strings + bounds = [len(self._firstCharAlphabet)] + for ix in range(1, self._maxLength): + bounds.append(len(self._charAlphabet)) + + self._bounds = bounds + + def __repr__(self): + return f"GenerateString(length={(self._minLength, self._maxLength)}, leadingAlpha={self.leadingAlpha})" + + def make_variable_length_mask(self, v, lengths): + """ given 2-d array of dimensions[r, c] and lengths of dimensions[r] + + generate mask for each row where col_index[r,c] < lengths[r] + """ + print(v.shape, lengths.shape) + assert v.shape[0] == lengths.shape[0], "values and lengths must agree on dimension 0]" + _, c_ix = np.indices(v.shape) + + return (c_ix.T < lengths.T).T + + def mk_bounds(self, v, minLength, maxLength): + rng = np.random.default_rng(42) + v_bounds = np.full(v.shape[0], (maxLength - minLength) + 1) + return rng.integers(v_bounds) + minLength + + def prepareBaseValue(self, baseDef): + """ Prepare the base value for processing + + :param baseDef: base value expression + :return: base value expression unchanged + + For generate string processing , we'll use the SQL function abs(hash(baseDef) + + This will ensure that even if there are multiple base values, only a single value is passed to the UDF + """ + return F.abs(F.hash(baseDef)) + + def pandasGenerateText(self, v): + """ entry point to use for pandas udfs + + Implementation uses vectorized implementation of process + + :param v: Pandas series of values passed as base values + :return: Pandas series of expanded templates + + """ + # placeholders is numpy array used to hold results + + rnds = np.full((v.shape[0], self._maxLength), len(self._charAlphabet), dtype=np.object_) + + rng = self.getNPRandomGenerator() + rnds2 = rng.integers(rnds) + + placeholders = np.full((v.shape[0], self._maxLength), '', dtype=np.object_) + + lengths = v.to_numpy() % (self._maxLength - self._minLength) + self._minLength + + v1 = np.full((v.shape[0], self._maxLength), -1) + + placeholder_mask = self.make_variable_length_mask(placeholders, lengths) + masked_placeholders = np.ma.MaskedArray(placeholders, mask=placeholder_mask) + + masked_placeholders[~placeholder_mask] = self._charAlphabet[rnds2[~placeholder_mask]] + + output = pd.Series(list(placeholders)) + + # join strings in placeholders + results = output.apply(lambda placeholder_items: "".join([str(elem) for elem in placeholder_items])) + + return results diff --git a/dbldatagen/text_generators.py b/dbldatagen/text_generators.py index 965350be..403e06d8 100644 --- a/dbldatagen/text_generators.py +++ b/dbldatagen/text_generators.py @@ -161,6 +161,16 @@ def getAsTupleOrElse(v, defaultValue, valueName): return defaultValue + def prepareBaseValue(self, baseDef): + """ Prepare the base value for processing + + :param baseDef: base value expression + :return: base value expression unchanged + + Derived classes are expected to override this if needed + """ + return baseDef + class TemplateGenerator(TextGenerator): # lgtm [py/missing-equals] """This class handles the generation of text from templates diff --git a/dbldatagen/value_based_prng.py b/dbldatagen/value_based_prng.py new file mode 100644 index 00000000..a7b65e86 --- /dev/null +++ b/dbldatagen/value_based_prng.py @@ -0,0 +1,301 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module defines the ValueBasedPRNG class + +The Value Based PRNG (psuedo random number generator) uses variations of the PCG algorithm to generate repeatable +psuedo random numbers. + +PCG operates by using classic LCG (Linear Congruential Generator) approaches to generate the random state +but applies a transform to generate the final pseudo random numbers. + +More details of the LCG and PCG algorithms can be found here. + - https://en.wikipedia.org/wiki/Linear_congruential_generator + - https://en.wikipedia.org/wiki/Permuted_congruential_generator + + +The implementation uses as a numpy array of values to seed different random seeds for each row +so that generating pseudo random sequences across each row of a numpy 2d array are completely repeatable and the +sequences for a given row will depend on the value of the seed + +The goal here is to seed each stream with a seed based on the value passed in for a given row +so that two rows with the same seed value will produce the same sequence. + +This differs from existing Numpy implementations as in existing Numpy random number generators, +each row does not have a separate random seed. + +In practice, this makes it ideal for generating repeatable foreign keys for a given primary key. + +""" + +import math +import numpy as np + + +class ValueBasedPRNG(object): + """ Value based psuedo-random number generator designed to generate repeatable random sequences. + + It uses the supplied array `values` to seed each a set of states for generating random numbers in sequences + of shape where the first dimension is a multiple of the size of the seed array `v`. + + The overall goal is to provide a prng that is replacement for the numpy random number generator where only + the `integers` method is used. + + While internal state is of type `numpy.uint64`, emitted values are of type `numpy.uint32` by default. + + Note that by array-like, we mean a list, a numpy array or a numeric scalar value + """ + + MASK_32 = 2 ** 32 - 1 + MASK_64 = 2 ** 64 - 1 + + # constants for PCG + MULTIPLIER = 0x5851f42d4c957f2d + INITIAL_STATE = 0x4d595df4d0f33173 + INCREMENT = 0x14057b7ef767814f + + ROWS = 20000 + COLUMNS = 50 + BITS_FOR_COLUMNS = int(math.ceil(math.log(COLUMNS, 2))) + + # constants for Java style LCG + MULTIPLIER = 0x5DEECE66D + INITIAL_STATE = 0 + INCREMENT = 11 + + # MULTIPLIER = 0x5851f42d4c957f2d + # INITIAL_STATE = 0x4d595df4d0f33173 + # INCREMENT = 0x14057b7ef767814f + + MAX_CYCLES = 25 + + def __init__(self, values, shape=None, additionalSeed=None): + """ Initialize the PRNG + + :param values: array-like list of values to seed the number generator. If values has second or additional + dimension, seedValues are hash of each row. Must be scalar, 1d or 2d array-like value + :param shape: If supplied, indicates that generation should optimize for generations of values of supplied + shape. + If supplied, must be 1-d or 2-d shape and first dimension must match first dimension of values + + param additionalSeed: if provided, the additional seed will be combined with base pcg seed. Must be of type int + + If a shape is supplied, it creates sufficient state to allow for generation of the number of values indicated + by the shape parameter independently. In effect, each row has a separate random generator and each column will + be generated by using a separate stream. + + This may raise a TypeError if the values cannot be converted to integers. + + If converting from string values, it expects that strings can be converted to ints, and if not + will raise a ValueError. + """ + assert values is not None, "`values` must be supplied" + + self._additionalSeed = additionalSeed + + effective_values = np.array(values) # if v is already numpy array, its a no-op + supplied_shape = effective_values.shape + + if len(supplied_shape) > 2: + raise ValueError("`values` must be scalar, 1d or 2d array-like value") + + if shape is not None and not isinstance(shape, tuple): + raise ValueError("`shape` must be tuple, if supplied") + + self._output_shape = shape or effective_values.shape + + print("shape and type", effective_values.shape, effective_values.dtype) + # reshape as needed + effective_values = self._reshapeAtLeast2d(effective_values) + + # apply hash if needed + if len(supplied_shape) > 2 or (len(supplied_shape) == 2 and supplied_shape[1] > 1): + # use hash of each row as seed + print("hashing") + hashing_shape = supplied_shape + while len(hashing_shape) > 2 or (len(hashing_shape) == 2 and hashing_shape[1] > 1): + effective_values = self._hashSeedValues(effective_values) + hashing_shape = effective_values.shape + elif np.issubdtype(effective_values.dtype, np.str_): + print("converting strings") + effective_values = self._hashSeedValues(effective_values) + elif np.issubdtype(effective_values.dtype, np.object_): + print("hashing objects") + effective_values = self._hashSeedValues(effective_values) + + # cast if needed to numpy.uint64 + effective_values = self._convertTypeIfNecessary(effective_values) + effective_values = self._reshapeAtLeast2d(effective_values) + + self._seed_values = effective_values + + # compute initial state + columns = self._columns_from_shape(effective_values.shape) + bits_for_columns = int(math.ceil(math.log(columns, 2))) + rows = self._rows_from_shape(effective_values.shape) + + column_increments = ((effective_values << bits_for_columns) | + np.arange(1, columns + 1, dtype=np.uint64)) << 1 | 1 + + self._state = np.full((rows, columns), column_increments + self.INITIAL_STATE, dtype=np.uint64) + self._incr = np.full((rows, columns), column_increments, dtype=np.uint64) + + def _reshapeAtLeast2d(self, arr): + """ Reshape array as 2d""" + if arr.shape == (): + return arr.reshape((1, 1)) + elif len(arr.shape) == 1: + return arr.reshape((arr.shape[0], 1)) + else: + return arr + + def _hashSeedValues(self, values): + """ Hash and reshape 2d values """ + print("hashing values", values) + values_shape = values.shape + results = np.apply_along_axis(lambda r: hash(tuple(r)), 1, values).astype(np.uint64) + + return results + + def _convertTypeIfNecessary(self, values): + """ Convert values array""" + results = values + # cast if needed to numpy.uint64 + if values.dtype != np.uint64: + if np.can_cast(values, np.uint64, casting="unsafe"): + results = values.astype(np.uint64, casting="unsafe") + else: + raise TypeError("Cant cast values to np.int64") + return results + + @staticmethod + def _columns_from_shape(shape): + assert isinstance(shape, tuple), "expecting tuple for shape" + + if len(shape) < 2: + return 1 + return shape[1] + + @staticmethod + def _rows_from_shape(shape): + assert isinstance(shape, tuple), "expecting tuple for shape" + + return shape[0] + + @property + def shape(self): + """get the `shape` attribute""" + return self._output_shape + + @property + def seedValues(self): + """ Get the values that were used to seed the PRNG""" + return self._seed_values + + def random_r(self, state, incr): + # compute state change using classic LCG algorithm + x = state[...] + state[...] = (x * self.MULTIPLIER + incr) & self.MASK_64 + + # then use transformation on state to generate actual random numbers + # 32 bit variation + # xorstate = (((x >> 18) ^ x) >> 27) & MASK_32 + # count = (x >> 59) & MASK_32 + # results = ((xorstate >> count) | (xorstate << (-count & 31))) & MASK_32 + + count = (x >> 59) & self.MASK_32 + results = x >> (16 + count) & self.MASK_32 + return results + + def bounded_random(self, bound, state, incr): + """ Use rejection sampling to remove modulo bias + + :param bound: random state + :param state: random state + :param incr: random increment + :return: + + Modulo bias occurs when a set of numbers to be used in a modulus calculation is not an even + multiple of the divisor. + + For example: for the set of numbers 0 .. 67, `n` mod 16 will return more numbers in the range 1 .. 3 + + So when using classic LCG random number generation `(aX + c) mod m`, the results will not be uniformly + distributed unless you discard numbers beyond a threshold of the highest before applying the modulo arithmetic. + + This threshold can be calculated using the expresson ``(-bound & MASK_32) % bound`` for a 32 bit random number. + + While the loop to find numbers over the threshold is usually short, we add a limit to the number of attempts + to retry the random number generation. This will produce slightly non-uniform distribution in rare cases + + As the intent for this random number generator to compute random word and character offsets and other uses + where the bounds are relatively low, this is acceptable for our use case. + + See: + "Efficiently Generating a Random Number in a Range" by Dr M.E O'Neill + https://www.pcg-random.org/posts/bounded-rands.html + + """ + + threshold = (-bound & self.MASK_32) % bound + + r = self.random_r(state, incr) + assert r is not None + + mask = r >= threshold + r[mask] = r[mask] % bound + print(np.all(mask)) + + cycles = 0 + + while not np.all(mask) and cycles < self.MAX_CYCLES: + r1 = self.random_r(state, incr) + mask = r1 >= threshold + r[mask] = r1[mask] % bound + cycles = cycles + 1 + print("cycle", cycles) + + final_mask = r >= bound + if np.any(final_mask): + print("fixup") + r[final_mask] = r[final_mask] % bound + + return r + + def integers(self, low, high=None, size=None, dtype=None, endpoint=False): + """ Return psuedo-random integers from low (inclusive) to high (exclusive), or if endpoint=True, + low (inclusive) to high (inclusive). + + :param low: lowint or array-like of ints. Lowest (signed) integers to be drawn from the distribution + (unless high=None, in which case this parameter is 0 and this value is used for high). + :param high: highint or array-like of ints, optional. f provided, one above the largest (signed) integer + to be drawn from the distribution (see above for behavior if high=None). + If array-like, must contain integer values + :param size: int or tuple of ints, optional that defines the output shape. + If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. + Default is None, in which case the size of the set of values used to initialize the PRNG + is used. + :param dtype: Desired Numpy dtype of the result. Byteorder must be native. The default value is np.int64 + :param endpoint: If true, sample from the interval [low, high] instead of the default [low, high). + Defaults to False + :return: ndarray of ints. size-shaped array of random integers from the appropriate distribution, + or a single such random int if size not provided. + + """ + low_arr = np.array(low) + + if high is None: + high_arr = low_arr + low_arr = np.full(low_arr.shape, 0) + else: + high_arr = np.array(high) + + if endpoint: + high_arr = high_arr + 1 + + if np.any(high_arr <= 0): + raise ValueError("high <= 0") + + return low_arr.reshape(low.shape) diff --git a/docs/utils/mk_quick_index.py b/docs/utils/mk_quick_index.py index c3d08953..a4f554a6 100644 --- a/docs/utils/mk_quick_index.py +++ b/docs/utils/mk_quick_index.py @@ -33,10 +33,14 @@ "grouping": "main classes"}, "text_generator_plugins.py": {"briefDesc": "Text data generation", "grouping": "main classes"}, + "text_generatestring.py": {"briefDesc": "Text data generation", + "grouping": "main classes"}, "data_analyzer.py": {"briefDesc": "Analysis of existing data", "grouping": "main classes"}, "function_builder.py": {"briefDesc": "Internal utilities to create functions related to weights", "grouping": "internal classes"}, + "value_based_prng.py": {"briefDesc": "Value based pseudo-random number generator", + "grouping": "internal classes"}, "schema_parser.py": {"briefDesc": "Internal utilities to parse Spark SQL schema information", "grouping": "internal classes"}, "spark_singleton.py": {"briefDesc": "Spark singleton for test purposes", diff --git a/tests/test_text_generatestring.py b/tests/test_text_generatestring.py new file mode 100644 index 00000000..9d6fc599 --- /dev/null +++ b/tests/test_text_generatestring.py @@ -0,0 +1,99 @@ +import pytest +import pyspark.sql.functions as F +from pyspark.sql.types import BooleanType, DateType +from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType + +import dbldatagen as dg + +spark = dg.SparkSingleton.getLocalInstance("unit tests") + +spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000") +spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + +#: list of digits for template generation +_DIGITS_ZERO = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + +#: list of uppercase letters for template generation +_LETTERS_UPPER = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'T', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z'] + +#: list of lowercase letters for template generation +_LETTERS_LOWER = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] + +#: list of all letters uppercase and lowercase +_LETTERS_ALL = _LETTERS_LOWER + _LETTERS_UPPER + +#: list of alphanumeric chars in lowercase +_ALNUM_LOWER = _LETTERS_LOWER + _DIGITS_ZERO + +#: list of alphanumeric chars in uppercase +_ALNUM_UPPER = _LETTERS_UPPER + _DIGITS_ZERO + + +# Test manipulation and generation of test data for a large schema +class TestTextGenerateString: + + @pytest.mark.parametrize("length, leadingAlpha, allUpper, allLower, allAlpha, customChars", + [ + (5, True, True, False, False, None), + (5, True, False, True, False, None), + (5, True, False, False, True, None), + (5, False, False, False, False, None), + (5, False, True, False, True, None), + (5, False, False, True, True, None), + (5, False, False, False, False, "01234567890ABCDEF"), + ]) + def test_basics(self, length, leadingAlpha, allUpper, allLower, allAlpha, customChars): + + tg1 = dg.GenerateString(length, leadingAlpha=leadingAlpha, allUpper=allUpper, allLower=allLower, + allAlpha=allAlpha, customChars=customChars) + + assert tg1._charAlphabet is not None + assert tg1._firstCharAlphabet is not None + + if allUpper and allAlpha: + alphabet = _LETTERS_UPPER + elif allLower and allAlpha: + alphabet = _LETTERS_LOWER + elif allLower: + alphabet = _LETTERS_LOWER + _DIGITS_ZERO + elif allUpper: + alphabet = _LETTERS_UPPER + _DIGITS_ZERO + elif allAlpha: + alphabet = _LETTERS_UPPER + _LETTERS_LOWER + else: + alphabet = _LETTERS_UPPER + _LETTERS_LOWER + _DIGITS_ZERO + + if customChars is not None: + alphabet = set(alphabet).intersection(set(customChars)) + + assert set(tg1._charAlphabet) == set(alphabet) + + @pytest.mark.parametrize("genstr", + [ + dg.GenerateString((1, 10)), + dg.GenerateString((1, 10), leadingAlpha=True), + dg.GenerateString((4, 64), allUpper=True), + dg.GenerateString((10, 20), allLower=True), + dg.GenerateString((1, 10)), + dg.GenerateString((3, 15)), + dg.GenerateString((17, 22)), + dg.GenerateString((1, 10)), + ]) + def test_simple_data(self, genstr): + dgspec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, + partitions=4, seedMethod='hash_fieldname', verbose=True, + seedColumnName="_id") + .withIdOutput() + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), text=dg.GenerateString((1, 10))) + ) + + fieldsFromGenerator = set(dgspec.getOutputColumnNames()) + + df_testdata = dgspec.build() + + df_testdata.show() diff --git a/tests/test_value_based_PRNG.py b/tests/test_value_based_PRNG.py new file mode 100644 index 00000000..d34ba84b --- /dev/null +++ b/tests/test_value_based_PRNG.py @@ -0,0 +1,144 @@ +import re +import pytest +import pandas as pd +import numpy as np + +import pyspark.sql.functions as F + +from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType + +import dbldatagen as dg + +# add the following if using pandas udfs +# .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") \ + + +spark = dg.SparkSingleton.getLocalInstance("unit tests") + + +# Test manipulation and generation of test data for a large schema +class TestValueBasedPRNG: + + def test_basics(self): + arr = np.arange(100) + + rng = dg.ValueBasedPRNG(arr, shape=(len(arr), 10)) + + assert rng is not None + + assert rng.shape == (len(arr), 10) + + print(rng.seedValues) + + assert rng.seedValues is not None + print(rng.seedValues.shape, rng.seedValues.dtype) + assert np.array_equal(rng.seedValues.reshape(arr.shape), arr) + + @pytest.mark.skip(reason="work in progress") + @pytest.mark.parametrize("data, shape, bounds, endpoint", + [ + (np.arange(1024), (1024, 10), 255, False), + (np.arange(1024), (1024, 15), np.arange(15) * 3 + 1, False), + (np.arange(1024), (1024,), 255, False), + (np.arange(1024), (1024, 10), 255, False), + (np.arange(1024), (1024, 15), np.arange(15) * 3 + 1, False), + (1, (1, 5), [1, 2, 3, 4, 5], False), + (10, None, [1, 2, 3, 4, 5], False), + ]) + def test_integers(self, data, shape, bounds, endpoint): + + arr = np.array(data) + rng = dg.ValueBasedPRNG(arr, shape=shape) + + bounds_arr = np.full(shape, bounds) + + # make sure that we have results that are always simply the passed in bounds + # they can be occasionally equal but not always equal + + cumulative_equality = bounds_arr >= 0 # will be array of boolean values all true + + for ix in range(10): + results = rng.integers(bounds_arr, endpoint=endpoint) + + assert results is not None, "results should be not None" + assert results.shape == bounds_arr.shape, "results should be of target shape" + + results_equality = results == bounds_arr + cumulative_equality = cumulative_equality and results_equality + print(cumulative_equality) + + assert not np.array_equal(results, bounds_arr) + + @pytest.mark.parametrize("data, shape", + [ + (23, (1024, 10)), + (23.10, (1024, 15)), + (23, None), + (23.10, None), + ([[1, 2, 3], [1, 2, 3]], None), + ([[1, "two", 3], [1, 2, "three"]], None), + ("test", None), + (["test", "test2"], None), + ([True, False, True], None), + ((1, 2, 3), None), + (np.datetime64('2005-10-24'), None) + + ]) + def test_initialization_success(self, data, shape): + + arr = np.array(data) + rng = dg.ValueBasedPRNG(arr, shape=shape) + + assert rng is not None + + print(rng.seedValues.shape, rng.seedValues.dtype, rng.seedValues) + + @pytest.mark.parametrize("data, shape", + [ + ({'a': 1, 'b': 2}, 25), # bad shape + ( 34, 45), # bad shape + ([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], None), # bad value - dimensions + ]) + def test_initialization_fail_value(self, data, shape): + + with pytest.raises(ValueError): + arr = np.array(data) + rng = dg.ValueBasedPRNG(arr, shape=shape) + + assert rng is not None + + @pytest.mark.parametrize("data, shape", + [ + ({'a': 1, 'b': 2}, (1024,)), + ]) + def test_initialization_fail_type(self, data, shape): + + with pytest.raises(TypeError): + arr = np.array(data) + rng = dg.ValueBasedPRNG(arr, shape=shape) + + assert rng is not None + + @pytest.mark.skip(reason="work in progress") + def test_udfs(self): + + def exampleUdf(v): + v1 = pd.DataFrame(v) + mk_str_fn = lambda x: str(hash(tuple(x))) # str(x) + results = v1.apply(mk_str_fn, axis=1) + return pd.Series(results) + + testUdf = F.pandas_udf(exampleUdf, returnType=StringType()).asNondeterministic() + + df = (spark.range(1024) + .withColumn("v1", F.expr("id * id")) + .withColumn("v2", F.expr("id * id")) + .withColumn("v3", F.expr("id * id")) + .withColumn("v4", F.expr("'test'")) + .withColumn("v5", testUdf(F.array(F.col("v3"), F.col("v4")))) + .withColumn("v6", testUdf(F.col("v3"))) + ) + + df.show()