Skip to content

Commit

Permalink
V 0.0.5 Split station and time features
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakleen committed Oct 2, 2024
1 parent 3354d58 commit 5cf3abc
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ blinker==1.8.2
cachetools==5.5.0
certifi==2024.8.30
charset-normalizer==3.3.2
CitiBike-Demand-Prediction==0.0.4
CitiBike-Demand-Prediction==0.0.5
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_requirements(file_name: str) -> List[str]:

setup(
name="CitiBike Demand Prediction",
version="0.0.4",
version="0.0.5",
description="An End-to-End Machine Learning project where I predict demand of bikes at citibike stations at hourly level.",
author="Shakleen Ishfar",
author_email="shakleenishfar@gmail.com",
Expand Down
53 changes: 45 additions & 8 deletions src/components/raw_to_bronze_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
month,
coalesce,
)
from typing import Tuple

if __name__ == "__main__":
from src.logger import logging
Expand All @@ -23,6 +24,8 @@ class RawToBronzeTransformerConfig:
root_delta_path: str = os.path.join("Data", "delta")
raw_data_path: str = os.path.join(root_delta_path, "raw")
bronze_data_path: str = os.path.join(root_delta_path, "bronze")
station_data_path: str = os.path.join(root_delta_path, "station")
row_to_station_data_path: str = os.path.join(root_delta_path, "row_to_station")


class RawToBronzeTransformer:
Expand All @@ -33,12 +36,8 @@ def __init__(self, spark: SparkSession):
def read_raw_delta(self) -> DataFrame:
return self.spark.read.format("delta").load(self.config.raw_data_path)

def write_delta(self, df: DataFrame):
df.write.save(
path=self.config.bronze_data_path,
format="delta",
mode="overwrite",
)
def write_delta(self, df: DataFrame, path: str):
df.write.save(path=path, format="delta", mode="overwrite")

def create_file_name_column(self, df: DataFrame) -> DataFrame:
regex_str = "[^\\/]+$"
Expand Down Expand Up @@ -144,7 +143,7 @@ def get_station_dataframe(self, df: DataFrame) -> DataFrame:
)
)

def drup_duplicates_and_all_nulls(self, df: DataFrame) -> DataFrame:
def drop_duplicates_and_all_nulls(self, df: DataFrame) -> DataFrame:
return df.dropDuplicates().dropna(how="all")

def fill_in_station_id_using_name(self, df: DataFrame) -> DataFrame:
Expand Down Expand Up @@ -210,13 +209,51 @@ def fill_in_using_station_id(self, df: DataFrame) -> DataFrame:
.dropna(how="any")
)

def split_station_and_time(
self, df: DataFrame
) -> Tuple[DataFrame, DataFrame, DataFrame]:
# Separating station Data
station_df = self.get_station_dataframe(df)
station_df = self.drop_duplicates_and_all_nulls(station_df)
station_df = self.fill_in_station_id_using_name(station_df)
station_df = self.fill_in_using_station_id(station_df)

# Dropping rows with null station ids
df = df.dropna(subset=["start_station_id", "end_station_id"], how="any")

# Mapping df to station ids
row_to_station_df = df.select(
"row_number", "start_station_id", "end_station_id"
)

# Dropping station related columns
df = df.drop(
"start_station_id",
"start_station_name",
"start_station_latitude",
"start_station_longitude",
"end_station_id",
"end_station_name",
"end_station_latitude",
"end_station_longitude",
"member",
"file_path",
"file_name",
)

return (station_df, row_to_station_df, df)

def transform(self):
logging.info("Reading raw delta table")
df = self.read_raw_delta()
logging.info("Creating file name column")
df = self.create_file_name_column(df)

self.set_timestamp_datatype(df)
df = self.set_timestamp_datatype(df)

station_df, row_to_station_df, df = self.split_station_and_time(df)
self.write_delta(station_df, self.config.station_data_path)
self.write_delta(row_to_station_df, self.config.row_to_station_data_path)

