Skip to content

Commit

Permalink
v0.1.2 Created demand columns
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakleen committed Oct 4, 2024
1 parent c94850b commit 8f031d0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ blinker==1.8.2
cachetools==5.5.0
certifi==2024.8.30
charset-normalizer==3.3.2
CitiBike-Demand-Prediction==0.1.1
CitiBike-Demand-Prediction==0.1.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_requirements(file_name: str) -> List[str]:

setup(
name="CitiBike Demand Prediction",
version="0.1.1",
version="0.1.2",
description="An End-to-End Machine Learning project where I predict demand of bikes at citibike stations at hourly level.",
author="Shakleen Ishfar",
author_email="shakleenishfar@gmail.com",
Expand Down
18 changes: 17 additions & 1 deletion src/components/bronze_to_silver_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,21 @@ def attach_station_ids(

def count_group_by_station_and_time(self, df: DataFrame) -> DataFrame:
df = df.withColumn("time", F.date_trunc("hour", "time"))
count_df = df.groupBy("station_id", "time").agg(F.count("row_number").alias("count"))
count_df = df.groupBy("station_id", "time").agg(
F.count("row_number").alias("count")
)
return count_df

def combine_on_station_id_and_time(
self,
start_df: DataFrame,
end_df: DataFrame,
) -> DataFrame:
start_df = start_df.withColumnRenamed("count", "bike_demand")
end_df = end_df.withColumnRenamed("count", "dock_demand")
combined_df = start_df.join(
end_df,
on=["station_id", "time"],
how="fullouter",
).fillna(0)
return combined_df
54 changes: 54 additions & 0 deletions test/components/bronze_to_silver_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,57 @@ def test_count_group_by_station_and_time(
assert set(output.columns) == {"station_id", "time", "count"}
assert output.select("count").toPandas().to_numpy().flatten().tolist() == [1, 2, 1]


def test_combine_on_station_id_and_time(
transformer: BronzeToSilverTransformer,
spark: SparkSession,
):
start_df = spark.createDataFrame(
[
[1, "2024-06-19 19:00:00", 100],
[2, "2024-06-20 17:00:00", 200],
[3, "2024-06-21 17:00:00", 300],
],
schema=T.StructType(
[
T.StructField("station_id", T.IntegerType(), True),
T.StructField("time", T.StringType(), True),
T.StructField("count", T.IntegerType(), True),
]
),
)
start_df = start_df.withColumn("time", F.to_timestamp("time"))

end_df = spark.createDataFrame(
[
[1, "2024-06-19 19:00:00", 100],
[2, "2024-06-20 17:00:00", 200],
[4, "2024-06-21 17:00:00", 400],
],
schema=T.StructType(
[
T.StructField("station_id", T.IntegerType(), True),
T.StructField("time", T.StringType(), True),
T.StructField("count", T.IntegerType(), True),
]
),
)
end_df = end_df.withColumn("time", F.to_timestamp("time"))

output = transformer.combine_on_station_id_and_time(start_df, end_df)

assert isinstance(output, DataFrame)
assert output.count() == 4
assert set(output.columns) == {"station_id", "time", "bike_demand", "dock_demand"}
assert output.select("bike_demand").toPandas().to_numpy().flatten().tolist() == [
100,
200,
300,
0,
]
assert output.select("dock_demand").toPandas().to_numpy().flatten().tolist() == [
100,
200,
0,
400,
]

0 comments on commit 8f031d0

Please sign in to comment.