Skip to content

Commit

Permalink
loader simplified, runner cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobransky committed Aug 26, 2024
1 parent 1699043 commit a185b20
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 309 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ All notable changes to this project will be documented in this file.
- restructured runner config
- added metadata and feature loader sections
- target moved to pipeline
- dependency date_col is mandatory
- custom extras config is available in each pipeline and will be passed as dictionary
- dependency date_col is now mandatory
- custom extras config is available in each pipeline and will be passed as dictionary available under pipeline_config.extras
- general section is renamed to runner
- transformation header changed
- added argument to skip dependency checking
#### Jobs
- jobs are now the main way to create all pipelines
- config holder removed from jobs
- metadata_manager and feature_loader are now available arguments, depending on configuration
#### TableReader
Expand All @@ -21,7 +22,8 @@ All notable changes to this project will be documented in this file.
- info_date_from -> date_from, info_date_to -> date_to
- date_column is now mandatory
- removed TableReaders ability to infer schema from partitions or properties

#### Loader
- removed DataLoader class, now only PysparkFeatureLoader is needed with additional parameters

## 1.3.0 - 2024-06-07

Expand Down
30 changes: 7 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -423,37 +423,23 @@ This module is used to load features from feature store into your models and scr

Two public classes are exposed form this module. **DatabricksLoader**(DataLoader), **PysparkFeatureLoader**(FeatureLoaderInterface).

### DatabricksLoader
This is a support class for feature loader and provides the data reading capability from the feature store.

This class needs to be instantiated with an active spark session and a path to the feature store schema (in the format of "catalog_name.schema_name").
Optionally a date_column information can be passed, otherwise it defaults to use INFORMATION_DATE
```python
from rialto.loader import DatabricksLoader

data_loader = DatabricksLoader(spark= spark_instance, schema= "catalog.schema", date_column= "INFORMATION_DATE")
```

This class provides one method, read_group(...), which returns a whole feature group for selected date. This is mostly used inside feature loader.

### PysparkFeatureLoader

This class needs to be instantiated with an active spark session, data loader and a path to the metadata schema (in the format of "catalog_name.schema_name").

```python
from rialto.loader import PysparkFeatureLoader

feature_loader = PysparkFeatureLoader(spark= spark_instance, data_loader= data_loader_instance, metadata_schema= "catalog.schema")
feature_loader = PysparkFeatureLoader(spark= spark_instance, feature_schema="catalog.schema", metadata_schema= "catalog.schema2", date_column="information_date")
```

#### Single feature

```python
from rialto.loader import DatabricksLoader, PysparkFeatureLoader
from rialto.loader import PysparkFeatureLoader
from datetime import datetime

data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema")
feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema")
feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema")
my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date()

feature = feature_loader.get_feature(group_name="CustomerFeatures", feature_name="AGE", information_date=my_date)
Expand All @@ -464,11 +450,10 @@ metadata = feature_loader.get_feature_metadata(group_name="CustomerFeatures", fe
This method of data access is only recommended for experimentation, as the group schema can evolve over time.

```python
from rialto.loader import DatabricksLoader, PysparkFeatureLoader
from rialto.loader import PysparkFeatureLoader
from datetime import datetime

data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema")
feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema")
feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema")
my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date()

features = feature_loader.get_group(group_name="CustomerFeatures", information_date=my_date)
Expand All @@ -478,11 +463,10 @@ metadata = feature_loader.get_group_metadata(group_name="CustomerFeatures")
#### Configuration

