Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ab86162
feat: add caching for TextEnvironment and fix bugs
Jan 10, 2025
d09ec63
feat: make TextEnvironment caching optional and add documentation
Jan 10, 2025
b7885cc
fix: failing TextEnvironment tests
Jan 10, 2025
034c5f7
test: add tests for TextEnvironment caching and fix cache combining bug
Jan 10, 2025
18eb106
test: remove unnecessary parametrized class decorator
Jan 10, 2025
44fd184
docs: update TextEnvironmentDocs with caching
Jan 10, 2025
28601c2
fix: run linter on TextEnvironment and TextEnvironment tests
Jan 10, 2025
2a7ec4e
fix: comment
Jan 10, 2025
af06d63
fix: Args comment
Jan 10, 2025
f6f12b5
fix: TextEnvironment cache combination and batching issue
Jan 10, 2025
ede7e81
tests: make caching test more complex
Jan 10, 2025
acddaa7
fix: combine caches of different sequence lengths
Jan 11, 2025
e38940e
docs: update caching warning
Jan 12, 2025
66d0ce4
fix: prevent bos tokens in tool response
Jan 12, 2025
a051e46
docs: Update docs/source/text_environments.md
konrad-gerlach Jan 12, 2025
9ea9287
Update trl/environment/base_environment.py
konrad-gerlach Jan 12, 2025
ae1233a
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 12, 2025
a2860bc
fix: code cleanup
Jan 12, 2025
23014fb
fix: attended to invalid last generated token and off-by-one in Strin…
Jan 14, 2025
bdaa922
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 15, 2025
a097c5b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 17, 2025
7324ee1
fix: off by one error in StringStoppingCriteria
Jan 21, 2025
9b6a6ec
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 21, 2025
39763b1
feat: test logits are same with and without caching
Jan 22, 2025
b70f51c
fix: model and tokenizer were called gpt2 but were another model
Jan 22, 2025
7b2169d
docs: add warning for torch.compile with TextEnvironment use_cache
Jan 22, 2025
c4b5400
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 22, 2025
5725b18
fix: StringStoppingCriteria and add test
Jan 23, 2025
589dcb7
refactor: move StoppingCriteria test
Jan 23, 2025
5e1a7dd
feat: add support for models without cache class support
Jan 23, 2025
cc99580
refactor: make caching code optional in TextEnvironment
Jan 23, 2025
50119a8
docs: TextEnvironment use_cache note untested Encoder-Decoder archite…
Jan 23, 2025
772527b
refactor: extract method from _generate_batched
Jan 24, 2025
4b58de2
feat: add optional logits to TextHistory and remove eos segment creat…
Jan 24, 2025
e012571
tests: remove redundant test code
Jan 24, 2025
520abe8
tests: update logit similarity calculation
Jan 24, 2025
eae5b03
feat: pad each batch individually in TextEnvironment
Jan 24, 2025
0736ce3
refactor: remove redundant past_input_ids manipulation in TExtEnviron…
Jan 24, 2025
6681217
fix: remove redundant attention blocking for padding and eos token an…
Jan 24, 2025
37e5041
fix: code and documenation cleanup for TextEnvironment caching
Jan 25, 2025
a5fc0d7
fix: add more validation to tests and TextEnvironment caching
Jan 26, 2025
6504474
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 26, 2025
ac00cea
refactor: Add more validation to caching in TextEnv
Jan 29, 2025
fc6787f
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 29, 2025
9047db9
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 29, 2025
8e5ac7e
tests: fix logits comparison
Jan 29, 2025
dad1f5d
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 29, 2025
7fec08b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 29, 2025
633d346
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 30, 2025
3174224
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/text_environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Let's decompose the settings:
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. Compatibility with Encoder-Decoder architectures is untested. `use_cache` may currently be incompatible with torch.compile due to a possible issue in the transformers library's generate method. See this comment in [_get_initial_cache_position](https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582).|
| `save_logits` | Whether to save logits for the generated tokens in the returned histories. Mainly intended to help the user test caching for their use case. Backpropagation through logits is not supported. |

You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!

Expand Down Expand Up @@ -169,6 +171,7 @@ The following table summarises the available attributes of the `TextEnvironment`
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
| `logits` | A list containing a tensor for each non-system segment. If no logits were saved, then `logits` is an empty list. The tensors contain the logits for the tokens generated by the model. See `save_logits` in [`TextEnvironment`]. |

With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!

Expand Down
Loading