Skip to content

Commit

Permalink
Removed imputation for station dataframe.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakleen committed Oct 3, 2024
1 parent 24487c0 commit aaeb6c4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 238 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ blinker==1.8.2
cachetools==5.5.0
certifi==2024.8.30
charset-normalizer==3.3.2
CitiBike-Demand-Prediction==0.0.7
CitiBike-Demand-Prediction==0.1.0
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_requirements(file_name: str) -> List[str]:

setup(
name="CitiBike Demand Prediction",
version="0.0.7",
version="0.1.0",
description="An End-to-End Machine Learning project where I predict demand of bikes at citibike stations at hourly level.",
author="Shakleen Ishfar",
author_email="shakleenishfar@gmail.com",
Expand Down
75 changes: 3 additions & 72 deletions src/components/raw_to_bronze_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,84 +165,15 @@ def get_station_dataframe(self, df: DataFrame) -> DataFrame:
)
)

def drop_duplicates_and_all_nulls(self, df: DataFrame) -> DataFrame:
return df.dropDuplicates().dropna(how="all")

def fill_in_station_id_using_name(self, df: DataFrame) -> DataFrame:
# Create a mapping DataFrame with distinct non-null name and id pairs
mapping_df = df.filter(df["id"].isNotNull()).select("name", "id").distinct()

# Rename the id column in the mapping DataFrame to avoid conflicts
mapping_df = mapping_df.withColumnRenamed("id", "mapped_id")

# Join the original DataFrame with the mapping DataFrame
df_filled = df.alias("df1").join(mapping_df.alias("df2"), on="name", how="left")

# Use coalesce to fill null values in the id column
df_filled = df_filled.withColumn(
"id", coalesce(df_filled["df1.id"], df_filled["df2.mapped_id"])
)

# Drop the extra columns from the join
df_filled = df_filled.drop("mapped_id")

return df_filled

def fill_in_using_station_id(self, df: DataFrame) -> DataFrame:
# Create a mapping DataFrame with distinct non-null id and corresponding non-null values
mapping_df = (
df.filter(df["id"].isNotNull())
.select("id", "name", "latitude", "longitude")
.distinct()
)
mapping_df = (
mapping_df.withColumnRenamed("name", "mapped_name")
.withColumnRenamed("latitude", "mapped_latitude")
.withColumnRenamed("longitude", "mapped_longitude")
)

# Show the mapping DataFrame
mapping_df.show()

# Join the original DataFrame with the mapping DataFrame on the id column
df_filled = df.alias("df1").join(mapping_df.alias("df2"), on="id", how="left")

# Use coalesce to fill null values in the name, latitude, and longitude columns
df_filled = (
df_filled.withColumn(
"name", coalesce(df_filled["df1.name"], df_filled["mapped_name"])
)
.withColumn(
"latitude",
coalesce(df_filled["df1.latitude"], df_filled["mapped_latitude"]),
)
.withColumn(
"longitude",
coalesce(df_filled["df1.longitude"], df_filled["mapped_longitude"]),
)
)

# Drop the extra columns from the join
return (
df_filled.drop("mapped_name")
.drop("mapped_latitude")
.drop("mapped_longitude")
.dropDuplicates()
.dropna(how="any")
)

def split_station_and_time(
self, df: DataFrame
) -> Tuple[DataFrame, DataFrame, DataFrame]:
# Separating station Data
station_df = self.get_station_dataframe(df)
station_df = self.drop_duplicates_and_all_nulls(station_df)
station_df = self.fill_in_station_id_using_name(station_df)
station_df = self.fill_in_using_station_id(station_df)

# Dropping rows with null station ids
df = df.dropna(subset=["start_station_id", "end_station_id"], how="any")

# Separating station Data
station_df = self.get_station_dataframe(df)

