Skip to content

Commit 4bb803a

Browse files
committed
cleanup and tests
1 parent 87c4cee commit 4bb803a

File tree

5 files changed

+138
-22
lines changed

5 files changed

+138
-22
lines changed

hbw/analysis/create_analysis.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def analysis_factory(configs: od.UniqueObjectIndex):
111111
)
112112
else:
113113
logger.warning(
114-
f"Campaign used for {config_name} has been changed since last initialization."
115-
"Difference: \n",
114+
f"Campaign used for {config_name} is being reinitialized: \n",
116115
)
117116
cpn_task.run()
118117

hbw/scripts/test_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
default_analysis = law.config.get_expanded("analysis", "default_analysis")
1111
default_config = law.config.get_expanded("analysis", "default_config")
12-
default_config = "c22uhhpost"
1312
analysis_inst = ana = AnalysisTask.get_analysis_inst(default_analysis)
1413
config_inst = cfg = ana.get_config(default_config)
1514

hbw/tasks/campaigns.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ def run(self):
146146
output = self.output()
147147

148148
# cross check if the dataset summary did change
149-
backup_dataset_summary = self.target("backup_dataset_summary.yaml")
150-
if backup_dataset_summary.exists():
151-
backup_dataset_summary = backup_dataset_summary.load(formatter="yaml")
149+
backup_target = self.target("backup_dataset_summary.yaml")
150+
if backup_target.exists():
151+
backup_dataset_summary = backup_target.load(formatter="yaml")
152152
if backup_dataset_summary != self.dataset_summary:
153153
from hbw.util import gather_dict_diff
154154
logger.warning(
@@ -157,15 +157,15 @@ def run(self):
157157
)
158158
if self.recreate_backup_summary:
159159
logger.warning("Recreating backup dataset summary")
160-
backup_dataset_summary.dump(self.dataset_summary, formatter="yaml")
160+
backup_target.dump(self.dataset_summary, formatter="yaml")
161161
else:
162162
logger.warning(
163163
"Run the following command to recreate the backup dataset summary:\n"
164164
f"law run {self.task_family} --recreate_backup_summary --config {self.config} --remove-output 0,a,y", # noqa
165165
)
166166
else:
167167
logger.warning("No backup dataset summary found, creating one now")
168-
backup_dataset_summary.dump(self.dataset_summary, formatter="yaml")
168+
backup_target.dump(self.dataset_summary, formatter="yaml")
169169

170170
output["dataset_summary"].dump(self.dataset_summary, formatter="yaml")
171171
output["campaign_summary"].dump(self.campaign_summary, formatter="yaml")

hbw/util.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -280,33 +280,33 @@ def filter_unchanged_keys(d1: dict, d2: dict):
280280
return filtered if filtered else None
281281

282282

283-
def dict_diff_filtered(dict1: dict, dict2: dict):
283+
def dict_diff_filtered(old_dict: dict, new_dict: dict):
284284
"""Return the differences between two dictionaries with nested filtering of unchanged keys."""
285285
diff = {}
286286

287287
# Check keys present in either dict
288-
all_keys = set(dict1.keys()).union(set(dict2.keys()))
288+
all_keys = set(old_dict.keys()).union(set(new_dict.keys()))
289289

290290
for key in all_keys:
291-
if key in dict1 and key in dict2:
292-
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
291+
if key in old_dict and key in new_dict:
292+
if isinstance(old_dict[key], dict) and isinstance(new_dict[key], dict):
293293
# Recur for nested dictionaries and get filtered diff
294-
nested_diff = filter_unchanged_keys(dict1[key], dict2[key])
294+
nested_diff = filter_unchanged_keys(old_dict[key], new_dict[key])
295295
if nested_diff:
296296
diff[key] = nested_diff
297-
elif dict1[key] != dict2[key]:
298-
diff[key] = {"old": dict1[key], "new": dict2[key]}
299-
elif key in dict1:
300-
diff[key] = {"old": dict1[key], "new": None}
297+
elif old_dict[key] != new_dict[key]:
298+
diff[key] = {"old": old_dict[key], "new": new_dict[key]}
299+
elif key in old_dict:
300+
diff[key] = {"old": old_dict[key], "new": None}
301301
else:
302-
diff[key] = {"old": None, "new": dict2[key]}
302+
diff[key] = {"old": None, "new": new_dict[key]}
303303

304304
return diff
305305

306306

307-
def gather_dict_diff(dict1: dict, dict2: dict) -> str:
307+
def gather_dict_diff(old_dict: dict, new_dict: dict) -> str:
308308
"""Gather the differences between two dictionaries and return them as a formatted string."""
309-
diff = filter_unchanged_keys(dict1, dict2)
309+
diff = filter_unchanged_keys(old_dict, new_dict)
310310
lines = []
311311

