Skip to content

Commit

Permalink
Bugfix: Align columns of two types of dataframe properly
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakleen committed Oct 1, 2024
1 parent 7fa80fd commit c0d2486
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 34 deletions.
15 changes: 15 additions & 0 deletions src/components/data_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ class DataIngestorConfig:
)

raw_data_save_dir: str = os.path.join(root_data_path, "delta", "raw")
column_order = [
"start_time",
"end_time",
"start_station_name",
"start_station_latitude",
"start_station_longitude",
"end_station_name",
"end_station_latitude",
"end_station_longitude",
"start_station_id",
"end_station_id",
"member",
]


class DataIngestor:
Expand Down Expand Up @@ -121,6 +134,7 @@ def standardize_columns_for_post2020(self, post_2020_df: DataFrame):
"member", when(col("member_casual") == "casual", 0).otherwise(1)
)
.drop("ride_id", "rideable_type", "member_casual")
.select(self.config.column_order)
)

def standardize_columns_for_pre2020(self, pre_2020_df: DataFrame):
Expand All @@ -145,6 +159,7 @@ def standardize_columns_for_pre2020(self, pre_2020_df: DataFrame):
"end station id",
"usertype",
)
.select(self.config.column_order)
)

def ingest(self):
Expand Down
37 changes: 3 additions & 34 deletions test/components/data_ingestor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def test_config():
assert hasattr(config, "post_2020_csv_dir")
assert hasattr(config, "post_2020_schema")
assert hasattr(config, "raw_data_save_dir")
assert hasattr(config, "column_order")


def test_init(spark: SparkSession, ingestor: DataIngestor):
Expand Down Expand Up @@ -246,50 +247,18 @@ def test_standardize_columns_pre2020(
dataframe_pre2020: DataFrame,
ingestor: DataIngestor,
):
expected_columns = set(
[
"start_time",
"end_time",
"start_station_name",
"start_station_latitude",
"start_station_longitude",
"end_station_name",
"end_station_latitude",
"end_station_longitude",
"start_station_id",
"end_station_id",
"member",
]
)

df = ingestor.standardize_columns_for_pre2020(dataframe_pre2020)

assert set(df.columns) == expected_columns
assert df.columns == ingestor.config.column_order


def test_standardize_columns_post2020(
dataframe_post2020: DataFrame,
ingestor: DataIngestor,
):
expected_columns = set(
[
"start_time",
"end_time",
"start_station_name",
"start_station_id",
"end_station_name",
"end_station_id",
"start_station_latitude",
"start_station_longitude",
"end_station_latitude",
"end_station_longitude",
"member",
]
)

df = ingestor.standardize_columns_for_post2020(dataframe_post2020)

assert set(df.columns) == expected_columns
assert df.columns == ingestor.config.column_order


def test_combine_dataframes(
Expand Down

0 comments on commit c0d2486

Please sign in to comment.