Skip to content

Commit

Permalink
add API test
Browse files Browse the repository at this point in the history
  • Loading branch information
Marco Zocca authored and ocramz committed Feb 21, 2025
1 parent b3a9587 commit 41af428
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 2 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"peft>=0.14.0",
"pytest",
"python-dotenv",
"requests",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
Expand Down
2 changes: 1 addition & 1 deletion src/open_r1/rewards/api/code/unfoldml/htgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def verify_triple_33(
v = res.json()
except JSONDecodeError:
v = None
print(v)
return v
# else:
except HTTPError as he:
print(f"HTTP error: {he}")
Expand Down
40 changes: 40 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33


class TestApi(unittest.TestCase):
def test_gen_triples_structure():
n_stmt = 3
for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt):
len_program = len(o['program'])
self.assertEqual(len_program, n_stmt)
def test_verify_triple_result():
is_total = True
preconditions = "True" # trivial precondition
program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4"
post_ok = "v5 == (0 - v3)" # post-condition that verifies
post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify
# # should return True
o = verify_triple_33(
is_total = is_total,
preconditions = preconditions,
program = program,
postconditions = post_ok
)
res_ok = o['prediction_is_correct']
self.assertEqual(res_ok, True)
# # should return False
o = verify_triple_33(
is_total = is_total,
preconditions = preconditions,
program = program,
postconditions = post_not_ok
)
res_not_ok = o['prediction_is_correct']
salf.assertEqual(res_not_ok, False)



if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
reasoning_steps_reward,
)


class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
Expand Down

0 comments on commit 41af428

Please sign in to comment.