312312
if not diff:

tests/test_util.py

+120-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,129 @@
88

99
from columnflow.util import maybe_import
1010

11-
from hbw.util import build_param_product, round_sig, dict_diff, four_vec, call_once_on_config
11+
from hbw.util import build_param_product, round_sig, dict_diff, four_vec, call_once_on_config, gather_dict_diff
1212

1313
import order as od
1414

1515
np = maybe_import("numpy")
1616
ak = maybe_import("awkward")
1717

1818

19-
class HbwUtilTest(unittest.TestCase):
19+
class TestDictDiff(unittest.TestCase):
20+
def test_no_difference(self):
21+
dict1 = {"name": "Alice", "age": 25}
22+
dict2 = {"name": "Alice", "age": 25}
23+
result = gather_dict_diff(dict1, dict2)
24+
self.assertEqual(result, "✅ No differences found.")
25+
26+
def test_simple_modification(self):
27+
dict1 = {"name": "Alice", "age": 25}
28+
dict2 = {"name": "Alice", "age": 26}
29+
result = gather_dict_diff(dict1, dict2)
30+
expected_output = (
31+
"🔄 Modified: age:\n"
32+
" - Old: 25\n"
33+
" - New: 26"
34+
)
35+
self.assertEqual(result, expected_output)
36+
37+
def test_addition(self):
38+
dict1 = {"name": "Alice"}
39+
dict2 = {"name": "Alice", "hobby": "cycling"}
40+
result = gather_dict_diff(dict1, dict2)
41+
expected_output = "🔹 Added: hobby: cycling"
42+
self.assertEqual(result, expected_output)
43+
44+
def test_removal(self):
45+
dict1 = {"name": "Alice", "hobby": "cycling"}
46+
dict2 = {"name": "Alice"}
47+
result = gather_dict_diff(dict1, dict2)
48+
expected_output = "🔻 Removed: hobby: cycling"
49+
self.assertEqual(result, expected_output)
50+
51+
def test_nested_modification(self):
52+
dict1 = {
53+
"name": "Alice",
54+
"skills": {
55+
"python": "intermediate",
56+
"sql": "beginner"
57+
}
58+
}
59+
dict2 = {
60+
"name": "Alice",
61+
"skills": {
62+
"python": "advanced",
63+
"sql": "beginner"
64+
}
65+
}
66+
result = gather_dict_diff(dict1, dict2)
67+
expected_output = (
68+
"🔄 Modified: skills:\n"
69+
" 🔄 Modified: python:\n"
70+
" - Old: intermediate\n"
71+
" - New: advanced"
72+
)
73+
self.assertEqual(result, expected_output)
74+
75+
def test_nested_addition(self):
76+
dict1 = {
77+
"name": "Alice",
78+
"skills": {
79+
"python": "intermediate"
80+
}
81+
}
82+
dict2 = {
83+
"name": "Alice",
84+
"skills": {
85+
"python": "intermediate",
86+
"docker": "beginner"
87+
}
88+
}
89+
result = gather_dict_diff(dict1, dict2)
90+
expected_output = (
91+
"🔄 Modified: skills:\n"
92+
" 🔹 Added: docker: beginner"
93+
)
94+
self.assertEqual(result, expected_output)
95+
96+
def test_complex_diff(self):
97+
dict1 = {
98+
"name": "Alice",
99+
"age": 25,
100+
"skills": {
101+
"python": "intermediate",
102+
"sql": "beginner",
103+
},
104+
}
105+
dict2 = {
106+
"name": "Alice",
107+
"age": 26,
108+
"skills": {
109+
"python": "advanced",
110+
"sql": "beginner",
111+
"docker": "beginner",
112+
},
113+
"hobby": "cycling",
114+
}
115+
result = gather_dict_diff(dict1, dict2)
116+
expected_output = (
117+
"🔄 Modified: age:\n"
118+
" - Old: 25\n"
119+
" - New: 26\n"
120+
"🔄 Modified: skills:\n"
121+
" 🔄 Modified: python:\n"
122+
" - Old: intermediate\n"
123+
" - New: advanced\n"
124+
" 🔹 Added: docker: beginner\n"
125+
"🔹 Added: hobby: cycling"
126+
)
127+
self.assertEqual(result, expected_output)
128+
129+
130+
class HbwUtilTest(
131+
TestDictDiff,
132+
unittest.TestCase,
133+
):
20134

21135
def __init__(self, *args, **kwargs):
22136
super().__init__(*args, **kwargs)
@@ -97,3 +211,7 @@ def some_config_function(config: od.Config) -> str:
97211

98212
# on second call, function should not be called -> returns None
99213
self.assertEqual(some_config_function(self.config_inst), None)
214+
215+
216+
if __name__ == "__main__":
217+
unittest.main()

0 commit comments

Comments
 (0)