# Mapping df to station ids
row_to_station_df = df.select(
"row_number", "start_station_id", "end_station_id"
Expand Down
184 changes: 20 additions & 164 deletions test/components/raw_to_bronze_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transformer(spark: SparkSession):


@pytest.fixture
def dataframe_2(spark: SparkSession):
def dataframe(spark: SparkSession):
return spark.createDataFrame(
data=[
[
Expand Down Expand Up @@ -164,108 +164,6 @@ def dataframe_2(spark: SparkSession):
)


@pytest.fixture
def dataframe(spark: SparkSession):
return spark.createDataFrame(
data=[
[
"2019-08-01 00:00:01.4680",
"2019-08-01 00:06:35.3780",
"Forsyth St & Broome St",
40.71894073486328,
-73.99266052246094,
"Market St & Cherry St",
40.71076202392578,
-73.99400329589844,
531,
408,
1,
85899345920,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/pre_2020/201908-citibike-tripdata_1.csv",
],
[
"2019-08-01 00:00:01.9290",
"2019-08-01 00:10:29.7840",
"Lafayette Ave & Fort Greene Pl",
40.686920166015625,
-73.9766845703125,
"Bergen St & Smith St",
40.686744689941406,
-73.99063110351562,
274,
3409,
1,
85899345921,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/pre_2020/201908-citibike-tripdata_1.csv",
],
[
"2019-08-01 00:00:04.0480",
"2019-08-01 00:18:56.1650",
"Front St & Washington St",
40.70254898071289,
-73.9894027709961,
"President St & Henry St",
40.68280029296875,
-73.9999008178711,
2000,
3388,
1,
85899345922,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/pre_2020/201908-citibike-tripdata_1.csv",
],
[
"2019-08-01 00:00:04.1630",
"2019-08-01 00:29:44.7940",
"9 Ave & W 45 St",
40.76019287109375,
-73.99125671386719,
"Rivington St & Chrystie St",
40.721099853515625,
-73.99192810058594,
479,
473,
1,
85899345923,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/pre_2020/201908-citibike-tripdata_1.csv",
],
[
"2019-08-01 00:00:05.4580",
"2019-08-01 00:25:23.4550",
"1 Ave & E 94 St",
40.78172302246094,
-73.94593811035156,
"1 Ave & E 94 St",
40.78172302246094,
-73.94593811035156,
3312,
3312,
1,
85899345924,
"file:///media/ishrak/New%20Volume/Studies/Projects/CitiBike-Demand-Prediction/Data/CSVs/pre_2020/201908-citibike-tripdata_1.csv",
],
],
schema=StructType(
StructType(
[
StructField("start_time", StringType(), True),
StructField("end_time", StringType(), True),
StructField("start_station_name", StringType(), True),
StructField("start_station_latitude", FloatType(), True),
StructField("start_station_longitude", FloatType(), True),
StructField("end_station_name", StringType(), True),
StructField("end_station_latitude", FloatType(), True),
StructField("end_station_longitude", FloatType(), True),
StructField("start_station_id", IntegerType(), True),
StructField("end_station_id", IntegerType(), True),
StructField("member", IntegerType(), True),
StructField("row_number", LongType(), True),
StructField("file_path", StringType(), True),
]
)
),
)


def test_config():
config = RawToBronzeTransformerConfig()

Expand All @@ -284,11 +182,12 @@ def test_init(transformer: RawToBronzeTransformer, spark: SparkSession):
assert isinstance(transformer.spark, SparkSession)


def test_read_raw_delta(dataframe: DataFrame):
def test_read_raw_delta():
spark_mock = Mock(SparkSession)
transformer = RawToBronzeTransformer(spark_mock)
dataframe_mock = Mock(DataFrame)

spark_mock.read.format("delta").load.return_value = dataframe
spark_mock.read.format("delta").load.return_value = dataframe_mock

df = transformer.read_raw_delta()

Expand All @@ -297,7 +196,7 @@ def test_read_raw_delta(dataframe: DataFrame):
transformer.config.raw_data_path
)

assert df is dataframe
assert df is dataframe_mock


def test_write_delta():
Expand All @@ -324,30 +223,30 @@ def test_create_file_name_column(


def test_get_dataframe_timeformat_type_1(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_dataframe_timeformat_type_1(dataframe_2)
output = transformer.get_dataframe_timeformat_type_1(dataframe)

assert isinstance(output, DataFrame)
assert output.count() == 2


def test_get_dataframe_timeformat_type_2(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_dataframe_timeformat_type_2(dataframe_2)
output = transformer.get_dataframe_timeformat_type_2(dataframe)

assert isinstance(output, DataFrame)
assert output.count() == 2


def test_get_dataframe_timeformat_type_3(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_dataframe_timeformat_type_3(dataframe_2)
output = transformer.get_dataframe_timeformat_type_3(dataframe)

assert isinstance(output, DataFrame)
assert output.count() == 2
Expand All @@ -362,12 +261,12 @@ def test_get_dataframe_timeformat_type_3(
],
)
def test_set_timestamp_for_format(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
time_format: str,
count: int,
):
output = transformer.set_timestamp_for_format(dataframe_2, time_format)
output = transformer.set_timestamp_for_format(dataframe, time_format)

assert isinstance(output, DataFrame)
assert output.schema[0] == StructField("start_time", TimestampType(), True)
Expand All @@ -381,10 +280,10 @@ def test_set_timestamp_for_format(


def test_set_timestamp_datatype(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.set_timestamp_datatype(dataframe_2)
output = transformer.set_timestamp_datatype(dataframe)

assert isinstance(output, DataFrame)
assert output.schema[0] == StructField("start_time", TimestampType(), True)
Expand All @@ -393,68 +292,25 @@ def test_set_timestamp_datatype(
output.filter(col("start_time").isNotNull())
.filter(col("end_time").isNotNull())
.count()
== dataframe_2.count()
== dataframe.count()
)


def test_get_station_dataframe(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
output = transformer.get_station_dataframe(dataframe)

assert isinstance(output, DataFrame)
assert output.columns == ["id", "name", "latitude", "longitude"]


def test_drup_duplicates_and_all_nulls(
dataframe_2: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
before = output.count()
output = transformer.drop_duplicates_and_all_nulls(output)
after = output.count()

assert isinstance(output, DataFrame)
assert before - after == 3


def test_fill_in_station_id_using_name(
dataframe_2: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
output = transformer.drop_duplicates_and_all_nulls(output)
output = transformer.fill_in_station_id_using_name(output)

assert isinstance(output, DataFrame)
assert output.filter(col("id").isNull()).count() is 0


def test_fill_in_using_station_id(
dataframe_2: DataFrame,
transformer: RawToBronzeTransformer,
):
output = transformer.get_station_dataframe(dataframe_2)
output = transformer.drop_duplicates_and_all_nulls(output)
output = transformer.fill_in_station_id_using_name(output)
output = transformer.fill_in_using_station_id(output)

assert isinstance(output, DataFrame)
assert (
output.filter(
col("name").isNull() | col("latitude").isNull() | col("longitude").isNull()
).count()
is 0
)


def test_split_station_and_time(
dataframe_2: DataFrame,
dataframe: DataFrame,
transformer: RawToBronzeTransformer,
):
station_df, row_to_station_df, df = transformer.split_station_and_time(dataframe_2)
station_df, row_to_station_df, df = transformer.split_station_and_time(dataframe)

assert isinstance(station_df, DataFrame)
assert isinstance(row_to_station_df, DataFrame)
Expand Down

0 comments on commit aaeb6c4

Please sign in to comment.