```python
from rialto.loader import DatabricksLoader, PysparkFeatureLoader
from rialto.loader import PysparkFeatureLoader
from datetime import datetime

data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema")
feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema")
feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema")
my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date()

features = feature_loader.get_features_from_cfg(path="local/configuration/file.yaml", information_date=my_date)
Expand Down
2 changes: 1 addition & 1 deletion rialto/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from rialto.common.table_reader import TableReader
from rialto.common.table_reader import DataReader, TableReader
34 changes: 1 addition & 33 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,43 +38,11 @@ def load_yaml(path: str) -> Any:
return yaml.load(stream, EnvLoader)


def get_date_col_property(spark, table: str, property: str) -> str:
"""
Retrieve a data column name from a given table property
:param spark: spark session
:param table: path to table
:param property: name of the property
:return: data column name
"""
props = spark.sql(f"show tblproperties {table}")
date_col = props.filter(F.col("key") == property).select("value").collect()
if len(date_col):
return date_col[0].value
else:
raise RuntimeError(f"Table {table} has no property {property}.")


def get_delta_partition(spark, table: str) -> str:
"""
Select first partition column of the delta table
:param table: full table name
:return: partition column name
"""
columns = spark.catalog.listColumns(table)
partition_columns = list(filter(lambda c: c.isPartition, columns))
if len(partition_columns):
return partition_columns[0].name
else:
raise RuntimeError(f"Delta table has no partitions: {table}.")


def cast_decimals_to_floats(df: DataFrame) -> DataFrame:
"""
Find all decimal types in the table and cast them to floats. Fixes errors in .toPandas() conversions.
:param df: pyspark DataFrame
:param df: input df
:return: pyspark DataFrame with fixed types
"""
decimal_cols = [col_name for col_name, data_type in df.dtypes if "decimal" in data_type]
Expand Down
2 changes: 1 addition & 1 deletion rialto/jobs/decorators/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def job(*args, custom_name=None, disable_version=False):
module = _get_module(stack)
version = _get_version(module)

# Use case where it's just raw @f. Otherwise we get [] here.
# Use case where it's just raw @f. Otherwise, we get [] here.
if len(args) == 1 and callable(args[0]):
f = args[0]
return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version)
Expand Down
2 changes: 1 addition & 1 deletion rialto/jobs/decorators/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def cache_clear(cls) -> None:
"""
Clear resolver cache.
The resolve mehtod caches its results to avoid duplication of resolutions.
The resolve method caches its results to avoid duplication of resolutions.
However, in case we re-register some callables, we need to clear cache
in order to ensure re-execution of all resolutions.
Expand Down
1 change: 0 additions & 1 deletion rialto/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from rialto.loader.data_loader import DatabricksLoader
from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader
45 changes: 0 additions & 45 deletions rialto/loader/data_loader.py

This file was deleted.

20 changes: 1 addition & 19 deletions rialto/loader/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["DataLoader", "FeatureLoaderInterface"]
__all__ = ["FeatureLoaderInterface"]

import abc
from datetime import date
from typing import Dict


class DataLoader(metaclass=abc.ABCMeta):
"""
An interface to read feature groups from storage
Requires read_group function.
"""

@abc.abstractmethod
def read_group(self, group: str, information_date: date):
"""
Read one feature group
:param group: Group name
:param information_date: date
"""
raise NotImplementedError


class FeatureLoaderInterface(metaclass=abc.ABCMeta):
"""
A definition of feature loading interface
Expand Down
43 changes: 31 additions & 12 deletions rialto/loader/pyspark_feature_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from pyspark.sql import DataFrame, SparkSession

from rialto.common import TableReader
from rialto.common.utils import cast_decimals_to_floats
from rialto.loader.config_loader import FeatureConfig, GroupConfig, get_feature_config
from rialto.loader.data_loader import DataLoader
from rialto.loader.interfaces import FeatureLoaderInterface
from rialto.metadata.metadata_manager import (
FeatureMetadata,
Expand All @@ -34,7 +34,13 @@
class PysparkFeatureLoader(FeatureLoaderInterface):
"""Implementation of feature loader for pyspark environment"""

def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema: str):
def __init__(
self,
spark: SparkSession,
feature_schema: str,
metadata_schema: str,
date_column: str = "INFORMATION_DATE",
):
"""
Init
Expand All @@ -44,11 +50,28 @@ def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema
"""
super().__init__()
self.spark = spark
self.data_loader = data_loader
self.reader = TableReader(spark)
self.feature_schema = feature_schema
self.date_col = date_column
self.metadata = MetadataManager(spark, metadata_schema)

KeyMap = namedtuple("KeyMap", ["df", "key"])

def read_group(self, group: str, information_date: date) -> DataFrame:
"""
Read a feature group by getting the latest partition by date
:param group: group name
:param information_date: partition date
:return: dataframe
"""
return self.reader.get_latest(
f"{self.feature_schema}.{group}",
date_until=information_date,
date_column=self.date_col,
uppercase_columns=True,
)

def get_feature(self, group_name: str, feature_name: str, information_date: date) -> DataFrame:
"""
Get single feature
Expand All @@ -60,9 +83,7 @@ def get_feature(self, group_name: str, feature_name: str, information_date: date
"""
print("This function is untested, use with caution!")
key = self.get_group_metadata(group_name).key
return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date).select(
*key, feature_name
)
return self.read_group(self.get_group_fs_name(group_name), information_date).select(*key, feature_name)

def get_feature_metadata(self, group_name: str, feature_name: str) -> FeatureMetadata:
"""
Expand All @@ -83,7 +104,7 @@ def get_group(self, group_name: str, information_date: date) -> DataFrame:
:return: A dataframe containing feature group key
"""
print("This function is untested, use with caution!")
return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date)
return self.read_group(self.get_group_fs_name(group_name), information_date)

def get_group_metadata(self, group_name: str) -> GroupMetadata:
"""
Expand Down Expand Up @@ -144,7 +165,7 @@ def _get_keymaps(self, config: FeatureConfig, information_date: date) -> List[Ke
"""
key_maps = []
for mapping in config.maps:
df = self.data_loader.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE")
df = self.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE")
key = self.metadata.get_group(mapping).key
key_maps.append(PysparkFeatureLoader.KeyMap(df, key))
return key_maps
Expand Down Expand Up @@ -174,17 +195,15 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame:
"""
config = get_feature_config(path)
# 1 select keys from base
base = self.data_loader.read_group(self.get_group_fs_name(config.base.group), information_date).select(
config.base.keys
)
base = self.read_group(self.get_group_fs_name(config.base.group), information_date).select(config.base.keys)
# 2 join maps onto base (resolve keys)
if config.maps:
key_maps = self._get_keymaps(config, information_date)
base = self._join_keymaps(base, key_maps)

# 3 read, select and join other tables
for group_cfg in config.selection:
df = self.data_loader.read_group(self.get_group_fs_name(group_cfg.group), information_date)
df = self.read_group(self.get_group_fs_name(group_cfg.group), information_date)
base = self._add_feature_group(base, df, group_cfg)

# 4 fix dtypes for pandas conversion
Expand Down
Loading

0 comments on commit a185b20

Please sign in to comment.