-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
base: main
Are you sure you want to change the base?
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
Conversation
I would be very grateful for a review by: |
6a87c8d
to
3f57ee9
Compare
3f57ee9
to
ede7e81
Compare
I was unable to execute the pre-commit hook, so I manually ran the linter. |
Thanks for the PR! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Just to be sure, as I'm unfamiliar with their implementation: The trl Trainers like PPO should not try to back propagate through the generated tokens, right? |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
I noticed, that GPT2 seems to only support the legacy cache format, so I am adding support for this, |
It appears, that it was still not fixed. Working on a solution and on testing StringStoppingCriteria |
As the tests did not include a encoder-decoder architecture, I did not test for it either. I think, that this is out of scope for this Pull Request. Where this was of concern in _generate_batched, I mirrored the implementation already provided. |
I added a TODO in the code, that the batch_size parameter should be exposed to the user. I think this is out of scope for this PR. |
…d add tests for utility methods in TextEnvironment
Should there be a general warning for TextEnvironment, that support for Encoder Decoder is not (automatically) tested? |
@qgallouedec I am ready for review. Please also read the comments/ questions above and do the pre-commit hooks (as discussed). I hope you like the code :) |
…gerlach/trl into text_environment_caching
…gerlach/trl into text_environment_caching
A notable behavioural change of this PR is, that if input_ids exceeds TextEnvironment.max_length we truncate all sequences instead of allowing the model to error. This technically breaks the isolation of the samples within the batch, as a history might be truncated due to another history being too long, but I think this is acceptable behaviour for now. Future work might remove the limiting sequence(s), truncate their histories and continue generating using the remaining histories. This is out of scope for this PR. |
This PR mainly affects the TextEnvironment class and adds caching in between generation calls, in order to not have to recompute all previous activations when generating the next segment. This is mainly intended for use cases where many tool calls are performed sequentially and thus the activations for the (possibly quite large) system prompt would have to be calculated at each step. For stability, caching is optional.
Bug fixes:
This issue also addresses two bugs I encountered:
I fixed the bug and also added a check at generation time to ensure, that the padded inputs also do not exceed max length.
RE testing:
I only made sure, that the tests in tests/test_environments.py were completing.
Using
make test
some tests were failing and the tests were taking a long time to run. However, the only tests, which call TextEnvironment seem to be in test_environments.py, so the rest should be unaffected as far as I know. Nevertheless, I would be grateful, if somebody else could run all the tests before merging. I suspect, that my environment may not be ideally configured. Is testing automated via a CI?