diff --git a/setup.py b/setup.py index 907269c2..e31d9fcc 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ "peft>=0.14.0", "pytest", "python-dotenv", + "requests", "ruff>=0.9.0", "safetensors>=0.3.3", "sentencepiece>=0.1.99", diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py index db4474b9..5947a873 100644 --- a/src/open_r1/rewards/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -95,7 +95,7 @@ def verify_triple_33( v = res.json() except JSONDecodeError: v = None - print(v) + return v # else: except HTTPError as he: print(f"HTTP error: {he}") diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 00000000..d5fa705d --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,40 @@ +import unittest + +from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 + + +class TestApi(unittest.TestCase): + def test_gen_triples_structure(): + n_stmt = 3 + for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt): + len_program = len(o['program']) + self.assertEqual(len_program, n_stmt) + def test_verify_triple_result(): + is_total = True + preconditions = "True" # trivial precondition + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" + post_ok = "v5 == (0 - v3)" # post-condition that verifies + post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify + # # should return True + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_ok + ) + res_ok = o['prediction_is_correct'] + self.assertEqual(res_ok, True) + # # should return False + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_not_ok + ) + res_not_ok = o['prediction_is_correct'] + salf.assertEqual(res_not_ok, False) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 9e41bdb0..7bcf807c 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -9,7 +9,6 @@ reasoning_steps_reward, ) - class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer."""