diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a4b2e..8c25f0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,13 @@ All notable changes to this project will be documented in this file. - restructured runner config - added metadata and feature loader sections - target moved to pipeline - - dependency date_col is mandatory - - custom extras config is available in each pipeline and will be passed as dictionary + - dependency date_col is now mandatory + - custom extras config is available in each pipeline and will be passed as dictionary available under pipeline_config.extras - general section is renamed to runner - transformation header changed - added argument to skip dependency checking #### Jobs + - jobs are now the main way to create all pipelines - config holder removed from jobs - metadata_manager and feature_loader are now available arguments, depending on configuration #### TableReader @@ -21,7 +22,8 @@ All notable changes to this project will be documented in this file. - info_date_from -> date_from, info_date_to -> date_to - date_column is now mandatory - removed TableReaders ability to infer schema from partitions or properties - + #### Loader + - removed DataLoader class, now only PysparkFeatureLoader is needed with additional parameters ## 1.3.0 - 2024-06-07 diff --git a/README.md b/README.md index cc7b01d..56ccaea 100644 --- a/README.md +++ b/README.md @@ -423,19 +423,6 @@ This module is used to load features from feature store into your models and scr Two public classes are exposed form this module. **DatabricksLoader**(DataLoader), **PysparkFeatureLoader**(FeatureLoaderInterface). -### DatabricksLoader -This is a support class for feature loader and provides the data reading capability from the feature store. - -This class needs to be instantiated with an active spark session and a path to the feature store schema (in the format of "catalog_name.schema_name"). -Optionally a date_column information can be passed, otherwise it defaults to use INFORMATION_DATE -```python -from rialto.loader import DatabricksLoader - -data_loader = DatabricksLoader(spark= spark_instance, schema= "catalog.schema", date_column= "INFORMATION_DATE") -``` - -This class provides one method, read_group(...), which returns a whole feature group for selected date. This is mostly used inside feature loader. - ### PysparkFeatureLoader This class needs to be instantiated with an active spark session, data loader and a path to the metadata schema (in the format of "catalog_name.schema_name"). @@ -443,17 +430,16 @@ This class needs to be instantiated with an active spark session, data loader an ```python from rialto.loader import PysparkFeatureLoader -feature_loader = PysparkFeatureLoader(spark= spark_instance, data_loader= data_loader_instance, metadata_schema= "catalog.schema") +feature_loader = PysparkFeatureLoader(spark= spark_instance, feature_schema="catalog.schema", metadata_schema= "catalog.schema2", date_column="information_date") ``` #### Single feature ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() feature = feature_loader.get_feature(group_name="CustomerFeatures", feature_name="AGE", information_date=my_date) @@ -464,11 +450,10 @@ metadata = feature_loader.get_feature_metadata(group_name="CustomerFeatures", fe This method of data access is only recommended for experimentation, as the group schema can evolve over time. ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_group(group_name="CustomerFeatures", information_date=my_date) @@ -478,11 +463,10 @@ metadata = feature_loader.get_group_metadata(group_name="CustomerFeatures") #### Configuration ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_features_from_cfg(path="local/configuration/file.yaml", information_date=my_date) diff --git a/rialto/common/__init__.py b/rialto/common/__init__.py index 93e8922..1bd5055 100644 --- a/rialto/common/__init__.py +++ b/rialto/common/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.common.table_reader import TableReader +from rialto.common.table_reader import DataReader, TableReader diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 6c2952c..b2e19b4 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -38,43 +38,11 @@ def load_yaml(path: str) -> Any: return yaml.load(stream, EnvLoader) -def get_date_col_property(spark, table: str, property: str) -> str: - """ - Retrieve a data column name from a given table property - - :param spark: spark session - :param table: path to table - :param property: name of the property - :return: data column name - """ - props = spark.sql(f"show tblproperties {table}") - date_col = props.filter(F.col("key") == property).select("value").collect() - if len(date_col): - return date_col[0].value - else: - raise RuntimeError(f"Table {table} has no property {property}.") - - -def get_delta_partition(spark, table: str) -> str: - """ - Select first partition column of the delta table - - :param table: full table name - :return: partition column name - """ - columns = spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") - - def cast_decimals_to_floats(df: DataFrame) -> DataFrame: """ Find all decimal types in the table and cast them to floats. Fixes errors in .toPandas() conversions. - :param df: pyspark DataFrame + :param df: input df :return: pyspark DataFrame with fixed types """ decimal_cols = [col_name for col_name, data_type in df.dtypes if "decimal" in data_type] diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index 94b7409..217b436 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -93,7 +93,7 @@ def job(*args, custom_name=None, disable_version=False): module = _get_module(stack) version = _get_version(module) - # Use case where it's just raw @f. Otherwise we get [] here. + # Use case where it's just raw @f. Otherwise, we get [] here. if len(args) == 1 and callable(args[0]): f = args[0] return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version) diff --git a/rialto/jobs/decorators/resolver.py b/rialto/jobs/decorators/resolver.py index f13f0eb..26856d1 100644 --- a/rialto/jobs/decorators/resolver.py +++ b/rialto/jobs/decorators/resolver.py @@ -101,7 +101,7 @@ def cache_clear(cls) -> None: """ Clear resolver cache. - The resolve mehtod caches its results to avoid duplication of resolutions. + The resolve method caches its results to avoid duplication of resolutions. However, in case we re-register some callables, we need to clear cache in order to ensure re-execution of all resolutions. diff --git a/rialto/loader/__init__.py b/rialto/loader/__init__.py index 7adc52d..7e1e936 100644 --- a/rialto/loader/__init__.py +++ b/rialto/loader/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.loader.data_loader import DatabricksLoader from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader diff --git a/rialto/loader/data_loader.py b/rialto/loader/data_loader.py deleted file mode 100644 index dc13572..0000000 --- a/rialto/loader/data_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["DatabricksLoader"] - -from datetime import date - -from pyspark.sql import DataFrame, SparkSession - -from rialto.common.table_reader import TableReader -from rialto.loader.interfaces import DataLoader - - -class DatabricksLoader(DataLoader): - """Implementation of DataLoader using TableReader to access feature tables""" - - def __init__(self, spark: SparkSession, schema: str, date_column: str = "INFORMATION_DATE"): - super().__init__() - - self.reader = TableReader(spark) - self.schema = schema - self.date_col = date_column - - def read_group(self, group: str, information_date: date) -> DataFrame: - """ - Read a feature group by getting the latest partition by date - - :param group: group name - :param information_date: partition date - :return: dataframe - """ - return self.reader.get_latest( - f"{self.schema}.{group}", date_until=information_date, date_column=self.date_col, uppercase_columns=True - ) diff --git a/rialto/loader/interfaces.py b/rialto/loader/interfaces.py index dad08e6..9089f40 100644 --- a/rialto/loader/interfaces.py +++ b/rialto/loader/interfaces.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["DataLoader", "FeatureLoaderInterface"] +__all__ = ["FeatureLoaderInterface"] import abc from datetime import date from typing import Dict -class DataLoader(metaclass=abc.ABCMeta): - """ - An interface to read feature groups from storage - - Requires read_group function. - """ - - @abc.abstractmethod - def read_group(self, group: str, information_date: date): - """ - Read one feature group - - :param group: Group name - :param information_date: date - """ - raise NotImplementedError - - class FeatureLoaderInterface(metaclass=abc.ABCMeta): """ A definition of feature loading interface diff --git a/rialto/loader/pyspark_feature_loader.py b/rialto/loader/pyspark_feature_loader.py index d0eef20..7ee78fc 100644 --- a/rialto/loader/pyspark_feature_loader.py +++ b/rialto/loader/pyspark_feature_loader.py @@ -20,9 +20,9 @@ from pyspark.sql import DataFrame, SparkSession +from rialto.common import TableReader from rialto.common.utils import cast_decimals_to_floats from rialto.loader.config_loader import FeatureConfig, GroupConfig, get_feature_config -from rialto.loader.data_loader import DataLoader from rialto.loader.interfaces import FeatureLoaderInterface from rialto.metadata.metadata_manager import ( FeatureMetadata, @@ -34,7 +34,13 @@ class PysparkFeatureLoader(FeatureLoaderInterface): """Implementation of feature loader for pyspark environment""" - def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema: str): + def __init__( + self, + spark: SparkSession, + feature_schema: str, + metadata_schema: str, + date_column: str = "INFORMATION_DATE", + ): """ Init @@ -44,11 +50,28 @@ def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema """ super().__init__() self.spark = spark - self.data_loader = data_loader + self.reader = TableReader(spark) + self.feature_schema = feature_schema + self.date_col = date_column self.metadata = MetadataManager(spark, metadata_schema) KeyMap = namedtuple("KeyMap", ["df", "key"]) + def read_group(self, group: str, information_date: date) -> DataFrame: + """ + Read a feature group by getting the latest partition by date + + :param group: group name + :param information_date: partition date + :return: dataframe + """ + return self.reader.get_latest( + f"{self.feature_schema}.{group}", + date_until=information_date, + date_column=self.date_col, + uppercase_columns=True, + ) + def get_feature(self, group_name: str, feature_name: str, information_date: date) -> DataFrame: """ Get single feature @@ -60,9 +83,7 @@ def get_feature(self, group_name: str, feature_name: str, information_date: date """ print("This function is untested, use with caution!") key = self.get_group_metadata(group_name).key - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date).select( - *key, feature_name - ) + return self.read_group(self.get_group_fs_name(group_name), information_date).select(*key, feature_name) def get_feature_metadata(self, group_name: str, feature_name: str) -> FeatureMetadata: """ @@ -83,7 +104,7 @@ def get_group(self, group_name: str, information_date: date) -> DataFrame: :return: A dataframe containing feature group key """ print("This function is untested, use with caution!") - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date) + return self.read_group(self.get_group_fs_name(group_name), information_date) def get_group_metadata(self, group_name: str) -> GroupMetadata: """ @@ -144,7 +165,7 @@ def _get_keymaps(self, config: FeatureConfig, information_date: date) -> List[Ke """ key_maps = [] for mapping in config.maps: - df = self.data_loader.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") + df = self.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") key = self.metadata.get_group(mapping).key key_maps.append(PysparkFeatureLoader.KeyMap(df, key)) return key_maps @@ -174,9 +195,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: """ config = get_feature_config(path) # 1 select keys from base - base = self.data_loader.read_group(self.get_group_fs_name(config.base.group), information_date).select( - config.base.keys - ) + base = self.read_group(self.get_group_fs_name(config.base.group), information_date).select(config.base.keys) # 2 join maps onto base (resolve keys) if config.maps: key_maps = self._get_keymaps(config, information_date) @@ -184,7 +203,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: # 3 read, select and join other tables for group_cfg in config.selection: - df = self.data_loader.read_group(self.get_group_fs_name(group_cfg.group), information_date) + df = self.read_group(self.get_group_fs_name(group_cfg.group), information_date) base = self._add_feature_group(base, df, group_cfg) # 4 fix dtypes for pandas conversion diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index e3efe01..3fc13b4 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -16,22 +16,15 @@ import datetime from datetime import date -from importlib import import_module from typing import List, Tuple import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession +import rialto.runner.utils as utils from rialto.common import TableReader -from rialto.loader import DatabricksLoader, PysparkFeatureLoader -from rialto.metadata import MetadataManager -from rialto.runner.config_loader import ( - ModuleConfig, - PipelineConfig, - ScheduleConfig, - get_pipelines_config, -) +from rialto.runner.config_loader import PipelineConfig, get_pipelines_config from rialto.runner.date_manager import DateManager from rialto.runner.table import Table from rialto.runner.tracker import Record, Tracker @@ -84,39 +77,16 @@ def __init__( raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") logger.info(f"Running period from {self.date_from} until {self.date_until}") - def _load_module(self, cfg: ModuleConfig) -> Transformation: + def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: """ - Load feature group - - :param cfg: Feature configuration - :return: Transformation object - """ - module = import_module(cfg.python_module) - class_obj = getattr(module, cfg.python_class) - return class_obj() - - def _generate(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: - """ - Run feature group + Run the job :param instance: Instance of Transformation :param run_date: date to run for :param pipeline: pipeline configuration :return: Dataframe """ - if pipeline.metadata_manager is not None: - metadata_manager = MetadataManager(self.spark, pipeline.metadata_manager.metadata_schema) - else: - metadata_manager = None - - if pipeline.feature_loader is not None: - feature_loader = PysparkFeatureLoader( - self.spark, - DatabricksLoader(self.spark, schema=pipeline.feature_loader.feature_schema), - metadata_schema=pipeline.feature_loader.metadata_schema, - ) - else: - feature_loader = None + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) df = instance.run( spark=self.spark, @@ -130,15 +100,6 @@ def _generate(self, instance: Transformation, run_date: date, pipeline: Pipeline return df - def _table_exists(self, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return self.spark.catalog.tableExists(table) - def _write(self, df: DataFrame, info_date: date, table: Table) -> None: """ Write dataframe to storage @@ -152,35 +113,6 @@ def _write(self, df: DataFrame, info_date: date, table: Table) -> None: df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) logger.info(f"Results writen to {table.get_table_path()}") - def _delta_partition(self, table: str) -> str: - """ - Select first partition column, should be only one - - :param table: full table name - :return: partition column name - """ - columns = self.spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") - - def _get_partitions(self, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - self.reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bool]: """ For given list of dates, check if there is a matching partition for each @@ -189,8 +121,8 @@ def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bo :param dates: list of dates to check :return: list of bool """ - if self._table_exists(table.get_table_path()): - partitions = self._get_partitions(table) + if utils.table_exists(table.get_table_path()): + partitions = utils.get_partitions(self.reader, table) return [(date in partitions) for date in dates] else: logger.info(f"Table {table.get_table_path()} doesn't exist!") @@ -230,25 +162,6 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: return True - def get_possible_run_dates(self, schedule: ScheduleConfig) -> List[date]: - """ - List possible run dates according to parameters and config - - :param schedule: schedule config - :return: List of dates - """ - return DateManager.run_dates(self.date_from, self.date_until, schedule) - - def get_info_dates(self, schedule: ScheduleConfig, run_dates: List[date]) -> List[date]: - """ - Transform given dates into info dates according to the config - - :param schedule: schedule config - :param run_dates: date list - :return: list of modified dates - """ - return [DateManager.to_info_date(x, schedule) for x in run_dates] - def _get_completion(self, target: Table, info_dates: List[date]) -> List[bool]: """ Check if model has run for given dates @@ -270,8 +183,8 @@ def _select_run_dates(self, pipeline: PipelineConfig, table: Table) -> Tuple[Lis :param table: table path :return: list of run dates and list of info dates """ - possible_run_dates = self.get_possible_run_dates(pipeline.schedule) - possible_info_dates = self.get_info_dates(pipeline.schedule, possible_run_dates) + possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) + possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] current_state = self._get_completion(table, possible_info_dates) selection = [ @@ -300,8 +213,8 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat if self.skip_dependencies or self.check_dependencies(pipeline, run_date): logger.info(f"Running {pipeline.name} for {run_date}") - feature_group = self._load_module(pipeline.module) - df = self._generate(feature_group, run_date, pipeline) + feature_group = utils.load_module(pipeline.module) + df = self._execute(feature_group, run_date, pipeline) records = df.count() if records > 0: self._write(df, info_date, target) @@ -349,8 +262,8 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except Exception as error: - print(f"An exception occurred in pipeline {pipeline.name}") - print(error) + logger.error(f"An exception occurred in pipeline {pipeline.name}") + logger.error(error) self.tracker.add( Record( job=pipeline.name, @@ -364,7 +277,7 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except KeyboardInterrupt: - print(f"Pipeline {pipeline.name} interrupted") + logger.error(f"Pipeline {pipeline.name} interrupted") self.tracker.add( Record( job=pipeline.name, diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py index 7b5eaa8..5b6f2eb 100644 --- a/rialto/runner/transformation.py +++ b/rialto/runner/transformation.py @@ -19,7 +19,7 @@ from pyspark.sql import DataFrame, SparkSession -from rialto.common import TableReader +from rialto.common import DataReader from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner.config_loader import PipelineConfig @@ -31,7 +31,7 @@ class Transformation(metaclass=abc.ABCMeta): @abc.abstractmethod def run( self, - reader: TableReader, + reader: DataReader, run_date: datetime.date, spark: SparkSession = None, config: PipelineConfig = None, diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py new file mode 100644 index 0000000..b74ec1b --- /dev/null +++ b/rialto/runner/utils.py @@ -0,0 +1,74 @@ +from datetime import date +from importlib import import_module +from typing import List, Tuple + +from pyspark.sql import SparkSession + +from rialto.common import DataReader +from rialto.loader import PysparkFeatureLoader +from rialto.metadata import MetadataManager +from rialto.runner.config_loader import ModuleConfig, PipelineConfig +from rialto.runner.table import Table +from rialto.runner.transformation import Transformation + + +def load_module(cfg: ModuleConfig) -> Transformation: + """ + Load feature group + + :param cfg: Feature configuration + :return: Transformation object + """ + module = import_module(cfg.python_module) + class_obj = getattr(module, cfg.python_class) + return class_obj() + + +def table_exists(spark: SparkSession, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return spark.catalog.tableExists(table) + + +def get_partitions(reader: DataReader, table: Table) -> List[date]: + """ + Get partition values + + :param table: Table object + :return: List of partition values + """ + rows = ( + reader.get_table(table.get_table_path(), date_column=table.partition) + .select(table.partition) + .distinct() + .collect() + ) + return [r[table.partition] for r in rows] + + +def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: + """ + Initialize metadata manager and feature loader + + :param spark: Spark session + :param pipeline: Pipeline configuration + :return: MetadataManager and PysparkFeatureLoader + """ + if pipeline.metadata_manager is not None: + metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema) + else: + metadata_manager = None + + if pipeline.feature_loader is not None: + feature_loader = PysparkFeatureLoader( + spark, + feature_schema=pipeline.feature_loader.feature_schema, + metadata_schema=pipeline.feature_loader.metadata_schema, + ) + else: + feature_loader = None + return metadata_manager, feature_loader diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index fa8f19c..55fced1 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -20,7 +20,7 @@ import tests.jobs.resources as resources from rialto.jobs.decorators.resolver import Resolver -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader def test_setup_except_feature_loader(spark): @@ -39,7 +39,7 @@ def test_setup_except_feature_loader(spark): def test_setup_feature_loader(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - feature_loader = PysparkFeatureLoader(spark, DatabricksLoader(spark, "", ""), "") + feature_loader = PysparkFeatureLoader(spark, "", "", "") resources.CustomJobNoReturnVal().run( reader=table_reader, run_date=date, spark=spark, config=None, feature_loader=feature_loader diff --git a/tests/loader/pyspark/dummy_loaders.py b/tests/loader/pyspark/dummy_loaders.py deleted file mode 100644 index a2b0cb8..0000000 --- a/tests/loader/pyspark/dummy_loaders.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import date - -from rialto.loader.data_loader import DataLoader - - -class DummyDataLoader(DataLoader): - def __init__(self): - super().__init__() - - def read_group(self, group: str, information_date: date): - return None diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py index 3ad653e..dd2049f 100644 --- a/tests/loader/pyspark/test_from_cfg.py +++ b/tests/loader/pyspark/test_from_cfg.py @@ -21,7 +21,6 @@ from rialto.loader.config_loader import get_feature_config from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader from tests.loader.pyspark.dataframe_builder import dataframe_builder as dfb -from tests.loader.pyspark.dummy_loaders import DummyDataLoader @pytest.fixture(scope="session") @@ -45,7 +44,7 @@ def spark(request): @pytest.fixture(scope="session") def loader(spark): - return PysparkFeatureLoader(spark, DummyDataLoader(), MagicMock()) + return PysparkFeatureLoader(spark, MagicMock(), MagicMock()) VALID_LIST = [(["a"], ["a"]), (["a"], ["a", "b", "c"]), (["c", "a"], ["a", "b", "c"])] @@ -90,7 +89,7 @@ def __call__(self, *args, **kwargs): metadata = MagicMock() monkeypatch.setattr(metadata, "get_group", GroupMd()) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") loader.metadata = metadata base = dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns) @@ -105,7 +104,7 @@ def __call__(self, *args, **kwargs): def test_get_group_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", return_value=7) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_group_metadata("group_name") assert ret_val == 7 @@ -115,7 +114,7 @@ def test_get_group_metadata(spark, mocker): def test_get_feature_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_feature", return_value=8) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_feature_metadata("group_name", "feature") assert ret_val == 8 @@ -129,7 +128,7 @@ def test_get_metadata_from_cfg(spark, mocker): ) mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", side_effect=lambda g: {"B": 10}[g]) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") metadata = loader.get_metadata_from_cfg("tests/loader/pyspark/example_cfg.yaml") assert metadata["B_F1"] == 1 diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 85ddf95..2171c7b 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from datetime import datetime from typing import Optional @@ -61,20 +60,10 @@ def get_latest( def test_table_exists(spark, mocker, basic_runner): mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True) - basic_runner._table_exists("abc") + basic_runner.table_exists("abc") mock.assert_called_once_with("abc") -def test_infer_column(spark, mocker, basic_runner): - column = namedtuple("catalog", ["name", "isPartition"]) - catalog = [column("a", True), column("b", False), column("c", False)] - - mock = mocker.patch("pyspark.sql.Catalog.listColumns", return_value=catalog) - partition = basic_runner._delta_partition("aaa") - assert partition == "a" - mock.assert_called_once_with("aaa") - - def test_load_module(spark, basic_runner): module = basic_runner._load_module(basic_runner.config.pipelines[0].module) assert isinstance(module, SimpleGroup) @@ -84,7 +73,7 @@ def test_generate(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() config = basic_runner.config.pipelines[0] - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), config) + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config) run.assert_called_once_with( reader=basic_runner.reader, @@ -99,7 +88,7 @@ def test_generate(spark, mocker, basic_runner): def test_generate_w_dep(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), @@ -133,27 +122,6 @@ def test_init_dates(spark): assert runner.date_until == DateManager.str_to_date("2023-03-31") -def test_possible_run_dates(spark): - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", - ) - - dates = runner.get_possible_run_dates(runner.config.pipelines[0].schedule) - expected = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - assert dates == [DateManager.str_to_date(d) for d in expected] - - -def test_info_dates(spark, basic_runner): - run = ["2023-02-05", "2023-02-12", "2023-02-19", "2023-02-26", "2023-03-05"] - run = [DateManager.str_to_date(d) for d in run] - info = basic_runner.get_info_dates(basic_runner.config.pipelines[0].schedule, run) - expected = ["2023-02-02", "2023-02-09", "2023-02-16", "2023-02-23", "2023-03-02"] - assert info == [DateManager.str_to_date(d) for d in expected] - - def test_completion(spark, mocker, basic_runner): mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)