logging.info("Writing to bronze delta table")
self.write_delta(df)
Expand Down
44 changes: 33 additions & 11 deletions test/components/raw_to_bronze_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def dataframe_2(spark: SparkSession):
None,
5414,
1,
1666447310848,
1,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"202101-citibike-tripdata_1.csv",
],
Expand All @@ -72,7 +72,7 @@ def dataframe_2(spark: SparkSession):
4789,
4829,
1,
1666447310849,
2,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"202101-citibike-tripdata_1.csv",
],
Expand All @@ -88,7 +88,7 @@ def dataframe_2(spark: SparkSession):
5406,
5414,
1,
1666447310848,
3,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"201501-citibike-tripdata_1.csv",
],
Expand All @@ -104,7 +104,7 @@ def dataframe_2(spark: SparkSession):
4789,
4829,
1,
1666447310849,
4,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"201501-citibike-tripdata_1.csv",
],
Expand All @@ -120,7 +120,7 @@ def dataframe_2(spark: SparkSession):
5406,
None,
1,
1666447310848,
5,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"201409-citibike-tripdata_1.csv",
],
Expand All @@ -136,7 +136,7 @@ def dataframe_2(spark: SparkSession):
4789,
4829,
1,
1666447310849,
6,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/post_2020/202101-citibike-tripdata_1.csv",
"201409-citibike-tripdata_1.csv",
],
Expand Down Expand Up @@ -272,6 +272,8 @@ def test_config():
assert hasattr(config, "root_delta_path")
assert hasattr(config, "raw_data_path")
assert hasattr(config, "bronze_data_path")
assert hasattr(config, "station_data_path")
assert hasattr(config, "row_to_station_data_path")


def test_init(transformer: RawToBronzeTransformer, spark: SparkSession):
Expand Down Expand Up @@ -303,10 +305,12 @@ def test_write_delta():
spark_mock = Mock(SparkSession)
transformer = RawToBronzeTransformer(spark_mock)

transformer.write_delta(dataframe)
transformer.write_delta(dataframe, transformer.config.bronze_data_path)

dataframe.write.save.assert_called_once_with(
path=transformer.config.bronze_data_path, format="delta", mode="overwrite"
path=transformer.config.bronze_data_path,
format="delta",
mode="overwrite",
)


Expand Down Expand Up @@ -409,7 +413,7 @@ def test_drup_duplicates_and_all_nulls(
):
output = transformer.get_station_dataframe(dataframe_2)
before = output.count()
output = transformer.drup_duplicates_and_all_nulls(output)
output = transformer.drop_duplicates_and_all_nulls(output)
after = output.count()

assert isinstance(output, DataFrame)
Expand All @@ -421,7 +425,7 @@ def test_fill_in_station_id_using_name(
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
output = transformer.drup_duplicates_and_all_nulls(output)
output = transformer.drop_duplicates_and_all_nulls(output)
output = transformer.fill_in_station_id_using_name(output)

assert isinstance(output, DataFrame)
Expand All @@ -433,7 +437,7 @@ def test_fill_in_using_station_id(
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
output = transformer.drup_duplicates_and_all_nulls(output)
output = transformer.drop_duplicates_and_all_nulls(output)
output = transformer.fill_in_station_id_using_name(output)
output = transformer.fill_in_using_station_id(output)

Expand All @@ -444,3 +448,21 @@ def test_fill_in_using_station_id(
).count()
is 0
)


def test_split_station_and_time(
dataframe_2: DataFrame,
transformer: RawToBronzeTransformer,
):
station_df, row_to_station_df, df = transformer.split_station_and_time(dataframe_2)

assert isinstance(station_df, DataFrame)
assert isinstance(row_to_station_df, DataFrame)
assert isinstance(df, DataFrame)
assert set(station_df.columns) == {"id", "name", "latitude", "longitude"}
assert set(df.columns) == {"start_time", "end_time", "row_number"}
assert set(row_to_station_df.columns) == {
"row_number",
"start_station_id",
"end_station_id",
}

0 comments on commit 5cf3abc

Please sign in to comment.