Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
ab86162
feat: add caching for TextEnvironment and fix bugs
Jan 10, 2025
d09ec63
feat: make TextEnvironment caching optional and add documentation
Jan 10, 2025
b7885cc
fix: failing TextEnvironment tests
Jan 10, 2025
034c5f7
test: add tests for TextEnvironment caching and fix cache combining bug
Jan 10, 2025
18eb106
test: remove unnecessary parametrized class decorator
Jan 10, 2025
44fd184
docs: update TextEnvironmentDocs with caching
Jan 10, 2025
28601c2
fix: run linter on TextEnvironment and TextEnvironment tests
Jan 10, 2025
2a7ec4e
fix: comment
Jan 10, 2025
af06d63
fix: Args comment
Jan 10, 2025
f6f12b5
fix: TextEnvironment cache combination and batching issue
Jan 10, 2025
ede7e81
tests: make caching test more complex
Jan 10, 2025
acddaa7
fix: combine caches of different sequence lengths
Jan 11, 2025
e38940e
docs: update caching warning
Jan 12, 2025
66d0ce4
fix: prevent bos tokens in tool response
Jan 12, 2025
a051e46
docs: Update docs/source/text_environments.md
konrad-gerlach Jan 12, 2025
9ea9287
Update trl/environment/base_environment.py
konrad-gerlach Jan 12, 2025
ae1233a
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 12, 2025
a2860bc
fix: code cleanup
Jan 12, 2025
23014fb
fix: attended to invalid last generated token and off-by-one in Strin…
Jan 14, 2025
bdaa922
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 15, 2025
a097c5b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 17, 2025
7324ee1
fix: off by one error in StringStoppingCriteria
Jan 21, 2025
9b6a6ec
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 21, 2025
39763b1
feat: test logits are same with and without caching
Jan 22, 2025
b70f51c
fix: model and tokenizer were called gpt2 but were another model
Jan 22, 2025
7b2169d
docs: add warning for torch.compile with TextEnvironment use_cache
Jan 22, 2025
c4b5400
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 22, 2025
5725b18
fix: StringStoppingCriteria and add test
Jan 23, 2025
589dcb7
refactor: move StoppingCriteria test
Jan 23, 2025
5e1a7dd
feat: add support for models without cache class support
Jan 23, 2025
cc99580
refactor: make caching code optional in TextEnvironment
Jan 23, 2025
50119a8
docs: TextEnvironment use_cache note untested Encoder-Decoder archite…
Jan 23, 2025
772527b
refactor: extract method from _generate_batched
Jan 24, 2025
4b58de2
feat: add optional logits to TextHistory and remove eos segment creat…
Jan 24, 2025
e012571
tests: remove redundant test code
Jan 24, 2025
520abe8
tests: update logit similarity calculation
Jan 24, 2025
eae5b03
feat: pad each batch individually in TextEnvironment
Jan 24, 2025
0736ce3
refactor: remove redundant past_input_ids manipulation in TExtEnviron…
Jan 24, 2025
6681217
fix: remove redundant attention blocking for padding and eos token an…
Jan 24, 2025
37e5041
fix: code and documenation cleanup for TextEnvironment caching
Jan 25, 2025
a5fc0d7
fix: add more validation to tests and TextEnvironment caching
Jan 26, 2025
6504474
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 26, 2025
ac00cea
refactor: Add more validation to caching in TextEnv
Jan 29, 2025
fc6787f
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 29, 2025
9047db9
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 29, 2025
8e5ac7e
tests: fix logits comparison
Jan 29, 2025
dad1f5d
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 29, 2025
7fec08b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 29, 2025
633d346
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 30, 2025
3174224
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 30, 2025
a4feeb2
tests: remove support_cache_class parametrization causing failing
Jan 31, 2025
74fcd44
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 31, 2025
e8d539f
tests: removed test for private method
Jan 31, 2025
cbf983b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 31, 2025
6b9a5b5
refactor: fix code quality issue
Jan 31, 2025
37b373b
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 31, 2025
7e57eb9
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 2, 2025
ca6fb76
refactor: fix code quality
Feb 2, 2025
66ae80a
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Feb 2, 2025
2a68e73
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 2, 2025
33702dd
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 4, 2025
2b85207
tests: test combining cache with emptied batch
Feb 4, 2025
414b016
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 6, 2025
6047a9a
fix: StoppingCriteria off by one
Feb 6, 2025
b78508a
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Feb 6, 2025
2560fd0
refactor: fix formatting
Feb 6, 2025
bad6bd9
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 6, 2025
bae9f31
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 6, 2025
c2bcec6
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 6, 2025
e99c829
Merge branch 'main' into text_environment_caching
konrad-gerlach Feb 7, 2025
ec8ec31
docs: add incompatibility warning for beamsearch
Feb 7, 2025
dcf7aab
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Feb 7, 2025
635ff40
fix: return when truncating
Feb 8, 2025
0db38f1
docs: add warning for early truncation
Feb 10, 2025
892feda
docs: update warning on early truncation
Feb 11, 2025
60c23bb
docs: add warning that internal batching provides little benefit with…
Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: failing TextEnvironment tests
Konrad Gerlach committed Jan 10, 2025
commit b7885ccbbec23472dfaccca9dacabee77cc6675c
8 changes: 4 additions & 4 deletions tests/test_environments.py
Original file line number Diff line number Diff line change
@@ -26,10 +26,10 @@ def __call__(self, text):
return text


def dummy_generate(histories):
def dummy_generate(histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None):
for i in range(len(histories)):
histories[i].append_segment("<request><DummyTool>test<call>", torch.tensor([1, 2, 3]), system=False)
return histories
return histories, None, None, None, None


class TextHistoryTest(unittest.TestCase):
@@ -131,10 +131,10 @@ def test_text_environment_generate(self):

model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]

generations_batched = env._generate_batched(model_inputs, batch_size=2)
generations_batched,_,_,_,_ = env._generate_batched(model_inputs, batch_size=2)
generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched)

generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs]
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)

self.assertEqual(generations_single, generations_batched)
2 changes: 1 addition & 1 deletion trl/environment/base_environment.py
Original file line number Diff line number Diff line change
@@ -579,7 +579,7 @@ def _generate_batched(
if past_input_ids is not None:
padded_inputs["input_ids"] = torch.concatenate([past_input_ids,padded_inputs["input_ids"]],dim=1)

if padded_inputs["input_ids"].shape[-1]>self.max_length:
if self.max_length is not None and padded_inputs["input_ids"].shape[-1]>self.max_length:
return None, None, None,None, True

generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)