From ab861626b56abb10492ee0be77de631fc669fbc7 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 10:02:33 +0100 Subject: [PATCH 01/38] feat: add caching for TextEnvironment and fix bugs --- trl/environment/base_environment.py | 146 ++++++++++++++++++++++++---- 1 file changed, 125 insertions(+), 21 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 133ce97170..e94e65ee13 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -13,11 +13,12 @@ # limitations under the License. import re + from typing import Optional import torch from accelerate.utils import extract_model_from_parallel -from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import StoppingCriteria, StoppingCriteriaList, DynamicCache from ..import_utils import is_rich_available @@ -36,7 +37,7 @@ def __init__(self, stop_strings, tokenizer): self.first_call = True def __call__(self, input_ids, scores, **kwargs): - """Returns true if all generated sequences contain any of the stop strings.""" + """Returns true if all generated sequences contain any of the stop strings or terminated early.""" if self.first_call: self.generated_tokens = [1 for _ in range(input_ids.shape[0])] self.start_length = input_ids.shape[-1] - 1 @@ -45,7 +46,7 @@ def __call__(self, input_ids, scores, **kwargs): done = [] for i, decoded_generation in enumerate(decoded_generations): - sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings) + sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings) or self.tokenizer.eos_token_id in input_ids[i, self.start_length :] done.append(sequence_complete) if not sequence_complete: self.generated_tokens[i] += 1 @@ -127,6 +128,14 @@ def last_text_segment(self): """ start, end = self.text_spans[-1] return self.text[start:end] + + @property + def last_token_segment(self): + """ + Get the last token segment + """ + start, end = self.token_spans[-1] + return self.tokens[start:end] def split_query_response_tokens(self): """ @@ -214,6 +223,8 @@ def show_colour_legend(self): class TextEnvironment: """ The TextEnvironment enables interaction of a LLM with an environment using tools. + When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. + Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode. """ def __init__( @@ -283,8 +294,10 @@ def run(self, queries, **rewards_kwargs): histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + past_key_values,past_attention_masks,past_input_ids,last_active_histories = (None,None,None,None) + while any(not history.completed for history in histories) and turns < self.max_turns: - histories = self.generate(histories) + histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) histories = self.tasks_end_check(histories) # TODO: make this parallel rather than for-loop for i in range(len(histories)): @@ -370,25 +383,69 @@ def compute_reward(self, histories, **reward_kwargs): history.reward = reward return histories - def generate(self, histories): + def _next_input(self,history): + return history.last_token_segment if not history.completed else torch.tensor([]) + + #combines all caches in order to exclude completed histories from further generation + #batch_examples: list of masks indicating for each example, whether it is supposed to remain or not + def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_input_ids): + legacy_format = [cache.to_legacy_cache() for cache in past_key_values ] + #combines all caches, excluding + example_mask_offset = 0 + combined_cache = [] + for layer_id in range(len(legacy_format[0])): + layer = None + for cache_idx, cache in enumerate(legacy_format): + layer = cache[layer_id] + num_examples = len(layer[0]) + new_keys = layer[0][example_mask[example_mask_offset:example_mask_offset+num_examples]] + new_values = layer[1][example_mask[example_mask_offset:example_mask_offset+num_examples]] + if layer is None: + layer = (new_keys,new_values) + else: + other_new_keys,other_new_values = layer + layer = (torch.concat([other_new_keys,new_keys],dim=0),torch.concat([other_new_values,new_values],dim=0)) + example_mask_offset += num_examples + combined_cache.append(layer) + combined_cache = tuple(combined_cache) + + combined_attention_masks = torch.concat(past_attention_masks,dim=0)[example_mask] + combined_input_ids = torch.concat(past_input_ids,dim=0)[example_mask] + + return combined_cache, combined_attention_masks, combined_input_ids + + def generate(self, histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None): """ Generate responses for a list of histories. """ - active_histories = [i for i, history in enumerate(histories) if not history.completed] - - query_tensors = [histories[i].tokens for i in active_histories] - response_tensors = self._generate_batched(query_tensors) - response_texts = self.tokenizer.batch_decode(response_tensors) - - for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): - histories[i].append_segment(response_text, response_tensor, system=False) + active_histories = [i for i in range(len(histories)) if not histories[i].completed] + query_tensors = [self._next_input(histories[i]) for i in active_histories] + combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = (None,None,None) + if past_key_values is not None: + example_mask = [(not histories[i].completed) for i in last_active_histories] + combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = self._combine_cache(example_mask,past_key_values,past_attention_masks,past_input_ids) + + response_tensors,past_key_values,past_attention_masks,past_input_ids, truncated = self._generate_batched(query_tensors,combined_past_key_values=combined_past_key_values,combined_past_attention_masks=combined_past_attention_masks, combined_past_input_ids=combined_past_input_ids) + if not truncated: + response_texts = self.tokenizer.batch_decode(response_tensors) + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + history = histories[i] + if not history.completed: + history.append_segment(response_text, response_tensor, system=False) + else: + for history in histories: + if not history.completed: + #Adds an eos token, so that we always end on a non-system segment + history.append_segment(self.tokenizer.eos_token, torch.tensor([self.tokenizer.eos_token_id]).to(self.current_device), system=False) + history.complete(truncated=True) - return histories + return histories,past_key_values,past_attention_masks,past_input_ids, active_histories def tasks_end_check(self, histories, model_turn=True): """ Check if the current generation sequences have finished. """ + for history in histories: if not history.completed: truncated, ended = self.task_end_check(history, model_turn=model_turn) @@ -404,7 +461,7 @@ def task_end_check(self, history, model_turn=True): ended = False if history.completed: return truncated, ended - if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + if self.max_length is not None and len(history.tokens) > self.max_length: truncated = True ended = True elif self.tokenizer.eos_token in history.text: @@ -418,11 +475,26 @@ def task_end_check(self, history, model_turn=True): ended = True return truncated, ended + #builds the cache for the current batch + def _get_batched_cache(self,start_index, end_index, combined_past_key_values, combined_attention_masks, combined_input_ids): + current_cache = [] + for layer_id, layer in enumerate(combined_past_key_values): + keys,values = layer + new_keys = keys[start_index:end_index] + new_values = values[start_index:end_index] + current_cache.append((new_keys,new_values)) + current_cache = tuple(current_cache) + return DynamicCache().from_legacy_cache(current_cache), combined_attention_masks[start_index:end_index], combined_input_ids[start_index:end_index] + + #TODO make batch_size changeable def _generate_batched( self, query_tensors, batch_size: int = 16, pad_to_multiple_of: Optional[int] = None, + combined_past_key_values=None,#past_key_values in legacy format + combined_past_attention_masks=None, + combined_past_input_ids = None ): """ Generate responses for a list of query tensors. @@ -437,15 +509,20 @@ def _generate_batched( if not self.is_encoder_decoder: self.tokenizer.padding_side = "left" + + new_past_key_values = [] + new_past_attention_masks = [] + new_past_input_ids = [] # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) - - for i in range(0, len(query_tensors), batch_size): + for batch_index,i in enumerate(range(0, len(query_tensors), batch_size)): # prevent overflow if query tensors are not even multiple of bs end_index = min(len(query_tensors), i + batch_size) - batch = query_tensors[i:end_index] batch_mask = [torch.ones_like(element) for element in batch] + past_key_values, past_attention_masks, past_input_ids = (None,None,None) + if combined_past_key_values is not None: + past_key_values, past_attention_masks, past_input_ids = self._get_batched_cache(i,end_index,combined_past_key_values,combined_past_attention_masks,combined_past_input_ids) inputs = {"input_ids": batch, "attention_mask": batch_mask} padded_inputs = self.tokenizer.pad( @@ -455,25 +532,52 @@ def _generate_batched( pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt", ).to(self.current_device) + input_attention_mask = padded_inputs["attention_mask"].clone() stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + self.generation_kwargs["use_cache"] = True + self.generation_kwargs["return_dict_in_generate"] = True + #handle caching + self.generation_kwargs["past_key_values"] = past_key_values if past_key_values is not None else DynamicCache() + if past_attention_masks is not None: + padded_inputs["attention_mask"] = torch.concatenate([past_attention_masks,padded_inputs["attention_mask"]],dim=1) + if past_input_ids is not None: + padded_inputs["input_ids"] = torch.concatenate([past_input_ids,padded_inputs["input_ids"]],dim=1) + + if padded_inputs["input_ids"].shape[-1]>self.max_length: + return None, None, None,None, True generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + new_past_key_values.append(generations.past_key_values) + + past_attention_mask = torch.ones_like(generations.sequences) + #Don't attend to generated padding or eos tokens + past_attention_mask[torch.logical_or(generations.sequences==self.tokenizer.eos_token_id, generations.sequences==self.tokenizer.pad_token_id)] = 0 + past_attention_mask[:,:input_attention_mask.shape[1]] = input_attention_mask + + generations = generations.sequences - for generation, mask, generated_tokens in zip( - generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + new_past_input_ids.append(generations) + for generation, mask, generated_tokens, new_attention_mask in zip( + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens,past_attention_mask ): if not self.is_encoder_decoder: output = generation[(1 - mask).sum() :] # remove padding + padding_removed_past_attention_mask = new_attention_mask[(1 - mask).sum() :] else: output = generation + padding_removed_past_attention_mask = new_attention_mask if not self.is_encoder_decoder: output = output[(mask).sum() :] # remove prompt + generated_tokens_attention_mask = padding_removed_past_attention_mask[(mask).sum() :] # remove chunk generated after stopping criteria in batch mode outputs.append(output[:generated_tokens]) + #Do not attend to tokens that were generated after or + generated_tokens_attention_mask[generated_tokens:]=0 + new_past_attention_masks.append(past_attention_mask) self.tokenizer.padding_side = padding_side_default - return outputs + return outputs, new_past_key_values, new_past_attention_masks,new_past_input_ids, False From d09ec63878aba955cfe6922ba0114e891265ad84 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 10:51:18 +0100 Subject: [PATCH 02/38] feat: make TextEnvironment caching optional and add documentation --- trl/environment/base_environment.py | 55 +++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index e94e65ee13..543b012db9 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -223,8 +223,6 @@ def show_colour_legend(self): class TextEnvironment: """ The TextEnvironment enables interaction of a LLM with an environment using tools. - When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. - Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode. """ def __init__( @@ -238,6 +236,7 @@ def __init__( max_tool_reponse=100, max_length=None, generation_kwargs=None, + use_cache=False ): """ Initialize TextEnvironment. @@ -252,6 +251,7 @@ def __init__( max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. max_length (Optional[int]): The maximum number of tokens to allow in an episode. generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + use_cache (bool): Whether or not to cache past_key_values between segments. When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode. """ self.model = model self.tokenizer = tokenizer @@ -268,6 +268,7 @@ def __init__( self.submit_token = "" self.max_turns = max_turns self.max_tool_response = max_tool_reponse + self.use_cache = use_cache if generation_kwargs is None: self.generation_kwargs = dict() @@ -293,11 +294,15 @@ def run(self, queries, **rewards_kwargs): ] histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] - + past_key_values,past_attention_masks,past_input_ids,last_active_histories = (None,None,None,None) while any(not history.completed for history in histories) and turns < self.max_turns: - histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) + if self.use_cache: + histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) + else: + #Discard cache + histories,_,_,_,_ = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) histories = self.tasks_end_check(histories) # TODO: make this parallel rather than for-loop for i in range(len(histories)): @@ -385,12 +390,18 @@ def compute_reward(self, histories, **reward_kwargs): def _next_input(self,history): return history.last_token_segment if not history.completed else torch.tensor([]) - - #combines all caches in order to exclude completed histories from further generation - #batch_examples: list of masks indicating for each example, whether it is supposed to remain or not + def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_input_ids): + """ + combines all caches in order to exclude completed histories from further generation + + Args: + batch_examples (list[bool]): mask indicating for each example, whether it is supposed to remain or not + past_key_values (list[transformers.DynamicCache]) : Batched list of caches from the last generation + past_attention_masks (list[torch.Tensor]): Batched list of attention masks from the last generation + past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation + """ legacy_format = [cache.to_legacy_cache() for cache in past_key_values ] - #combines all caches, excluding example_mask_offset = 0 combined_cache = [] for layer_id in range(len(legacy_format[0])): @@ -417,13 +428,23 @@ def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_i def generate(self, histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None): """ Generate responses for a list of histories. + Either all of past_key_values, past_attention_masks, past_input_ids,last_active_histories are provided or all are None. + Args: + histories (list[TextHistory]): + past_key_values (Optional[list[transformers.DynamicCache]]): Batched list of caches from the last generation + past_attention_masks (Optional[list[torch.Tensor]]): Batched list of attention masks from the last generation + past_input_ids (Optional[list[torch.Tensor]]): Batched list of input ids from the last generation + last_active_histories (Optional[list[int]]): indices of histories for which generation took place during the last generation turn """ active_histories = [i for i in range(len(histories)) if not histories[i].completed] - query_tensors = [self._next_input(histories[i]) for i in active_histories] combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = (None,None,None) + if past_key_values is not None: + query_tensors = [self._next_input(histories[i]) for i in active_histories] example_mask = [(not histories[i].completed) for i in last_active_histories] combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = self._combine_cache(example_mask,past_key_values,past_attention_masks,past_input_ids) + else: + query_tensors = [histories[i].tokens for i in active_histories] response_tensors,past_key_values,past_attention_masks,past_input_ids, truncated = self._generate_batched(query_tensors,combined_past_key_values=combined_past_key_values,combined_past_attention_masks=combined_past_attention_masks, combined_past_input_ids=combined_past_input_ids) if not truncated: @@ -477,6 +498,14 @@ def task_end_check(self, history, model_turn=True): #builds the cache for the current batch def _get_batched_cache(self,start_index, end_index, combined_past_key_values, combined_attention_masks, combined_input_ids): + """ + Extract (batch) cache for current batch + start_index (int): start index of current batch + end_index (int): end index of current batch (points to first element not in batch) + combined_past_key_values (tuple[tuple[torch.Tensor]]) : The combined (unbatched) cache in legacy format from the last generation + combined_past_attention_masks (torch.Tensor): The combined (unbatched) attention masks from the last generation + combined_past_input_ids (torch.Tensor): The combined (unbatched) input ids from the last generation + """ current_cache = [] for layer_id, layer in enumerate(combined_past_key_values): keys,values = layer @@ -486,23 +515,27 @@ def _get_batched_cache(self,start_index, end_index, combined_past_key_values, co current_cache = tuple(current_cache) return DynamicCache().from_legacy_cache(current_cache), combined_attention_masks[start_index:end_index], combined_input_ids[start_index:end_index] + #TODO make batch_size changeable def _generate_batched( self, query_tensors, batch_size: int = 16, pad_to_multiple_of: Optional[int] = None, - combined_past_key_values=None,#past_key_values in legacy format + combined_past_key_values=None, combined_past_attention_masks=None, combined_past_input_ids = None ): """ Generate responses for a list of query tensors. - + Either all of combined_past_key_values, combined_past_attention_masks, combined_past_input_ids are provided or all are None. Args: query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. batch_size (int): The batch size to use for generation. pad_to_multiple_of (int): The padding length to use for generation. + combined_past_key_values (Optional[tuple[tuple[torch.Tensor]]]) : The combined (unbatched) cache in legacy format from the last generation + combined_past_attention_masks (Optional[torch.Tensor]): The combined (unbatched) attention masks from the last generation + combined_past_input_ids (Optional[torch.Tensor]): The combined (unbatched) input ids from the last generation """ outputs = [] padding_side_default = self.tokenizer.padding_side From b7885ccbbec23472dfaccca9dacabee77cc6675c Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 12:12:44 +0100 Subject: [PATCH 03/38] fix: failing TextEnvironment tests --- tests/test_environments.py | 8 ++++---- trl/environment/base_environment.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index e4b2cd52a1..02e310f7fe 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -26,10 +26,10 @@ def __call__(self, text): return text -def dummy_generate(histories): +def dummy_generate(histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None): for i in range(len(histories)): histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) - return histories + return histories, None, None, None, None class TextHistoryTest(unittest.TestCase): @@ -131,10 +131,10 @@ def test_text_environment_generate(self): model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - generations_batched = env._generate_batched(model_inputs, batch_size=2) + generations_batched,_,_,_,_ = env._generate_batched(model_inputs, batch_size=2) generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) - generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs] + generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs] generations_single = self.gpt2_tokenizer.batch_decode(generations_single) self.assertEqual(generations_single, generations_batched) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 543b012db9..62b8f7cc97 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -579,7 +579,7 @@ def _generate_batched( if past_input_ids is not None: padded_inputs["input_ids"] = torch.concatenate([past_input_ids,padded_inputs["input_ids"]],dim=1) - if padded_inputs["input_ids"].shape[-1]>self.max_length: + if self.max_length is not None and padded_inputs["input_ids"].shape[-1]>self.max_length: return None, None, None,None, True generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) From 034c5f742edc391d8aaea3fc3439cc4959d00bcf Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 15:09:49 +0100 Subject: [PATCH 04/38] test: add tests for TextEnvironment caching and fix cache combining bug --- tests/test_environments.py | 99 ++++++++++++++++++++++++++++- trl/environment/base_environment.py | 14 ++-- 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 02e310f7fe..2c5047bbff 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -13,10 +13,11 @@ # limitations under the License. import unittest +from parameterized import parameterized_class from unittest.mock import patch import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, DynamicCache from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory @@ -79,6 +80,7 @@ def test_text_history_last_segment(self): history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) self.assertEqual(history.last_text_segment, "You are a bold one!") + self.assertTrue(torch.all(history.last_token_segment== torch.tensor([7, 8, 9])).item()) def test_text_history_split_query_response(self): text = "Hello there!" @@ -93,6 +95,9 @@ def test_text_history_split_query_response(self): self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) + + +@parameterized_class(("use_cache",),[(True,),(False,)]) class TextEnvironmentTester(unittest.TestCase): def setUp(self): # model_id @@ -103,6 +108,7 @@ def setUp(self): self.gpt2_tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token + def test_text_environment_setup(self): env = TextEnvironment( self.gpt2_model, @@ -276,3 +282,94 @@ def test_text_environment_run(self, mock_generate): ("I am a prompt!\n" + "Hello there! General Kenobi!") + (2 * "testtest"), ) + + def test_combine_cache(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + + caches = [((torch.tensor([[1],[2]]), + torch.tensor([[3],[4]])),), + ((torch.tensor([[5]]), + torch.tensor([[6]])),)] + caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] + attention_masks = [torch.tensor([[0],[1]]),torch.tensor([[2]])] + input_ids = [torch.tensor([[1],[2]]),torch.tensor([[3]])] + example_mask = [True,False,True] + + expected_cache = ((torch.tensor([[1],[5]]),torch.tensor([[3],[6]])),) + expected_attention_mask = torch.tensor([[0],[2]]) + expected_input_ids = torch.tensor([[1],[3]]) + + combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache(example_mask, caches, attention_masks,input_ids) + + self.assertEqual(len(combined_cache),len(expected_cache)) + self.assertEqual(len(combined_cache[0]),len(expected_cache[0])) + self.assertTrue(torch.all(combined_cache[0][0]==expected_cache[0][0])) + self.assertTrue(torch.all(combined_cache[0][1]==expected_cache[0][1])) + self.assertTrue(torch.all(combined_attention_masks==expected_attention_mask)) + self.assertTrue(torch.all(combined_input_ids==expected_input_ids)) + + def test_get_batched_cache(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + + cache = ((torch.tensor([[1],[2],[3]]),torch.tensor([[4],[5],[6]])),) + attention_masks = torch.tensor([[1],[2],[3]]) + input_ids = torch.tensor([[4],[5],[6]]) + batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache(1,3,cache,attention_masks,input_ids) + batched_cache = batched_cache.to_legacy_cache() + expected_cache = ((torch.tensor([[2],[3]]),torch.tensor([[5],[6]])),) + + self.assertEqual(len(batched_cache),len(expected_cache)) + self.assertEqual(len(batched_cache[0]),len(expected_cache[0])) + self.assertTrue(torch.all(batched_cache[0][0]==expected_cache[0][0])) + self.assertTrue(torch.all(batched_cache[0][1]==expected_cache[0][1])) + + expected_attention_mask = torch.tensor([[2],[3]]) + self.assertTrue(torch.all(batched_attention_masks==expected_attention_mask)) + + expected_input_ids = torch.tensor([[5],[6]]) + self.assertTrue(torch.all(batched_input_ids==expected_input_ids)) + + def test_cached_generate_batched(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test"] + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + outputs, past_key_values, past_attention_masks,past_input_ids, _ = env._generate_batched(model_inputs, batch_size=2) + past_key_values = past_key_values[0].to_legacy_cache() + past_attention_masks = past_attention_masks[0] + past_input_ids = past_input_ids[0] + + input_texts2 = [" short interim", " a slightly longer interim"] + model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + + outputs_cached, _,_,_,_ = env._generate_batched(model_inputs2, batch_size=2,combined_past_key_values=past_key_values,combined_past_attention_masks=past_attention_masks,combined_past_input_ids=past_input_ids) + + model_inputs2_full = [torch.concat([in1,out1,in2],dim=0) for in1,out1,in2 in zip(model_inputs,outputs, model_inputs2)] + + outputs_uncached, _, _,_, _ = env._generate_batched(model_inputs2_full, batch_size=2) + + for cached, uncached in zip(outputs_cached,outputs_uncached): + self.assertTrue(torch.all(cached==uncached)) + diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 62b8f7cc97..b84eab2a97 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -405,19 +405,19 @@ def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_i example_mask_offset = 0 combined_cache = [] for layer_id in range(len(legacy_format[0])): - layer = None + combined_layer = None for cache_idx, cache in enumerate(legacy_format): layer = cache[layer_id] num_examples = len(layer[0]) new_keys = layer[0][example_mask[example_mask_offset:example_mask_offset+num_examples]] new_values = layer[1][example_mask[example_mask_offset:example_mask_offset+num_examples]] - if layer is None: - layer = (new_keys,new_values) + if combined_layer is None: + combined_layer = (new_keys,new_values) else: - other_new_keys,other_new_values = layer - layer = (torch.concat([other_new_keys,new_keys],dim=0),torch.concat([other_new_values,new_values],dim=0)) + other_new_keys,other_new_values = combined_layer + combined_layer = (torch.concat([other_new_keys,new_keys],dim=0),torch.concat([other_new_values,new_values],dim=0)) example_mask_offset += num_examples - combined_cache.append(layer) + combined_cache.append(combined_layer) combined_cache = tuple(combined_cache) combined_attention_masks = torch.concat(past_attention_masks,dim=0)[example_mask] @@ -542,7 +542,6 @@ def _generate_batched( if not self.is_encoder_decoder: self.tokenizer.padding_side = "left" - new_past_key_values = [] new_past_attention_masks = [] new_past_input_ids = [] @@ -566,7 +565,6 @@ def _generate_batched( return_tensors="pt", ).to(self.current_device) input_attention_mask = padded_inputs["attention_mask"].clone() - stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) From 18eb106796760c70daca7e1399a426d137b67516 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 15:12:14 +0100 Subject: [PATCH 05/38] test: remove unnecessary parametrized class decorator --- tests/test_environments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 2c5047bbff..a6cb0a9b66 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -from parameterized import parameterized_class from unittest.mock import patch import torch @@ -97,7 +96,6 @@ def test_text_history_split_query_response(self): -@parameterized_class(("use_cache",),[(True,),(False,)]) class TextEnvironmentTester(unittest.TestCase): def setUp(self): # model_id From 44fd1841888513772e33218d40d4c81e009c45ad Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 15:34:22 +0100 Subject: [PATCH 06/38] docs: update TextEnvironmentDocs with caching --- docs/source/text_environments.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index c7b0bd0cfd..1ab2e131da 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,6 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | +| `use_cache` | Cache keys and values between segment generation. Warning: When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode.| You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! From 28601c2500813056978c05a586c6b0b3265c348b Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 15:51:20 +0100 Subject: [PATCH 07/38] fix: run linter on TextEnvironment and TextEnvironment tests --- tests/test_environments.py | 106 +++++++++++--------- trl/environment/base_environment.py | 149 ++++++++++++++++++---------- 2 files changed, 156 insertions(+), 99 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index a6cb0a9b66..f0f57ad2f9 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -26,7 +26,9 @@ def __call__(self, text): return text -def dummy_generate(histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None): +def dummy_generate( + histories, past_key_values=None, past_attention_masks=None, past_input_ids=None, last_active_histories=None +): for i in range(len(histories)): histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) return histories, None, None, None, None @@ -79,7 +81,7 @@ def test_text_history_last_segment(self): history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) self.assertEqual(history.last_text_segment, "You are a bold one!") - self.assertTrue(torch.all(history.last_token_segment== torch.tensor([7, 8, 9])).item()) + self.assertTrue(torch.all(history.last_token_segment == torch.tensor([7, 8, 9])).item()) def test_text_history_split_query_response(self): text = "Hello there!" @@ -94,8 +96,6 @@ def test_text_history_split_query_response(self): self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) - - class TextEnvironmentTester(unittest.TestCase): def setUp(self): # model_id @@ -106,7 +106,6 @@ def setUp(self): self.gpt2_tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token - def test_text_environment_setup(self): env = TextEnvironment( self.gpt2_model, @@ -135,7 +134,7 @@ def test_text_environment_generate(self): model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - generations_batched,_,_,_,_ = env._generate_batched(model_inputs, batch_size=2) + generations_batched, _, _, _, _ = env._generate_batched(model_inputs, batch_size=2) generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs] @@ -291,28 +290,30 @@ def test_combine_cache(self): max_turns=2, ) - caches = [((torch.tensor([[1],[2]]), - torch.tensor([[3],[4]])),), - ((torch.tensor([[5]]), - torch.tensor([[6]])),)] + caches = [ + ((torch.tensor([[1], [2]]), torch.tensor([[3], [4]])),), + ((torch.tensor([[5]]), torch.tensor([[6]])),), + ] caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] - attention_masks = [torch.tensor([[0],[1]]),torch.tensor([[2]])] - input_ids = [torch.tensor([[1],[2]]),torch.tensor([[3]])] - example_mask = [True,False,True] - - expected_cache = ((torch.tensor([[1],[5]]),torch.tensor([[3],[6]])),) - expected_attention_mask = torch.tensor([[0],[2]]) - expected_input_ids = torch.tensor([[1],[3]]) - - combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache(example_mask, caches, attention_masks,input_ids) - - self.assertEqual(len(combined_cache),len(expected_cache)) - self.assertEqual(len(combined_cache[0]),len(expected_cache[0])) - self.assertTrue(torch.all(combined_cache[0][0]==expected_cache[0][0])) - self.assertTrue(torch.all(combined_cache[0][1]==expected_cache[0][1])) - self.assertTrue(torch.all(combined_attention_masks==expected_attention_mask)) - self.assertTrue(torch.all(combined_input_ids==expected_input_ids)) - + attention_masks = [torch.tensor([[0], [1]]), torch.tensor([[2]])] + input_ids = [torch.tensor([[1], [2]]), torch.tensor([[3]])] + example_mask = [True, False, True] + + expected_cache = ((torch.tensor([[1], [5]]), torch.tensor([[3], [6]])),) + expected_attention_mask = torch.tensor([[0], [2]]) + expected_input_ids = torch.tensor([[1], [3]]) + + combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache( + example_mask, caches, attention_masks, input_ids + ) + + self.assertEqual(len(combined_cache), len(expected_cache)) + self.assertEqual(len(combined_cache[0]), len(expected_cache[0])) + self.assertTrue(torch.all(combined_cache[0][0] == expected_cache[0][0])) + self.assertTrue(torch.all(combined_cache[0][1] == expected_cache[0][1])) + self.assertTrue(torch.all(combined_attention_masks == expected_attention_mask)) + self.assertTrue(torch.all(combined_input_ids == expected_input_ids)) + def test_get_batched_cache(self): env = TextEnvironment( self.gpt2_model, @@ -323,23 +324,25 @@ def test_get_batched_cache(self): max_turns=2, ) - cache = ((torch.tensor([[1],[2],[3]]),torch.tensor([[4],[5],[6]])),) - attention_masks = torch.tensor([[1],[2],[3]]) - input_ids = torch.tensor([[4],[5],[6]]) - batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache(1,3,cache,attention_masks,input_ids) + cache = ((torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])),) + attention_masks = torch.tensor([[1], [2], [3]]) + input_ids = torch.tensor([[4], [5], [6]]) + batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache( + 1, 3, cache, attention_masks, input_ids + ) batched_cache = batched_cache.to_legacy_cache() - expected_cache = ((torch.tensor([[2],[3]]),torch.tensor([[5],[6]])),) + expected_cache = ((torch.tensor([[2], [3]]), torch.tensor([[5], [6]])),) - self.assertEqual(len(batched_cache),len(expected_cache)) - self.assertEqual(len(batched_cache[0]),len(expected_cache[0])) - self.assertTrue(torch.all(batched_cache[0][0]==expected_cache[0][0])) - self.assertTrue(torch.all(batched_cache[0][1]==expected_cache[0][1])) + self.assertEqual(len(batched_cache), len(expected_cache)) + self.assertEqual(len(batched_cache[0]), len(expected_cache[0])) + self.assertTrue(torch.all(batched_cache[0][0] == expected_cache[0][0])) + self.assertTrue(torch.all(batched_cache[0][1] == expected_cache[0][1])) - expected_attention_mask = torch.tensor([[2],[3]]) - self.assertTrue(torch.all(batched_attention_masks==expected_attention_mask)) + expected_attention_mask = torch.tensor([[2], [3]]) + self.assertTrue(torch.all(batched_attention_masks == expected_attention_mask)) - expected_input_ids = torch.tensor([[5],[6]]) - self.assertTrue(torch.all(batched_input_ids==expected_input_ids)) + expected_input_ids = torch.tensor([[5], [6]]) + self.assertTrue(torch.all(batched_input_ids == expected_input_ids)) def test_cached_generate_batched(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} @@ -354,7 +357,9 @@ def test_cached_generate_batched(self): input_texts = ["this is a test", "this is another, longer test"] model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - outputs, past_key_values, past_attention_masks,past_input_ids, _ = env._generate_batched(model_inputs, batch_size=2) + outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + model_inputs, batch_size=2 + ) past_key_values = past_key_values[0].to_legacy_cache() past_attention_masks = past_attention_masks[0] past_input_ids = past_input_ids[0] @@ -362,12 +367,19 @@ def test_cached_generate_batched(self): input_texts2 = [" short interim", " a slightly longer interim"] model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] - outputs_cached, _,_,_,_ = env._generate_batched(model_inputs2, batch_size=2,combined_past_key_values=past_key_values,combined_past_attention_masks=past_attention_masks,combined_past_input_ids=past_input_ids) - - model_inputs2_full = [torch.concat([in1,out1,in2],dim=0) for in1,out1,in2 in zip(model_inputs,outputs, model_inputs2)] + outputs_cached, _, _, _, _ = env._generate_batched( + model_inputs2, + batch_size=2, + combined_past_key_values=past_key_values, + combined_past_attention_masks=past_attention_masks, + combined_past_input_ids=past_input_ids, + ) - outputs_uncached, _, _,_, _ = env._generate_batched(model_inputs2_full, batch_size=2) + model_inputs2_full = [ + torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) + ] - for cached, uncached in zip(outputs_cached,outputs_uncached): - self.assertTrue(torch.all(cached==uncached)) + outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) + for cached, uncached in zip(outputs_cached, outputs_uncached): + self.assertTrue(torch.all(cached == uncached)) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index b84eab2a97..bb4329f959 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -46,7 +46,10 @@ def __call__(self, input_ids, scores, **kwargs): done = [] for i, decoded_generation in enumerate(decoded_generations): - sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings) or self.tokenizer.eos_token_id in input_ids[i, self.start_length :] + sequence_complete = ( + any(stop_string in decoded_generation for stop_string in self.stop_strings) + or self.tokenizer.eos_token_id in input_ids[i, self.start_length :] + ) done.append(sequence_complete) if not sequence_complete: self.generated_tokens[i] += 1 @@ -128,7 +131,7 @@ def last_text_segment(self): """ start, end = self.text_spans[-1] return self.text[start:end] - + @property def last_token_segment(self): """ @@ -154,8 +157,7 @@ def show_text(self, show_legend=False): """ if not is_rich_available(): raise ImportError( - "The `rich` library is required to display text with formatting. " - "Install it using `pip install rich`." + "The `rich` library is required to display text with formatting. Install it using `pip install rich`." ) text = Text(self.text) @@ -236,7 +238,7 @@ def __init__( max_tool_reponse=100, max_length=None, generation_kwargs=None, - use_cache=False + use_cache=False, ): """ Initialize TextEnvironment. @@ -294,15 +296,21 @@ def run(self, queries, **rewards_kwargs): ] histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] - - past_key_values,past_attention_masks,past_input_ids,last_active_histories = (None,None,None,None) + + past_key_values, past_attention_masks, past_input_ids, last_active_histories = (None, None, None, None) while any(not history.completed for history in histories) and turns < self.max_turns: if self.use_cache: - histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) + histories, past_key_values, past_attention_masks, past_input_ids, last_active_histories = ( + self.generate( + histories, past_key_values, past_attention_masks, past_input_ids, last_active_histories + ) + ) else: - #Discard cache - histories,_,_,_,_ = self.generate(histories,past_key_values,past_attention_masks,past_input_ids,last_active_histories) + # Discard cache + histories, _, _, _, _ = self.generate( + histories, past_key_values, past_attention_masks, past_input_ids, last_active_histories + ) histories = self.tasks_end_check(histories) # TODO: make this parallel rather than for-loop for i in range(len(histories)): @@ -388,10 +396,10 @@ def compute_reward(self, histories, **reward_kwargs): history.reward = reward return histories - def _next_input(self,history): + def _next_input(self, history): return history.last_token_segment if not history.completed else torch.tensor([]) - def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_input_ids): + def _combine_cache(self, example_mask, past_key_values, past_attention_masks, past_input_ids): """ combines all caches in order to exclude completed histories from further generation @@ -401,7 +409,7 @@ def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_i past_attention_masks (list[torch.Tensor]): Batched list of attention masks from the last generation past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation """ - legacy_format = [cache.to_legacy_cache() for cache in past_key_values ] + legacy_format = [cache.to_legacy_cache() for cache in past_key_values] example_mask_offset = 0 combined_cache = [] for layer_id in range(len(legacy_format[0])): @@ -409,23 +417,33 @@ def _combine_cache(self,example_mask,past_key_values,past_attention_masks,past_i for cache_idx, cache in enumerate(legacy_format): layer = cache[layer_id] num_examples = len(layer[0]) - new_keys = layer[0][example_mask[example_mask_offset:example_mask_offset+num_examples]] - new_values = layer[1][example_mask[example_mask_offset:example_mask_offset+num_examples]] + new_keys = layer[0][example_mask[example_mask_offset : example_mask_offset + num_examples]] + new_values = layer[1][example_mask[example_mask_offset : example_mask_offset + num_examples]] if combined_layer is None: - combined_layer = (new_keys,new_values) + combined_layer = (new_keys, new_values) else: - other_new_keys,other_new_values = combined_layer - combined_layer = (torch.concat([other_new_keys,new_keys],dim=0),torch.concat([other_new_values,new_values],dim=0)) + other_new_keys, other_new_values = combined_layer + combined_layer = ( + torch.concat([other_new_keys, new_keys], dim=0), + torch.concat([other_new_values, new_values], dim=0), + ) example_mask_offset += num_examples combined_cache.append(combined_layer) combined_cache = tuple(combined_cache) - combined_attention_masks = torch.concat(past_attention_masks,dim=0)[example_mask] - combined_input_ids = torch.concat(past_input_ids,dim=0)[example_mask] + combined_attention_masks = torch.concat(past_attention_masks, dim=0)[example_mask] + combined_input_ids = torch.concat(past_input_ids, dim=0)[example_mask] return combined_cache, combined_attention_masks, combined_input_ids - def generate(self, histories,past_key_values=None,past_attention_masks=None,past_input_ids=None,last_active_histories=None): + def generate( + self, + histories, + past_key_values=None, + past_attention_masks=None, + past_input_ids=None, + last_active_histories=None, + ): """ Generate responses for a list of histories. Either all of past_key_values, past_attention_masks, past_input_ids,last_active_histories are provided or all are None. @@ -437,16 +455,23 @@ def generate(self, histories,past_key_values=None,past_attention_masks=None,past last_active_histories (Optional[list[int]]): indices of histories for which generation took place during the last generation turn """ active_histories = [i for i in range(len(histories)) if not histories[i].completed] - combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = (None,None,None) - + combined_past_key_values, combined_past_attention_masks, combined_past_input_ids = (None, None, None) + if past_key_values is not None: query_tensors = [self._next_input(histories[i]) for i in active_histories] example_mask = [(not histories[i].completed) for i in last_active_histories] - combined_past_key_values,combined_past_attention_masks, combined_past_input_ids = self._combine_cache(example_mask,past_key_values,past_attention_masks,past_input_ids) + combined_past_key_values, combined_past_attention_masks, combined_past_input_ids = self._combine_cache( + example_mask, past_key_values, past_attention_masks, past_input_ids + ) else: query_tensors = [histories[i].tokens for i in active_histories] - response_tensors,past_key_values,past_attention_masks,past_input_ids, truncated = self._generate_batched(query_tensors,combined_past_key_values=combined_past_key_values,combined_past_attention_masks=combined_past_attention_masks, combined_past_input_ids=combined_past_input_ids) + response_tensors, past_key_values, past_attention_masks, past_input_ids, truncated = self._generate_batched( + query_tensors, + combined_past_key_values=combined_past_key_values, + combined_past_attention_masks=combined_past_attention_masks, + combined_past_input_ids=combined_past_input_ids, + ) if not truncated: response_texts = self.tokenizer.batch_decode(response_tensors) for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): @@ -456,11 +481,15 @@ def generate(self, histories,past_key_values=None,past_attention_masks=None,past else: for history in histories: if not history.completed: - #Adds an eos token, so that we always end on a non-system segment - history.append_segment(self.tokenizer.eos_token, torch.tensor([self.tokenizer.eos_token_id]).to(self.current_device), system=False) + # Adds an eos token, so that we always end on a non-system segment + history.append_segment( + self.tokenizer.eos_token, + torch.tensor([self.tokenizer.eos_token_id]).to(self.current_device), + system=False, + ) history.complete(truncated=True) - return histories,past_key_values,past_attention_masks,past_input_ids, active_histories + return histories, past_key_values, past_attention_masks, past_input_ids, active_histories def tasks_end_check(self, histories, model_turn=True): """ @@ -496,8 +525,10 @@ def task_end_check(self, history, model_turn=True): ended = True return truncated, ended - #builds the cache for the current batch - def _get_batched_cache(self,start_index, end_index, combined_past_key_values, combined_attention_masks, combined_input_ids): + # builds the cache for the current batch + def _get_batched_cache( + self, start_index, end_index, combined_past_key_values, combined_attention_masks, combined_input_ids + ): """ Extract (batch) cache for current batch start_index (int): start index of current batch @@ -508,15 +539,18 @@ def _get_batched_cache(self,start_index, end_index, combined_past_key_values, co """ current_cache = [] for layer_id, layer in enumerate(combined_past_key_values): - keys,values = layer + keys, values = layer new_keys = keys[start_index:end_index] new_values = values[start_index:end_index] - current_cache.append((new_keys,new_values)) + current_cache.append((new_keys, new_values)) current_cache = tuple(current_cache) - return DynamicCache().from_legacy_cache(current_cache), combined_attention_masks[start_index:end_index], combined_input_ids[start_index:end_index] - + return ( + DynamicCache().from_legacy_cache(current_cache), + combined_attention_masks[start_index:end_index], + combined_input_ids[start_index:end_index], + ) - #TODO make batch_size changeable + # TODO make batch_size changeable def _generate_batched( self, query_tensors, @@ -524,7 +558,7 @@ def _generate_batched( pad_to_multiple_of: Optional[int] = None, combined_past_key_values=None, combined_past_attention_masks=None, - combined_past_input_ids = None + combined_past_input_ids=None, ): """ Generate responses for a list of query tensors. @@ -547,14 +581,16 @@ def _generate_batched( new_past_input_ids = [] # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) - for batch_index,i in enumerate(range(0, len(query_tensors), batch_size)): + for batch_index, i in enumerate(range(0, len(query_tensors), batch_size)): # prevent overflow if query tensors are not even multiple of bs end_index = min(len(query_tensors), i + batch_size) batch = query_tensors[i:end_index] batch_mask = [torch.ones_like(element) for element in batch] - past_key_values, past_attention_masks, past_input_ids = (None,None,None) + past_key_values, past_attention_masks, past_input_ids = (None, None, None) if combined_past_key_values is not None: - past_key_values, past_attention_masks, past_input_ids = self._get_batched_cache(i,end_index,combined_past_key_values,combined_past_attention_masks,combined_past_input_ids) + past_key_values, past_attention_masks, past_input_ids = self._get_batched_cache( + i, end_index, combined_past_key_values, combined_past_attention_masks, combined_past_input_ids + ) inputs = {"input_ids": batch, "attention_mask": batch_mask} padded_inputs = self.tokenizer.pad( @@ -570,29 +606,38 @@ def _generate_batched( self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) self.generation_kwargs["use_cache"] = True self.generation_kwargs["return_dict_in_generate"] = True - #handle caching - self.generation_kwargs["past_key_values"] = past_key_values if past_key_values is not None else DynamicCache() + # handle caching + self.generation_kwargs["past_key_values"] = ( + past_key_values if past_key_values is not None else DynamicCache() + ) if past_attention_masks is not None: - padded_inputs["attention_mask"] = torch.concatenate([past_attention_masks,padded_inputs["attention_mask"]],dim=1) + padded_inputs["attention_mask"] = torch.concatenate( + [past_attention_masks, padded_inputs["attention_mask"]], dim=1 + ) if past_input_ids is not None: - padded_inputs["input_ids"] = torch.concatenate([past_input_ids,padded_inputs["input_ids"]],dim=1) + padded_inputs["input_ids"] = torch.concatenate([past_input_ids, padded_inputs["input_ids"]], dim=1) - if self.max_length is not None and padded_inputs["input_ids"].shape[-1]>self.max_length: - return None, None, None,None, True + if self.max_length is not None and padded_inputs["input_ids"].shape[-1] > self.max_length: + return None, None, None, None, True generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) new_past_key_values.append(generations.past_key_values) past_attention_mask = torch.ones_like(generations.sequences) - #Don't attend to generated padding or eos tokens - past_attention_mask[torch.logical_or(generations.sequences==self.tokenizer.eos_token_id, generations.sequences==self.tokenizer.pad_token_id)] = 0 - past_attention_mask[:,:input_attention_mask.shape[1]] = input_attention_mask + # Don't attend to generated padding or eos tokens + past_attention_mask[ + torch.logical_or( + generations.sequences == self.tokenizer.eos_token_id, + generations.sequences == self.tokenizer.pad_token_id, + ) + ] = 0 + past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask generations = generations.sequences new_past_input_ids.append(generations) for generation, mask, generated_tokens, new_attention_mask in zip( - generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens,past_attention_mask + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens, past_attention_mask ): if not self.is_encoder_decoder: output = generation[(1 - mask).sum() :] # remove padding @@ -607,8 +652,8 @@ def _generate_batched( # remove chunk generated after stopping criteria in batch mode outputs.append(output[:generated_tokens]) - #Do not attend to tokens that were generated after or - generated_tokens_attention_mask[generated_tokens:]=0 + # Do not attend to tokens that were generated after or + generated_tokens_attention_mask[generated_tokens:] = 0 new_past_attention_masks.append(past_attention_mask) self.tokenizer.padding_side = padding_side_default - return outputs, new_past_key_values, new_past_attention_masks,new_past_input_ids, False + return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False From 2a7ec4ef4777c7c24a90a6ba0c6fff988999e930 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 15:58:50 +0100 Subject: [PATCH 08/38] fix: comment --- trl/environment/base_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index bb4329f959..81dbd0029d 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -481,7 +481,7 @@ def generate( else: for history in histories: if not history.completed: - # Adds an eos token, so that we always end on a non-system segment + # Adds an eos token, so that we end on a non-system segment history.append_segment( self.tokenizer.eos_token, torch.tensor([self.tokenizer.eos_token_id]).to(self.current_device), From af06d6319477e4e3b1e3b6ccfec7fa98df5e7179 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 16:16:21 +0100 Subject: [PATCH 09/38] fix: Args comment --- trl/environment/base_environment.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 81dbd0029d..8e5aaa5c01 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -531,11 +531,12 @@ def _get_batched_cache( ): """ Extract (batch) cache for current batch - start_index (int): start index of current batch - end_index (int): end index of current batch (points to first element not in batch) - combined_past_key_values (tuple[tuple[torch.Tensor]]) : The combined (unbatched) cache in legacy format from the last generation - combined_past_attention_masks (torch.Tensor): The combined (unbatched) attention masks from the last generation - combined_past_input_ids (torch.Tensor): The combined (unbatched) input ids from the last generation + Args: + start_index (int): start index of current batch + end_index (int): end index of current batch (points to first element not in batch) + combined_past_key_values (tuple[tuple[torch.Tensor]]) : The combined (unbatched) cache in legacy format from the last generation + combined_past_attention_masks (torch.Tensor): The combined (unbatched) attention masks from the last generation + combined_past_input_ids (torch.Tensor): The combined (unbatched) input ids from the last generation """ current_cache = [] for layer_id, layer in enumerate(combined_past_key_values): From f6f12b50d1601b3cca04553487ff919d67101c8a Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 16:55:08 +0100 Subject: [PATCH 10/38] fix: TextEnvironment cache combination and batching issue --- tests/test_environments.py | 44 +++++++++++++++++++++-------- trl/environment/base_environment.py | 31 ++++++++++++-------- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index f0f57ad2f9..328fe07690 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -291,15 +291,24 @@ def test_combine_cache(self): ) caches = [ - ((torch.tensor([[1], [2]]), torch.tensor([[3], [4]])),), - ((torch.tensor([[5]]), torch.tensor([[6]])),), + ( + (torch.tensor([[1], [2]]), torch.tensor([[3], [4]])), + (torch.tensor([[7], [8]]), torch.tensor([[9], [10]])), + ), + ( + (torch.tensor([[5]]), torch.tensor([[6]])), + (torch.tensor([[11]]), torch.tensor([[12]])), + ), ] caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] attention_masks = [torch.tensor([[0], [1]]), torch.tensor([[2]])] input_ids = [torch.tensor([[1], [2]]), torch.tensor([[3]])] example_mask = [True, False, True] - expected_cache = ((torch.tensor([[1], [5]]), torch.tensor([[3], [6]])),) + expected_cache = ( + (torch.tensor([[1], [5]]), torch.tensor([[3], [6]])), + (torch.tensor([[7], [11]]), torch.tensor([[9], [12]])), + ) expected_attention_mask = torch.tensor([[0], [2]]) expected_input_ids = torch.tensor([[1], [3]]) @@ -311,6 +320,9 @@ def test_combine_cache(self): self.assertEqual(len(combined_cache[0]), len(expected_cache[0])) self.assertTrue(torch.all(combined_cache[0][0] == expected_cache[0][0])) self.assertTrue(torch.all(combined_cache[0][1] == expected_cache[0][1])) + self.assertEqual(len(combined_cache[1]), len(expected_cache[1])) + self.assertTrue(torch.all(combined_cache[1][0] == expected_cache[1][0])) + self.assertTrue(torch.all(combined_cache[1][1] == expected_cache[1][1])) self.assertTrue(torch.all(combined_attention_masks == expected_attention_mask)) self.assertTrue(torch.all(combined_input_ids == expected_input_ids)) @@ -324,19 +336,28 @@ def test_get_batched_cache(self): max_turns=2, ) - cache = ((torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])),) + cache = ( + (torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])), + (torch.tensor([[7], [8], [9]]), torch.tensor([[10], [11], [12]])), + ) attention_masks = torch.tensor([[1], [2], [3]]) input_ids = torch.tensor([[4], [5], [6]]) batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache( 1, 3, cache, attention_masks, input_ids ) batched_cache = batched_cache.to_legacy_cache() - expected_cache = ((torch.tensor([[2], [3]]), torch.tensor([[5], [6]])),) + expected_cache = ( + (torch.tensor([[2], [3]]), torch.tensor([[5], [6]])), + (torch.tensor([[8], [9]]), torch.tensor([[11], [12]])), + ) self.assertEqual(len(batched_cache), len(expected_cache)) self.assertEqual(len(batched_cache[0]), len(expected_cache[0])) self.assertTrue(torch.all(batched_cache[0][0] == expected_cache[0][0])) self.assertTrue(torch.all(batched_cache[0][1] == expected_cache[0][1])) + self.assertEqual(len(batched_cache[1]), len(expected_cache[1])) + self.assertTrue(torch.all(batched_cache[1][0] == expected_cache[1][0])) + self.assertTrue(torch.all(batched_cache[1][1] == expected_cache[1][1])) expected_attention_mask = torch.tensor([[2], [3]]) self.assertTrue(torch.all(batched_attention_masks == expected_attention_mask)) @@ -355,16 +376,17 @@ def test_cached_generate_batched(self): generation_kwargs=generation_kwargs, ) - input_texts = ["this is a test", "this is another, longer test"] + input_texts = ["this is a test", "this is another, longer test", "some other batch"] model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( model_inputs, batch_size=2 ) - past_key_values = past_key_values[0].to_legacy_cache() - past_attention_masks = past_attention_masks[0] - past_input_ids = past_input_ids[0] - input_texts2 = [" short interim", " a slightly longer interim"] + past_key_values, past_attention_masks, past_input_ids = env._combine_cache( + [True, True, True], past_key_values, past_attention_masks, past_input_ids + ) + + input_texts2 = [" short interim", " a slightly longer interim", "another interim"] model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] outputs_cached, _, _, _, _ = env._generate_batched( @@ -378,8 +400,6 @@ def test_cached_generate_batched(self): model_inputs2_full = [ torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) ] - outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) - for cached, uncached in zip(outputs_cached, outputs_uncached): self.assertTrue(torch.all(cached == uncached)) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 8e5aaa5c01..2a97ff6d2c 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -410,10 +410,10 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation """ legacy_format = [cache.to_legacy_cache() for cache in past_key_values] - example_mask_offset = 0 combined_cache = [] for layer_id in range(len(legacy_format[0])): combined_layer = None + example_mask_offset = 0 for cache_idx, cache in enumerate(legacy_format): layer = cache[layer_id] num_examples = len(layer[0]) @@ -580,27 +580,34 @@ def _generate_batched( new_past_key_values = [] new_past_attention_masks = [] new_past_input_ids = [] + + # pad all batches to same length for cache compatibility + mask = [torch.ones_like(element) for element in query_tensors] + inputs = {"input_ids": query_tensors, "attention_mask": mask} + all_padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) for batch_index, i in enumerate(range(0, len(query_tensors), batch_size)): # prevent overflow if query tensors are not even multiple of bs end_index = min(len(query_tensors), i + batch_size) - batch = query_tensors[i:end_index] - batch_mask = [torch.ones_like(element) for element in batch] past_key_values, past_attention_masks, past_input_ids = (None, None, None) if combined_past_key_values is not None: past_key_values, past_attention_masks, past_input_ids = self._get_batched_cache( i, end_index, combined_past_key_values, combined_past_attention_masks, combined_past_input_ids ) - inputs = {"input_ids": batch, "attention_mask": batch_mask} - - padded_inputs = self.tokenizer.pad( - inputs, - padding=True, - max_length=None, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ).to(self.current_device) + + padded_inputs = { + "input_ids": all_padded_inputs["input_ids"][i:end_index], + "attention_mask": all_padded_inputs["attention_mask"][i:end_index], + } + input_attention_mask = padded_inputs["attention_mask"].clone() stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) From ede7e819f7f965db28d167285523115cd7714e35 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 10 Jan 2025 17:17:50 +0100 Subject: [PATCH 11/38] tests: make caching test more complex --- tests/test_environments.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 328fe07690..e8a708cd62 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -376,17 +376,17 @@ def test_cached_generate_batched(self): generation_kwargs=generation_kwargs, ) - input_texts = ["this is a test", "this is another, longer test", "some other batch"] + input_texts = ["this is a test", "this is another, longer test", "some other batch", "something unnecessary"] model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( model_inputs, batch_size=2 ) past_key_values, past_attention_masks, past_input_ids = env._combine_cache( - [True, True, True], past_key_values, past_attention_masks, past_input_ids + [True, True, True, False], past_key_values, past_attention_masks, past_input_ids ) - input_texts2 = [" short interim", " a slightly longer interim", "another interim"] + input_texts2 = [" short interim", " a somewhat longer section in between", "something else entirely! So, "] model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] outputs_cached, _, _, _, _ = env._generate_batched( @@ -398,7 +398,7 @@ def test_cached_generate_batched(self): ) model_inputs2_full = [ - torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) + torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs[:-1], outputs, model_inputs2) ] outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) for cached, uncached in zip(outputs_cached, outputs_uncached): From acddaa7d999e50024db68c56f0bef298cf9524ca Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sat, 11 Jan 2025 13:56:43 +0100 Subject: [PATCH 12/38] fix: combine caches of different sequence lengths --- tests/test_environments.py | 110 ++++++++++++++++++++++++---- trl/environment/base_environment.py | 45 ++++++++++-- 2 files changed, 135 insertions(+), 20 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index e8a708cd62..584318d164 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -34,6 +34,16 @@ def dummy_generate( return histories, None, None, None, None +def reshape_cache(cache): + new_cache = [] + for layer in cache: + keys, values = layer + keys = keys.reshape((-1, 1, 1, 1)) + values = values.reshape((-1, 1, 1, 1)) + new_cache.append((keys, values)) + return tuple(new_cache) + + class TextHistoryTest(unittest.TestCase): def test_text_history_init(self): text = "Hello there!" @@ -300,17 +310,19 @@ def test_combine_cache(self): (torch.tensor([[11]]), torch.tensor([[12]])), ), ] - caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] - attention_masks = [torch.tensor([[0], [1]]), torch.tensor([[2]])] - input_ids = [torch.tensor([[1], [2]]), torch.tensor([[3]])] + caches = [DynamicCache().from_legacy_cache(reshape_cache(cache)) for cache in caches] + attention_masks = [torch.tensor([[0, 1], [1, 0]]), torch.tensor([[2, 4]])] + input_ids = [torch.tensor([[1, 4], [2, 5]]), torch.tensor([[3, 6]])] example_mask = [True, False, True] - expected_cache = ( - (torch.tensor([[1], [5]]), torch.tensor([[3], [6]])), - (torch.tensor([[7], [11]]), torch.tensor([[9], [12]])), + expected_cache = reshape_cache( + ( + (torch.tensor([[1], [5]]), torch.tensor([[3], [6]])), + (torch.tensor([[7], [11]]), torch.tensor([[9], [12]])), + ) ) - expected_attention_mask = torch.tensor([[0], [2]]) - expected_input_ids = torch.tensor([[1], [3]]) + expected_attention_mask = torch.tensor([[0, 1], [2, 4]]) + expected_input_ids = torch.tensor([[1, 4], [3, 6]]) combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache( example_mask, caches, attention_masks, input_ids @@ -336,9 +348,11 @@ def test_get_batched_cache(self): max_turns=2, ) - cache = ( - (torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])), - (torch.tensor([[7], [8], [9]]), torch.tensor([[10], [11], [12]])), + cache = reshape_cache( + ( + (torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])), + (torch.tensor([[7], [8], [9]]), torch.tensor([[10], [11], [12]])), + ) ) attention_masks = torch.tensor([[1], [2], [3]]) input_ids = torch.tensor([[4], [5], [6]]) @@ -346,9 +360,11 @@ def test_get_batched_cache(self): 1, 3, cache, attention_masks, input_ids ) batched_cache = batched_cache.to_legacy_cache() - expected_cache = ( - (torch.tensor([[2], [3]]), torch.tensor([[5], [6]])), - (torch.tensor([[8], [9]]), torch.tensor([[11], [12]])), + expected_cache = reshape_cache( + ( + (torch.tensor([[2], [3]]), torch.tensor([[5], [6]])), + (torch.tensor([[8], [9]]), torch.tensor([[11], [12]])), + ) ) self.assertEqual(len(batched_cache), len(expected_cache)) @@ -386,8 +402,12 @@ def test_cached_generate_batched(self): [True, True, True, False], past_key_values, past_attention_masks, past_input_ids ) - input_texts2 = [" short interim", " a somewhat longer section in between", "something else entirely! So, "] + input_texts2 = [" short interim", " a somewhat longer section in between"] model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + # for single token query + model_inputs2.append( + torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + ) outputs_cached, _, _, _, _ = env._generate_batched( model_inputs2, @@ -403,3 +423,63 @@ def test_cached_generate_batched(self): outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) for cached, uncached in zip(outputs_cached, outputs_uncached): self.assertTrue(torch.all(cached == uncached)) + + def test_different_sequence_lengths(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test", "some other batch"] + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + model_inputs, batch_size=2 + ) + # remove the last two tokens from the second batch to pretend they were never generated + second_cache = past_key_values[1].to_legacy_cache() + edited_cache = [] + for layer in second_cache: + keys, values = layer + new_keys = keys[:, :, :-2, :] + new_values = values[:, :, :-2, :] + edited_cache.append((new_keys, new_values)) + + past_key_values[1] = DynamicCache().from_legacy_cache(tuple(edited_cache)) + past_attention_masks[1] = past_attention_masks[1][:, :-2] + past_input_ids[1] = past_input_ids[1][:, :-2] + + # ensure this actually removes generated tokens and not skipped tokens / padding + self.assertEqual(len(outputs[2]), 4) + + past_key_values, past_attention_masks, past_input_ids = env._combine_cache( + [True, True, True], past_key_values, past_attention_masks, past_input_ids + ) + + self.assertEqual(past_attention_masks.shape, past_input_ids.shape) + self.assertEqual(past_key_values[0][0].shape[2], past_attention_masks.shape[1] - 1) + self.assertEqual(past_key_values[0][0].shape[0], past_attention_masks.shape[0]) + input_texts2 = [" short interim", " a somewhat longer section in between"] + model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + # for single token query + model_inputs2.append( + torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + ) + outputs_cached, _, _, _, _ = env._generate_batched( + model_inputs2, + batch_size=2, + combined_past_key_values=past_key_values, + combined_past_attention_masks=past_attention_masks, + combined_past_input_ids=past_input_ids, + ) + outputs[2] = outputs[2][:-2] # remove last two generated tokens from input + model_inputs2_full = [ + torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) + ] + outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) + for cached, uncached in zip(outputs_cached, outputs_uncached): + self.assertTrue(torch.all(cached == uncached)) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 2a97ff6d2c..a913dc3462 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -409,6 +409,8 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa past_attention_masks (list[torch.Tensor]): Batched list of attention masks from the last generation past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation """ + max_sequence_length = max([attention_mask.shape[1] for attention_mask in past_attention_masks]) + legacy_format = [cache.to_legacy_cache() for cache in past_key_values] combined_cache = [] for layer_id in range(len(legacy_format[0])): @@ -417,8 +419,22 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa for cache_idx, cache in enumerate(legacy_format): layer = cache[layer_id] num_examples = len(layer[0]) - new_keys = layer[0][example_mask[example_mask_offset : example_mask_offset + num_examples]] - new_values = layer[1][example_mask[example_mask_offset : example_mask_offset + num_examples]] + extracted_keys = layer[0][example_mask[example_mask_offset : example_mask_offset + num_examples]] + extracted_values = layer[1][example_mask[example_mask_offset : example_mask_offset + num_examples]] + + # pad to max_sequence_length -1 + new_keys = torch.zeros( + ( + extracted_keys.shape[0], + extracted_keys.shape[1], + max_sequence_length - 1, + extracted_keys.shape[3], + ) + ).to(self.current_device) + new_values = torch.zeros_like(new_keys).to(self.current_device) + new_keys[:, :, : extracted_keys.shape[2], :] = extracted_keys + new_values[:, :, : extracted_values.shape[2], :] = extracted_values + if combined_layer is None: combined_layer = (new_keys, new_values) else: @@ -431,8 +447,23 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa combined_cache.append(combined_layer) combined_cache = tuple(combined_cache) - combined_attention_masks = torch.concat(past_attention_masks, dim=0)[example_mask] - combined_input_ids = torch.concat(past_input_ids, dim=0)[example_mask] + padded_attentions_masks = [] + padded_past_input_ids = [] + for attention_mask, input_ids in zip(past_attention_masks, past_input_ids): + padded_attention_mask = torch.zeros( + (attention_mask.shape[0], max_sequence_length), dtype=attention_mask.dtype + ).to(self.current_device) + padded_attention_mask[:, : attention_mask.shape[1]] = attention_mask + padded_attentions_masks.append(padded_attention_mask) + + padded_input_ids = torch.full( + (input_ids.shape[0], max_sequence_length), self.tokenizer.pad_token_id, dtype=input_ids.dtype + ).to(self.current_device) + padded_input_ids[:, : input_ids.shape[1]] = input_ids + padded_past_input_ids.append(padded_input_ids) + + combined_attention_masks = torch.concat(padded_attentions_masks, dim=0)[example_mask] + combined_input_ids = torch.concat(padded_past_input_ids, dim=0)[example_mask] return combined_cache, combined_attention_masks, combined_input_ids @@ -565,13 +596,17 @@ def _generate_batched( Generate responses for a list of query tensors. Either all of combined_past_key_values, combined_past_attention_masks, combined_past_input_ids are provided or all are None. Args: - query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + query_tensors (list[torch.Tensor]): A list of non-empty query tensors to generate responses for. batch_size (int): The batch size to use for generation. pad_to_multiple_of (int): The padding length to use for generation. combined_past_key_values (Optional[tuple[tuple[torch.Tensor]]]) : The combined (unbatched) cache in legacy format from the last generation combined_past_attention_masks (Optional[torch.Tensor]): The combined (unbatched) attention masks from the last generation combined_past_input_ids (Optional[torch.Tensor]): The combined (unbatched) input ids from the last generation """ + # Ensures, that the next token is never conditioned on a padding token. This should never be a problem, as empty system prompts are not particularly useful and between segments there is always a response token. + for query in query_tensors: + if len(query) == 0: + raise Exception("Cannot input empty query") outputs = [] padding_side_default = self.tokenizer.padding_side if not self.is_encoder_decoder: From e38940e03fef77307fc1c8db8fcbedafcc3001d6 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sun, 12 Jan 2025 16:24:56 +0100 Subject: [PATCH 13/38] docs: update caching warning --- docs/source/text_environments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index 1ab2e131da..c8a349fa4f 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,7 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | -| `use_cache` | Cache keys and values between segment generation. Warning: When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode.| +| `use_cache` | Cache keys and values between segment generation. Warning: This feature is experimental! When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and generation_kwargs. Cache use has been tested for GPT-2 with greedy search. | You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! From 66d0ce4350e10c5b5dfc025f46ab8ad3671c852d Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sun, 12 Jan 2025 16:40:45 +0100 Subject: [PATCH 14/38] fix: prevent bos tokens in tool response --- trl/environment/base_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index a913dc3462..6435e09d7e 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -354,7 +354,7 @@ def step(self, history): history.append_segment( response + self.response_token, - self.tokenizer(response + self.response_token, return_tensors="pt") + self.tokenizer(response + self.response_token, return_tensors="pt", add_special_tokens=False) .input_ids[0] .to(self.model.pretrained_model.device), system=True, From a051e46735726ceffeb0ad9e937bfa68ebc530c4 Mon Sep 17 00:00:00 2001 From: konrad-gerlach <58958090+konrad-gerlach@users.noreply.github.com> Date: Sun, 12 Jan 2025 19:48:12 +0100 Subject: [PATCH 15/38] docs: Update docs/source/text_environments.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/text_environments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index c8a349fa4f..82e60d02ab 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,7 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | -| `use_cache` | Cache keys and values between segment generation. Warning: This feature is experimental! When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and generation_kwargs. Cache use has been tested for GPT-2 with greedy search. | +| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. Cache use has been tested for GPT-2 with greedy search. | You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! From 9ea92874e78cc663c6afcf5dda82acff77cb9653 Mon Sep 17 00:00:00 2001 From: konrad-gerlach <58958090+konrad-gerlach@users.noreply.github.com> Date: Sun, 12 Jan 2025 19:48:49 +0100 Subject: [PATCH 16/38] Update trl/environment/base_environment.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/environment/base_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 6435e09d7e..9664f46e68 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -253,7 +253,7 @@ def __init__( max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. max_length (Optional[int]): The maximum number of tokens to allow in an episode. generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. - use_cache (bool): Whether or not to cache past_key_values between segments. When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval mode. + use_cache (bool): Whether to cache past_key_values between segments. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval mode. """ self.model = model self.tokenizer = tokenizer From a2860bc4b051e6941b9ffae01cfce9026af78c18 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sun, 12 Jan 2025 20:15:33 +0100 Subject: [PATCH 17/38] fix: code cleanup --- trl/environment/base_environment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 9664f46e68..e633a8c06c 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -416,7 +416,7 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa for layer_id in range(len(legacy_format[0])): combined_layer = None example_mask_offset = 0 - for cache_idx, cache in enumerate(legacy_format): + for cache in legacy_format: layer = cache[layer_id] num_examples = len(layer[0]) extracted_keys = layer[0][example_mask[example_mask_offset : example_mask_offset + num_examples]] @@ -570,7 +570,7 @@ def _get_batched_cache( combined_past_input_ids (torch.Tensor): The combined (unbatched) input ids from the last generation """ current_cache = [] - for layer_id, layer in enumerate(combined_past_key_values): + for layer in combined_past_key_values: keys, values = layer new_keys = keys[start_index:end_index] new_values = values[start_index:end_index] @@ -629,7 +629,7 @@ def _generate_batched( # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) - for batch_index, i in enumerate(range(0, len(query_tensors), batch_size)): + for i in range(0, len(query_tensors), batch_size): # prevent overflow if query tensors are not even multiple of bs end_index = min(len(query_tensors), i + batch_size) past_key_values, past_attention_masks, past_input_ids = (None, None, None) From 23014fb2ed508bb5d39a70d57fec3da91082ca42 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Tue, 14 Jan 2025 22:06:21 +0100 Subject: [PATCH 18/38] fix: attended to invalid last generated token and off-by-one in StringStoppingCriteria --- tests/test_environments.py | 26 ++++++------ trl/environment/base_environment.py | 61 ++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 584318d164..272622070d 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -302,27 +302,25 @@ def test_combine_cache(self): caches = [ ( - (torch.tensor([[1], [2]]), torch.tensor([[3], [4]])), - (torch.tensor([[7], [8]]), torch.tensor([[9], [10]])), + (torch.tensor([[[[1], [13]]], [[[2], [14]]]]), torch.tensor([[[[3], [15]]], [[[4], [16]]]])), + (torch.tensor([[[[7], [17]]], [[[8], [18]]]]), torch.tensor([[[[9], [19]]], [[[10], [20]]]])), ), ( - (torch.tensor([[5]]), torch.tensor([[6]])), - (torch.tensor([[11]]), torch.tensor([[12]])), + (torch.tensor([[[[5]]]]), torch.tensor([[[[6]]]])), + (torch.tensor([[[[11]]]]), torch.tensor([[[[12]]]])), ), ] - caches = [DynamicCache().from_legacy_cache(reshape_cache(cache)) for cache in caches] - attention_masks = [torch.tensor([[0, 1], [1, 0]]), torch.tensor([[2, 4]])] - input_ids = [torch.tensor([[1, 4], [2, 5]]), torch.tensor([[3, 6]])] + caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] + attention_masks = [torch.tensor([[-1, 1, 7], [1, 0, 8]]), torch.tensor([[2, 4]])] + input_ids = [torch.tensor([[1, 4, 7], [2, 5, 8]]), torch.tensor([[3, 6]])] example_mask = [True, False, True] - expected_cache = reshape_cache( - ( - (torch.tensor([[1], [5]]), torch.tensor([[3], [6]])), - (torch.tensor([[7], [11]]), torch.tensor([[9], [12]])), - ) + expected_cache = ( + (torch.tensor([[[[1], [13]]], [[[0], [5]]]]), torch.tensor([[[[3], [15]]], [[[0], [6]]]])), + (torch.tensor([[[[7], [17]]], [[[0], [11]]]]), torch.tensor([[[[9], [19]]], [[[0], [12]]]])), ) - expected_attention_mask = torch.tensor([[0, 1], [2, 4]]) - expected_input_ids = torch.tensor([[1, 4], [3, 6]]) + expected_attention_mask = torch.tensor([[-1, 1, 7], [0, 2, 4]]) + expected_input_ids = torch.tensor([[1, 4, 7], [self.gpt2_tokenizer.pad_token_id, 3, 6]]) combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache( example_mask, caches, attention_masks, input_ids diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index e633a8c06c..0dcdab1aa2 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -39,7 +39,7 @@ def __init__(self, stop_strings, tokenizer): def __call__(self, input_ids, scores, **kwargs): """Returns true if all generated sequences contain any of the stop strings or terminated early.""" if self.first_call: - self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.generated_tokens = [0 for _ in range(input_ids.shape[0])] self.start_length = input_ids.shape[-1] - 1 self.first_call = False decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) @@ -431,9 +431,14 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa extracted_keys.shape[3], ) ).to(self.current_device) + + if extracted_keys.shape[2] != extracted_values.shape[2]: + raise Exception("Cache format incompatible") + # left padding ensures, that the last valid generated token is what the next generated token is conditioned on + start_position = max_sequence_length - 1 - extracted_keys.shape[2] new_values = torch.zeros_like(new_keys).to(self.current_device) - new_keys[:, :, : extracted_keys.shape[2], :] = extracted_keys - new_values[:, :, : extracted_values.shape[2], :] = extracted_values + new_keys[:, :, start_position:, :] = extracted_keys + new_values[:, :, start_position:, :] = extracted_values if combined_layer is None: combined_layer = (new_keys, new_values) @@ -450,16 +455,19 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa padded_attentions_masks = [] padded_past_input_ids = [] for attention_mask, input_ids in zip(past_attention_masks, past_input_ids): + if attention_mask.shape[1] != input_ids.shape[1]: + raise Exception("Cache format incompatible") + start_position = max_sequence_length - attention_mask.shape[1] padded_attention_mask = torch.zeros( (attention_mask.shape[0], max_sequence_length), dtype=attention_mask.dtype ).to(self.current_device) - padded_attention_mask[:, : attention_mask.shape[1]] = attention_mask + padded_attention_mask[:, start_position:] = attention_mask padded_attentions_masks.append(padded_attention_mask) padded_input_ids = torch.full( (input_ids.shape[0], max_sequence_length), self.tokenizer.pad_token_id, dtype=input_ids.dtype ).to(self.current_device) - padded_input_ids[:, : input_ids.shape[1]] = input_ids + padded_input_ids[:, start_position:] = input_ids padded_past_input_ids.append(padded_input_ids) combined_attention_masks = torch.concat(padded_attentions_masks, dim=0)[example_mask] @@ -519,6 +527,7 @@ def generate( system=False, ) history.complete(truncated=True) + return histories, None, None, None, [] # invalidate cache return histories, past_key_values, past_attention_masks, past_input_ids, active_histories @@ -664,6 +673,9 @@ def _generate_batched( return None, None, None, None, True generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + if generations.past_key_values.to_legacy_cache()[0][0].shape[2] != generations.sequences.shape[1] - 1: + raise Exception("Cache should not contain keys and values for last generated token") new_past_key_values.append(generations.past_key_values) past_attention_mask = torch.ones_like(generations.sequences) @@ -677,26 +689,45 @@ def _generate_batched( past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask generations = generations.sequences - - new_past_input_ids.append(generations) - for generation, mask, generated_tokens, new_attention_mask in zip( - generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens, past_attention_mask + # copy for in-place modification + batch_new_past_input_ids = generations.detach().clone() + for generation, mask, num_generated_tokens, new_attention_mask, example_input_ids, i in zip( + generations, + padded_inputs["attention_mask"], + stopping_criteria.generated_tokens, + past_attention_mask, + batch_new_past_input_ids, + range(len(generations)), ): if not self.is_encoder_decoder: - output = generation[(1 - mask).sum() :] # remove padding + # remove padding + output = generation[(1 - mask).sum() :] + padding_removed_b_n_past_input_ids = example_input_ids[(1 - mask).sum() :] padding_removed_past_attention_mask = new_attention_mask[(1 - mask).sum() :] else: output = generation padding_removed_past_attention_mask = new_attention_mask + padding_removed_b_n_past_input_ids = example_input_ids[(1 - mask).sum() :] if not self.is_encoder_decoder: - output = output[(mask).sum() :] # remove prompt - generated_tokens_attention_mask = padding_removed_past_attention_mask[(mask).sum() :] + # remove prompt + output = output[(mask).sum() :] + padding_removed_b_n_past_input_ids = padding_removed_b_n_past_input_ids[(mask).sum() :] + padding_removed_past_attention_mask = padding_removed_past_attention_mask[(mask).sum() :] # remove chunk generated after stopping criteria in batch mode - outputs.append(output[:generated_tokens]) - # Do not attend to tokens that were generated after or - generated_tokens_attention_mask[generated_tokens:] = 0 + generated_tokens = output[:num_generated_tokens] + if num_generated_tokens < 1: + raise Exception("Generation failed to produce any valid token") + + outputs.append(generated_tokens) + # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence + padding_removed_past_attention_mask[num_generated_tokens - 1 :] = 0 + # move last valid generated token to the end of the sequence to be the start of the next generation + padding_removed_b_n_past_input_ids[-1] = padding_removed_b_n_past_input_ids[num_generated_tokens - 1] + padding_removed_past_attention_mask[-1] = 1 # attend to the last valid generated token + new_past_attention_masks.append(past_attention_mask) + new_past_input_ids.append(batch_new_past_input_ids) self.tokenizer.padding_side = padding_side_default return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False From 7324ee16af362b6ee0e50d6afe55684841daf17d Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Tue, 21 Jan 2025 15:59:38 +0100 Subject: [PATCH 19/38] fix: off by one error in StringStoppingCriteria --- trl/environment/base_environment.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 0dcdab1aa2..9d044965c0 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -51,8 +51,8 @@ def __call__(self, input_ids, scores, **kwargs): or self.tokenizer.eos_token_id in input_ids[i, self.start_length :] ) done.append(sequence_complete) - if not sequence_complete: - self.generated_tokens[i] += 1 + # we still consider the last generated token to be valid + self.generated_tokens[i] += 1 if all(done): self.first_call = True @@ -717,8 +717,9 @@ def _generate_batched( # remove chunk generated after stopping criteria in batch mode generated_tokens = output[:num_generated_tokens] - if num_generated_tokens < 1: - raise Exception("Generation failed to produce any valid token") + if len(generated_tokens) < 1: + input_length = padded_inputs["input_ids"].shape[0] + raise Exception(f"Generation failed to produce any valid token; input length {input_length}") outputs.append(generated_tokens) # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence From 39763b1fbe97fa0a99c3b62d7d0ab01cdaddab78 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Wed, 22 Jan 2025 15:00:13 +0100 Subject: [PATCH 20/38] feat: test logits are same with and without caching --- tests/test_environments.py | 33 +++++++++++++++++++++++------ trl/environment/base_environment.py | 30 ++++++++++++++++++++------ 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 272622070d..f14a542929 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -406,21 +406,28 @@ def test_cached_generate_batched(self): model_inputs2.append( torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) ) - - outputs_cached, _, _, _, _ = env._generate_batched( + outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( model_inputs2, batch_size=2, combined_past_key_values=past_key_values, combined_past_attention_masks=past_attention_masks, combined_past_input_ids=past_input_ids, + output_logits=True, ) model_inputs2_full = [ torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs[:-1], outputs, model_inputs2) ] - outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) - for cached, uncached in zip(outputs_cached, outputs_uncached): + outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( + model_inputs2_full, batch_size=2, output_logits=True + ) + for cached, uncached, logits_cached, logits_uncached in zip( + outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached + ): self.assertTrue(torch.all(cached == uncached)) + self.assertEqual(logits_cached.shape[0], 4) + self.assertEqual(logits_uncached.shape[0], 4) + self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) def test_different_sequence_lengths(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} @@ -467,17 +474,29 @@ def test_different_sequence_lengths(self): model_inputs2.append( torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) ) - outputs_cached, _, _, _, _ = env._generate_batched( + outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( model_inputs2, batch_size=2, combined_past_key_values=past_key_values, combined_past_attention_masks=past_attention_masks, combined_past_input_ids=past_input_ids, + output_logits=True, ) outputs[2] = outputs[2][:-2] # remove last two generated tokens from input model_inputs2_full = [ torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) ] - outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2) - for cached, uncached in zip(outputs_cached, outputs_uncached): + outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( + model_inputs2_full, batch_size=2, output_logits=True + ) + for cached, uncached, logits_cached, logits_uncached in zip( + outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached + ): self.assertTrue(torch.all(cached == uncached)) + self.assertEqual(logits_cached.shape[0], 4) + self.assertEqual(logits_uncached.shape[0], 4) + self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + + +if __name__ == "__main__": + pass diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 9d044965c0..4c341154b1 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import copy from typing import Optional @@ -600,6 +601,7 @@ def _generate_batched( combined_past_key_values=None, combined_past_attention_masks=None, combined_past_input_ids=None, + output_logits=False, ): """ Generate responses for a list of query tensors. @@ -624,6 +626,8 @@ def _generate_batched( new_past_key_values = [] new_past_attention_masks = [] new_past_input_ids = [] + if output_logits: + all_logits = [] # pad all batches to same length for cache compatibility mask = [torch.ones_like(element) for element in query_tensors] @@ -655,13 +659,16 @@ def _generate_batched( input_attention_mask = padded_inputs["attention_mask"].clone() stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) - self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) - self.generation_kwargs["use_cache"] = True - self.generation_kwargs["return_dict_in_generate"] = True + generation_kwargs = copy.deepcopy(self.generation_kwargs) + + generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + generation_kwargs["use_cache"] = True + generation_kwargs["return_dict_in_generate"] = True + if output_logits: + generation_kwargs["output_logits"] = True + # handle caching - self.generation_kwargs["past_key_values"] = ( - past_key_values if past_key_values is not None else DynamicCache() - ) + generation_kwargs["past_key_values"] = past_key_values if past_key_values is not None else DynamicCache() if past_attention_masks is not None: padded_inputs["attention_mask"] = torch.concatenate( [past_attention_masks, padded_inputs["attention_mask"]], dim=1 @@ -672,7 +679,7 @@ def _generate_batched( if self.max_length is not None and padded_inputs["input_ids"].shape[-1] > self.max_length: return None, None, None, None, True - generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **generation_kwargs) if generations.past_key_values.to_legacy_cache()[0][0].shape[2] != generations.sequences.shape[1] - 1: raise Exception("Cache should not contain keys and values for last generated token") @@ -688,6 +695,8 @@ def _generate_batched( ] = 0 past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask + if output_logits: + logits = generations.logits generations = generations.sequences # copy for in-place modification batch_new_past_input_ids = generations.detach().clone() @@ -727,8 +736,15 @@ def _generate_batched( # move last valid generated token to the end of the sequence to be the start of the next generation padding_removed_b_n_past_input_ids[-1] = padding_removed_b_n_past_input_ids[num_generated_tokens - 1] padding_removed_past_attention_mask[-1] = 1 # attend to the last valid generated token + if output_logits: + for i, num_generated_tokens in enumerate(stopping_criteria.generated_tokens): + relevant_logits = [batched_logits[i] for batched_logits in logits[:num_generated_tokens]] + all_logits.append(torch.stack(relevant_logits, dim=0)) new_past_attention_masks.append(past_attention_mask) new_past_input_ids.append(batch_new_past_input_ids) self.tokenizer.padding_side = padding_side_default + if output_logits: + return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False, all_logits + return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False From b70f51c33d390daebea2c3f885a106ba976fa707 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Wed, 22 Jan 2025 15:48:36 +0100 Subject: [PATCH 21/38] fix: model and tokenizer were called gpt2 but were another model --- docs/source/text_environments.md | 2 +- tests/test_environments.py | 76 ++++++++++++++++---------------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index 82e60d02ab..c8864ad3d3 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,7 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | -| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. Cache use has been tested for GPT-2 with greedy search. | +| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. | You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! diff --git a/tests/test_environments.py b/tests/test_environments.py index f14a542929..b6d4c60bec 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -112,14 +112,14 @@ def setUp(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" # get models and tokenizer - self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - self.gpt2_tokenizer = AutoTokenizer.from_pretrained(self.model_id) - self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token + self.model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token def test_text_environment_setup(self): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -130,10 +130,10 @@ def test_text_environment_setup(self): self.assertEqual(env.reward_fn("Hello there!"), 1) def test_text_environment_generate(self): - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -142,13 +142,13 @@ def test_text_environment_generate(self): input_texts = ["this is a test", "this is another, longer test"] - model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] generations_batched, _, _, _, _ = env._generate_batched(model_inputs, batch_size=2) - generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) + generations_batched = self.tokenizer.batch_decode(generations_batched) generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs] - generations_single = self.gpt2_tokenizer.batch_decode(generations_single) + generations_single = self.tokenizer.batch_decode(generations_single) self.assertEqual(generations_single, generations_batched) @@ -160,8 +160,8 @@ def test_text_environment_tool_call_parsing(self): string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>" env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -188,8 +188,8 @@ def test_text_environment_tool_call_parsing(self): def test_text_environment_tool_truncation(self): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"dummy": lambda x: "a" * 1000}, reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -214,8 +214,8 @@ def test_text_environment_tool_truncation(self): @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) def test_text_environment_max_calls(self, mock_generate): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"DummyTool": DummyTool()}, reward_fn=lambda x: [torch.tensor(1) for _ in x], prompt="I am a prompt!\n", @@ -244,8 +244,8 @@ def test_text_environment_max_calls(self, mock_generate): def test_text_environment_compute_rewards(self): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"DummyTool": DummyTool()}, reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], prompt="I am a prompt!\n", @@ -260,8 +260,8 @@ def test_text_environment_compute_rewards(self): @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) def test_text_environment_run(self, mock_generate): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"DummyTool": DummyTool()}, reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], prompt="I am a prompt!\n", @@ -292,8 +292,8 @@ def test_text_environment_run(self, mock_generate): def test_combine_cache(self): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"DummyTool": DummyTool()}, reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], prompt="I am a prompt!\n", @@ -320,7 +320,7 @@ def test_combine_cache(self): (torch.tensor([[[[7], [17]]], [[[0], [11]]]]), torch.tensor([[[[9], [19]]], [[[0], [12]]]])), ) expected_attention_mask = torch.tensor([[-1, 1, 7], [0, 2, 4]]) - expected_input_ids = torch.tensor([[1, 4, 7], [self.gpt2_tokenizer.pad_token_id, 3, 6]]) + expected_input_ids = torch.tensor([[1, 4, 7], [self.tokenizer.pad_token_id, 3, 6]]) combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache( example_mask, caches, attention_masks, input_ids @@ -338,8 +338,8 @@ def test_combine_cache(self): def test_get_batched_cache(self): env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools={"DummyTool": DummyTool()}, reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], prompt="I am a prompt!\n", @@ -380,10 +380,10 @@ def test_get_batched_cache(self): self.assertTrue(torch.all(batched_input_ids == expected_input_ids)) def test_cached_generate_batched(self): - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -391,7 +391,7 @@ def test_cached_generate_batched(self): ) input_texts = ["this is a test", "this is another, longer test", "some other batch", "something unnecessary"] - model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( model_inputs, batch_size=2 ) @@ -401,10 +401,10 @@ def test_cached_generate_batched(self): ) input_texts2 = [" short interim", " a somewhat longer section in between"] - model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] # for single token query model_inputs2.append( - torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) ) outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( model_inputs2, @@ -430,10 +430,10 @@ def test_cached_generate_batched(self): self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) def test_different_sequence_lengths(self): - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} env = TextEnvironment( - self.gpt2_model, - self.gpt2_tokenizer, + self.model, + self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor(1), prompt="I am a prompt!\n", @@ -441,7 +441,7 @@ def test_different_sequence_lengths(self): ) input_texts = ["this is a test", "this is another, longer test", "some other batch"] - model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( model_inputs, batch_size=2 ) @@ -469,10 +469,10 @@ def test_different_sequence_lengths(self): self.assertEqual(past_key_values[0][0].shape[2], past_attention_masks.shape[1] - 1) self.assertEqual(past_key_values[0][0].shape[0], past_attention_masks.shape[0]) input_texts2 = [" short interim", " a somewhat longer section in between"] - model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] # for single token query model_inputs2.append( - torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) ) outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( model_inputs2, From 7b2169de9058645381b76a9f3ed670efa0cb8853 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Wed, 22 Jan 2025 15:59:08 +0100 Subject: [PATCH 22/38] docs: add warning for torch.compile with TextEnvironment use_cache --- docs/source/text_environments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index c8864ad3d3..068832523b 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,7 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | -| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. | +| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. `use_cache` may currently be incompatible with torch.compile due to a possible issue in the transformers library's generate method. See this comment in [_get_initial_cache_position](https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582).| You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! From 5725b1827f186e49d97534f12fc056d4c1a8e8cd Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Thu, 23 Jan 2025 11:06:59 +0100 Subject: [PATCH 23/38] fix: StringStoppingCriteria and add test --- tests/test_environments.py | 29 +++++++++++++++++++++++++++++ trl/environment/base_environment.py | 6 +++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index b6d4c60bec..3908554263 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -19,6 +19,7 @@ from transformers import AutoTokenizer, DynamicCache from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory +from trl.environment.base_environment import StringStoppingCriteria class DummyTool: @@ -116,6 +117,34 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token + def test_stopping_criteria(self): + stopping_criteria = StringStoppingCriteria(["stop", "end"], self.tokenizer) + encoded = self.tokenizer( + ["Lorem ipsum stop dolor sit amet", "ipsumenddolor sit amet, consectetur adipiscing", "token"], + return_tensors="pt", + padding=True, + padding_side="right", + ) + end_positions = [] + end_positions.append(encoded.char_to_token(batch_or_char_index=0, char_index=15)) + end_positions.append(encoded.char_to_token(batch_or_char_index=1, char_index=7)) + end_index = encoded.char_to_token(batch_or_char_index=1, char_index=4) + encoded["input_ids"][end_index + 1] = self.tokenizer.eos_token_id + end_positions.append(end_index) + i = 0 + is_stopped = False + while not is_stopped and i < 100: + # the first token is assumed to be the original input + is_stopped = stopping_criteria(encoded["input_ids"][:, : i + 2], None) + self.assertEqual( + stopping_criteria.generated_tokens, [min(end_position, i + 1) for end_position in end_positions] + ) + i += 1 + + self.assertTrue(is_stopped) + self.assertEqual(i, max(*end_positions)) + self.assertEqual(end_positions, stopping_criteria.generated_tokens) + def test_text_environment_setup(self): env = TextEnvironment( self.model, diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 4c341154b1..7b79cfe3db 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -43,6 +43,7 @@ def __call__(self, input_ids, scores, **kwargs): self.generated_tokens = [0 for _ in range(input_ids.shape[0])] self.start_length = input_ids.shape[-1] - 1 self.first_call = False + self.last_done = [False for _ in range(input_ids.shape[0])] decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) done = [] @@ -53,7 +54,10 @@ def __call__(self, input_ids, scores, **kwargs): ) done.append(sequence_complete) # we still consider the last generated token to be valid - self.generated_tokens[i] += 1 + if not self.last_done[i]: + self.generated_tokens[i] += 1 + + self.last_done = done if all(done): self.first_call = True From 589dcb7d3f5e3446901590790a4d006c35a8d9e5 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Thu, 23 Jan 2025 11:22:27 +0100 Subject: [PATCH 24/38] refactor: move StoppingCriteria test --- tests/test_environments.py | 64 +++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 3908554263..09ab79cb5e 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -45,6 +45,42 @@ def reshape_cache(cache): return tuple(new_cache) +class StoppingCriteriaTester(unittest.TestCase): + def setUp(self): + # model_id + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + def test_stopping_criteria(self): + stopping_criteria = StringStoppingCriteria(["stop", "end"], self.tokenizer) + encoded = self.tokenizer( + ["Lorem ipsum stop dolor sit amet", "ipsumenddolor sit amet, consectetur adipiscing", "token"], + return_tensors="pt", + padding=True, + padding_side="right", + ) + end_positions = [] + end_positions.append(encoded.char_to_token(batch_or_char_index=0, char_index=15)) + end_positions.append(encoded.char_to_token(batch_or_char_index=1, char_index=7)) + end_index = encoded.char_to_token(batch_or_char_index=1, char_index=4) + encoded["input_ids"][end_index + 1] = self.tokenizer.eos_token_id + end_positions.append(end_index) + i = 0 + is_stopped = False + while not is_stopped and i < 100: + # the first token is assumed to be the original input + is_stopped = stopping_criteria(encoded["input_ids"][:, : i + 2], None) + self.assertEqual( + stopping_criteria.generated_tokens, [min(end_position, i + 1) for end_position in end_positions] + ) + i += 1 + + self.assertTrue(is_stopped) + self.assertEqual(i, max(*end_positions)) + self.assertEqual(end_positions, stopping_criteria.generated_tokens) + + class TextHistoryTest(unittest.TestCase): def test_text_history_init(self): text = "Hello there!" @@ -117,34 +153,6 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token - def test_stopping_criteria(self): - stopping_criteria = StringStoppingCriteria(["stop", "end"], self.tokenizer) - encoded = self.tokenizer( - ["Lorem ipsum stop dolor sit amet", "ipsumenddolor sit amet, consectetur adipiscing", "token"], - return_tensors="pt", - padding=True, - padding_side="right", - ) - end_positions = [] - end_positions.append(encoded.char_to_token(batch_or_char_index=0, char_index=15)) - end_positions.append(encoded.char_to_token(batch_or_char_index=1, char_index=7)) - end_index = encoded.char_to_token(batch_or_char_index=1, char_index=4) - encoded["input_ids"][end_index + 1] = self.tokenizer.eos_token_id - end_positions.append(end_index) - i = 0 - is_stopped = False - while not is_stopped and i < 100: - # the first token is assumed to be the original input - is_stopped = stopping_criteria(encoded["input_ids"][:, : i + 2], None) - self.assertEqual( - stopping_criteria.generated_tokens, [min(end_position, i + 1) for end_position in end_positions] - ) - i += 1 - - self.assertTrue(is_stopped) - self.assertEqual(i, max(*end_positions)) - self.assertEqual(end_positions, stopping_criteria.generated_tokens) - def test_text_environment_setup(self): env = TextEnvironment( self.model, From 5e1a7dd328beb4730ff45bad34c18cadfa4de18d Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Thu, 23 Jan 2025 14:44:48 +0100 Subject: [PATCH 25/38] feat: add support for models without cache class support --- tests/test_environments.py | 266 +++++++++++++++++----------- trl/environment/base_environment.py | 30 ++-- 2 files changed, 183 insertions(+), 113 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 09ab79cb5e..cbaf3476a0 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -14,9 +14,10 @@ import unittest from unittest.mock import patch +from parameterized import parameterized import torch -from transformers import AutoTokenizer, DynamicCache +from transformers import AutoTokenizer, Cache from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory from trl.environment.base_environment import StringStoppingCriteria @@ -35,6 +36,15 @@ def dummy_generate( return histories, None, None, None, None +def cache_class_support_forward(support_cache_class, feedback): + def _forward(*args, **kwargs): + if isinstance(kwargs["past_key_values"], Cache) == support_cache_class: + feedback[0] += 1 + raise Exception("Testing") + + return _forward + + def reshape_cache(cache): new_cache = [] for layer in cache: @@ -347,7 +357,6 @@ def test_combine_cache(self): (torch.tensor([[[[11]]]]), torch.tensor([[[[12]]]])), ), ] - caches = [DynamicCache().from_legacy_cache(cache) for cache in caches] attention_masks = [torch.tensor([[-1, 1, 7], [1, 0, 8]]), torch.tensor([[2, 4]])] input_ids = [torch.tensor([[1, 4, 7], [2, 5, 8]]), torch.tensor([[3, 6]])] example_mask = [True, False, True] @@ -394,7 +403,6 @@ def test_get_batched_cache(self): batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache( 1, 3, cache, attention_masks, input_ids ) - batched_cache = batched_cache.to_legacy_cache() expected_cache = reshape_cache( ( (torch.tensor([[2], [3]]), torch.tensor([[5], [6]])), @@ -416,57 +424,68 @@ def test_get_batched_cache(self): expected_input_ids = torch.tensor([[5], [6]]) self.assertTrue(torch.all(batched_input_ids == expected_input_ids)) - def test_cached_generate_batched(self): - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} - env = TextEnvironment( - self.model, - self.tokenizer, - tools=[DummyTool()], - reward_fn=lambda x: torch.tensor(1), - prompt="I am a prompt!\n", - generation_kwargs=generation_kwargs, - ) + @parameterized.expand([(True,), (False,)]) + def test_cached_generate_batched(self, support_cache_class): + with patch.object(self.model.pretrained_model, "_supports_cache_class", new=support_cache_class): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} + env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) - input_texts = ["this is a test", "this is another, longer test", "some other batch", "something unnecessary"] - model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( - model_inputs, batch_size=2 - ) + input_texts = [ + "this is a test", + "this is another, longer test", + "some other batch", + "something unnecessary", + ] + model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + model_inputs, batch_size=2 + ) - past_key_values, past_attention_masks, past_input_ids = env._combine_cache( - [True, True, True, False], past_key_values, past_attention_masks, past_input_ids - ) + past_key_values, past_attention_masks, past_input_ids = env._combine_cache( + [True, True, True, False], past_key_values, past_attention_masks, past_input_ids + ) - input_texts2 = [" short interim", " a somewhat longer section in between"] - model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] - # for single token query - model_inputs2.append( - torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) - ) - outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( - model_inputs2, - batch_size=2, - combined_past_key_values=past_key_values, - combined_past_attention_masks=past_attention_masks, - combined_past_input_ids=past_input_ids, - output_logits=True, - ) + input_texts2 = [" short interim", " a somewhat longer section in between"] + model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + # for single token query + model_inputs2.append( + torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + ) + outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( + model_inputs2, + batch_size=2, + combined_past_key_values=past_key_values, + combined_past_attention_masks=past_attention_masks, + combined_past_input_ids=past_input_ids, + output_logits=True, + ) + + model_inputs2_full = [ + torch.concat([in1, out1, in2], dim=0) + for in1, out1, in2 in zip(model_inputs[:-1], outputs, model_inputs2) + ] + outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( + model_inputs2_full, batch_size=2, output_logits=True + ) + for cached, uncached, logits_cached, logits_uncached in zip( + outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached + ): + self.assertTrue(torch.all(cached == uncached)) + self.assertEqual(logits_cached.shape[0], 4) + self.assertEqual(logits_uncached.shape[0], 4) + self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + + @parameterized.expand([(True,), (False,)]) + def test_cache_class_support(self, support_cache_class): + self.assertEqual(self.model_id, "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") - model_inputs2_full = [ - torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs[:-1], outputs, model_inputs2) - ] - outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( - model_inputs2_full, batch_size=2, output_logits=True - ) - for cached, uncached, logits_cached, logits_uncached in zip( - outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached - ): - self.assertTrue(torch.all(cached == uncached)) - self.assertEqual(logits_cached.shape[0], 4) - self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) - - def test_different_sequence_lengths(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} env = TextEnvironment( self.model, @@ -477,62 +496,103 @@ def test_different_sequence_lengths(self): generation_kwargs=generation_kwargs, ) - input_texts = ["this is a test", "this is another, longer test", "some other batch"] - model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( - model_inputs, batch_size=2 - ) - # remove the last two tokens from the second batch to pretend they were never generated - second_cache = past_key_values[1].to_legacy_cache() - edited_cache = [] - for layer in second_cache: - keys, values = layer - new_keys = keys[:, :, :-2, :] - new_values = values[:, :, :-2, :] - edited_cache.append((new_keys, new_values)) - - past_key_values[1] = DynamicCache().from_legacy_cache(tuple(edited_cache)) - past_attention_masks[1] = past_attention_masks[1][:, :-2] - past_input_ids[1] = past_input_ids[1][:, :-2] - - # ensure this actually removes generated tokens and not skipped tokens / padding - self.assertEqual(len(outputs[2]), 4) + input_texts = ["test"] + model_inputs = list(self.tokenizer(input_texts, return_tensors="pt").input_ids) + _, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched(model_inputs, batch_size=2) past_key_values, past_attention_masks, past_input_ids = env._combine_cache( - [True, True, True], past_key_values, past_attention_masks, past_input_ids - ) + [True], past_key_values, past_attention_masks, past_input_ids + ) + + input_texts2 = [" short interim"] + model_inputs2 = list(self.tokenizer(input_texts2, return_tensors="pt").input_ids) + feedback = torch.tensor([0]) + with patch.object(self.model.pretrained_model, "_supports_cache_class", new=support_cache_class): + with patch.object( + self.model.pretrained_model, "forward", new=cache_class_support_forward(support_cache_class, feedback) + ): + try: + _, _, _, _, _, all_logits_cached = env._generate_batched( + model_inputs2, + batch_size=2, + combined_past_key_values=past_key_values, + combined_past_attention_masks=past_attention_masks, + combined_past_input_ids=past_input_ids, + output_logits=True, + ) + except: + pass + self.assertTrue(torch.all(feedback == 1.0)) + + @parameterized.expand([(True,), (False,)]) + def test_different_sequence_lengths(self, support_cache_class): + with patch.object(self.model.pretrained_model, "_supports_cache_class", new=support_cache_class): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} + env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) - self.assertEqual(past_attention_masks.shape, past_input_ids.shape) - self.assertEqual(past_key_values[0][0].shape[2], past_attention_masks.shape[1] - 1) - self.assertEqual(past_key_values[0][0].shape[0], past_attention_masks.shape[0]) - input_texts2 = [" short interim", " a somewhat longer section in between"] - model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] - # for single token query - model_inputs2.append( - torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) - ) - outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( - model_inputs2, - batch_size=2, - combined_past_key_values=past_key_values, - combined_past_attention_masks=past_attention_masks, - combined_past_input_ids=past_input_ids, - output_logits=True, - ) - outputs[2] = outputs[2][:-2] # remove last two generated tokens from input - model_inputs2_full = [ - torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) - ] - outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( - model_inputs2_full, batch_size=2, output_logits=True - ) - for cached, uncached, logits_cached, logits_uncached in zip( - outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached - ): - self.assertTrue(torch.all(cached == uncached)) - self.assertEqual(logits_cached.shape[0], 4) - self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + input_texts = ["this is a test", "this is another, longer test", "some other batch"] + model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + model_inputs, batch_size=2 + ) + # remove the last two tokens from the second batch to pretend they were never generated + second_cache = past_key_values[1] + edited_cache = [] + for layer in second_cache: + keys, values = layer + new_keys = keys[:, :, :-2, :] + new_values = values[:, :, :-2, :] + edited_cache.append((new_keys, new_values)) + + past_key_values[1] = tuple(edited_cache) + past_attention_masks[1] = past_attention_masks[1][:, :-2] + past_input_ids[1] = past_input_ids[1][:, :-2] + + # ensure this actually removes generated tokens and not skipped tokens / padding + self.assertEqual(len(outputs[2]), 4) + + past_key_values, past_attention_masks, past_input_ids = env._combine_cache( + [True, True, True], past_key_values, past_attention_masks, past_input_ids + ) + + self.assertEqual(past_attention_masks.shape, past_input_ids.shape) + self.assertEqual(past_key_values[0][0].shape[2], past_attention_masks.shape[1] - 1) + self.assertEqual(past_key_values[0][0].shape[0], past_attention_masks.shape[0]) + input_texts2 = [" short interim", " a somewhat longer section in between"] + model_inputs2 = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2] + # for single token query + model_inputs2.append( + torch.tensor([self.tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype) + ) + outputs_cached, _, _, _, _, all_logits_cached = env._generate_batched( + model_inputs2, + batch_size=2, + combined_past_key_values=past_key_values, + combined_past_attention_masks=past_attention_masks, + combined_past_input_ids=past_input_ids, + output_logits=True, + ) + outputs[2] = outputs[2][:-2] # remove last two generated tokens from input + model_inputs2_full = [ + torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2) + ] + outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( + model_inputs2_full, batch_size=2, output_logits=True + ) + for cached, uncached, logits_cached, logits_uncached in zip( + outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached + ): + self.assertTrue(torch.all(cached == uncached)) + self.assertEqual(logits_cached.shape[0], 4) + self.assertEqual(logits_uncached.shape[0], 4) + self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) if __name__ == "__main__": diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 7b79cfe3db..f8063e37ac 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -410,18 +410,17 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa Args: batch_examples (list[bool]): mask indicating for each example, whether it is supposed to remain or not - past_key_values (list[transformers.DynamicCache]) : Batched list of caches from the last generation + past_key_values (tuple[tuple[torch.Tensor]]) : Batched list of caches (in legacy format) from the last generation past_attention_masks (list[torch.Tensor]): Batched list of attention masks from the last generation past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation """ max_sequence_length = max([attention_mask.shape[1] for attention_mask in past_attention_masks]) - legacy_format = [cache.to_legacy_cache() for cache in past_key_values] combined_cache = [] - for layer_id in range(len(legacy_format[0])): + for layer_id in range(len(past_key_values[0])): combined_layer = None example_mask_offset = 0 - for cache in legacy_format: + for cache in past_key_values: layer = cache[layer_id] num_examples = len(layer[0]) extracted_keys = layer[0][example_mask[example_mask_offset : example_mask_offset + num_examples]] @@ -493,7 +492,7 @@ def generate( Either all of past_key_values, past_attention_masks, past_input_ids,last_active_histories are provided or all are None. Args: histories (list[TextHistory]): - past_key_values (Optional[list[transformers.DynamicCache]]): Batched list of caches from the last generation + past_key_values (Optional[tuple[tuple[torch.Tensor]]]): Batched list of caches in legacy format from the last generation past_attention_masks (Optional[list[torch.Tensor]]): Batched list of attention masks from the last generation past_input_ids (Optional[list[torch.Tensor]]): Batched list of input ids from the last generation last_active_histories (Optional[list[int]]): indices of histories for which generation took place during the last generation turn @@ -591,7 +590,7 @@ def _get_batched_cache( current_cache.append((new_keys, new_values)) current_cache = tuple(current_cache) return ( - DynamicCache().from_legacy_cache(current_cache), + current_cache, combined_attention_masks[start_index:end_index], combined_input_ids[start_index:end_index], ) @@ -666,13 +665,14 @@ def _generate_batched( generation_kwargs = copy.deepcopy(self.generation_kwargs) generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) - generation_kwargs["use_cache"] = True generation_kwargs["return_dict_in_generate"] = True + if output_logits: generation_kwargs["output_logits"] = True # handle caching - generation_kwargs["past_key_values"] = past_key_values if past_key_values is not None else DynamicCache() + generation_kwargs["use_cache"] = True + generation_kwargs["return_legacy_cache"] = True if past_attention_masks is not None: padded_inputs["attention_mask"] = torch.concatenate( [past_attention_masks, padded_inputs["attention_mask"]], dim=1 @@ -683,9 +683,19 @@ def _generate_batched( if self.max_length is not None and padded_inputs["input_ids"].shape[-1] > self.max_length: return None, None, None, None, True - generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **generation_kwargs) + extracted_model = extract_model_from_parallel(self.model) + if extracted_model.pretrained_model._supports_cache_class: + generation_kwargs["past_key_values"] = ( + DynamicCache().from_legacy_cache(past_key_values) + if past_key_values is not None + else DynamicCache() + ) + else: + generation_kwargs["past_key_values"] = past_key_values + + generations = extracted_model.generate(**padded_inputs, **generation_kwargs) - if generations.past_key_values.to_legacy_cache()[0][0].shape[2] != generations.sequences.shape[1] - 1: + if generations.past_key_values[0][0].shape[2] != generations.sequences.shape[1] - 1: raise Exception("Cache should not contain keys and values for last generated token") new_past_key_values.append(generations.past_key_values) From cc99580e53e0661302b42d8c41f030a07d7e29c1 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Thu, 23 Jan 2025 16:54:40 +0100 Subject: [PATCH 26/38] refactor: make caching code optional in TextEnvironment --- tests/test_environments.py | 8 +- trl/environment/base_environment.py | 121 +++++++++++++++------------- 2 files changed, 72 insertions(+), 57 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index cbaf3476a0..0146571319 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -445,7 +445,7 @@ def test_cached_generate_batched(self, support_cache_class): ] model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( - model_inputs, batch_size=2 + model_inputs, batch_size=2, return_cache=True ) past_key_values, past_attention_masks, past_input_ids = env._combine_cache( @@ -498,7 +498,9 @@ def test_cache_class_support(self, support_cache_class): input_texts = ["test"] model_inputs = list(self.tokenizer(input_texts, return_tensors="pt").input_ids) - _, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched(model_inputs, batch_size=2) + _, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + model_inputs, batch_size=2, return_cache=True + ) past_key_values, past_attention_masks, past_input_ids = env._combine_cache( [True], past_key_values, past_attention_masks, past_input_ids @@ -540,7 +542,7 @@ def test_different_sequence_lengths(self, support_cache_class): input_texts = ["this is a test", "this is another, longer test", "some other batch"] model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( - model_inputs, batch_size=2 + model_inputs, batch_size=2, return_cache=True ) # remove the last two tokens from the second batch to pretend they were never generated second_cache = past_key_values[1] diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index f8063e37ac..fe2b1fdf26 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -514,6 +514,7 @@ def generate( combined_past_key_values=combined_past_key_values, combined_past_attention_masks=combined_past_attention_masks, combined_past_input_ids=combined_past_input_ids, + return_cache=self.use_cache, ) if not truncated: response_texts = self.tokenizer.batch_decode(response_tensors) @@ -595,6 +596,23 @@ def _get_batched_cache( combined_input_ids[start_index:end_index], ) + def extract_generation(self, sequence, mask): + """Remove padding and prompt based on the attention mask to extract generated tokens + Args: + sequence (torch.Tensor): A sequence with length corresponding to input sequence length + generation sequence length + mask (torch.Tensor): The input attention mask + """ + if not self.is_encoder_decoder: + # remove padding + output = sequence[(1 - mask).sum() :] + else: + output = sequence + + if not self.is_encoder_decoder: + # remove prompt + output = output[(mask).sum() :] + return output + # TODO make batch_size changeable def _generate_batched( self, @@ -605,6 +623,7 @@ def _generate_batched( combined_past_attention_masks=None, combined_past_input_ids=None, output_logits=False, + return_cache=False, ): """ Generate responses for a list of query tensors. @@ -617,6 +636,7 @@ def _generate_batched( combined_past_attention_masks (Optional[torch.Tensor]): The combined (unbatched) attention masks from the last generation combined_past_input_ids (Optional[torch.Tensor]): The combined (unbatched) input ids from the last generation """ + caching_enabled = return_cache or (combined_past_key_values is not None) # Ensures, that the next token is never conditioned on a padding token. This should never be a problem, as empty system prompts are not particularly useful and between segments there is always a response token. for query in query_tensors: if len(query) == 0: @@ -626,9 +646,11 @@ def _generate_batched( if not self.is_encoder_decoder: self.tokenizer.padding_side = "left" - new_past_key_values = [] - new_past_attention_masks = [] - new_past_input_ids = [] + if return_cache: + new_past_key_values, new_past_attention_masks, new_past_input_ids = ([], [], []) + else: + new_past_key_values, new_past_attention_masks, new_past_input_ids = (None, None, None) + if output_logits: all_logits = [] @@ -670,9 +692,9 @@ def _generate_batched( if output_logits: generation_kwargs["output_logits"] = True - # handle caching - generation_kwargs["use_cache"] = True - generation_kwargs["return_legacy_cache"] = True + if caching_enabled: + generation_kwargs["use_cache"] = True + generation_kwargs["return_legacy_cache"] = True if past_attention_masks is not None: padded_inputs["attention_mask"] = torch.concatenate( [past_attention_masks, padded_inputs["attention_mask"]], dim=1 @@ -684,60 +706,26 @@ def _generate_batched( return None, None, None, None, True extracted_model = extract_model_from_parallel(self.model) - if extracted_model.pretrained_model._supports_cache_class: + if caching_enabled and extracted_model.pretrained_model._supports_cache_class: generation_kwargs["past_key_values"] = ( DynamicCache().from_legacy_cache(past_key_values) if past_key_values is not None else DynamicCache() ) - else: + elif caching_enabled: generation_kwargs["past_key_values"] = past_key_values generations = extracted_model.generate(**padded_inputs, **generation_kwargs) - if generations.past_key_values[0][0].shape[2] != generations.sequences.shape[1] - 1: - raise Exception("Cache should not contain keys and values for last generated token") - new_past_key_values.append(generations.past_key_values) - - past_attention_mask = torch.ones_like(generations.sequences) - # Don't attend to generated padding or eos tokens - past_attention_mask[ - torch.logical_or( - generations.sequences == self.tokenizer.eos_token_id, - generations.sequences == self.tokenizer.pad_token_id, - ) - ] = 0 - past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask - if output_logits: logits = generations.logits - generations = generations.sequences + sequences = generations.sequences # copy for in-place modification - batch_new_past_input_ids = generations.detach().clone() - for generation, mask, num_generated_tokens, new_attention_mask, example_input_ids, i in zip( - generations, - padded_inputs["attention_mask"], - stopping_criteria.generated_tokens, - past_attention_mask, - batch_new_past_input_ids, - range(len(generations)), + batch_new_past_input_ids = sequences.detach().clone() + for generation, mask, num_generated_tokens in zip( + sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens ): - if not self.is_encoder_decoder: - # remove padding - output = generation[(1 - mask).sum() :] - padding_removed_b_n_past_input_ids = example_input_ids[(1 - mask).sum() :] - padding_removed_past_attention_mask = new_attention_mask[(1 - mask).sum() :] - else: - output = generation - padding_removed_past_attention_mask = new_attention_mask - padding_removed_b_n_past_input_ids = example_input_ids[(1 - mask).sum() :] - - if not self.is_encoder_decoder: - # remove prompt - output = output[(mask).sum() :] - padding_removed_b_n_past_input_ids = padding_removed_b_n_past_input_ids[(mask).sum() :] - padding_removed_past_attention_mask = padding_removed_past_attention_mask[(mask).sum() :] - + output = self.extract_generation(generation, mask) # remove chunk generated after stopping criteria in batch mode generated_tokens = output[:num_generated_tokens] if len(generated_tokens) < 1: @@ -745,18 +733,43 @@ def _generate_batched( raise Exception(f"Generation failed to produce any valid token; input length {input_length}") outputs.append(generated_tokens) - # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence - padding_removed_past_attention_mask[num_generated_tokens - 1 :] = 0 - # move last valid generated token to the end of the sequence to be the start of the next generation - padding_removed_b_n_past_input_ids[-1] = padding_removed_b_n_past_input_ids[num_generated_tokens - 1] - padding_removed_past_attention_mask[-1] = 1 # attend to the last valid generated token + + if return_cache: + if generations.past_key_values[0][0].shape[2] != generations.sequences.shape[1] - 1: + raise Exception("Cache should not contain keys and values for last generated token") + new_past_key_values.append(generations.past_key_values) + new_past_attention_mask = torch.ones_like(generations.sequences) + # Don't attend to generated padding or eos tokens + new_past_attention_mask[ + torch.logical_or( + generations.sequences == self.tokenizer.eos_token_id, + generations.sequences == self.tokenizer.pad_token_id, + ) + ] = 0 + new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask + + for mask, num_generated_tokens, new_attention_mask, example_input_ids in zip( + padded_inputs["attention_mask"], + stopping_criteria.generated_tokens, + new_past_attention_mask, + batch_new_past_input_ids, + ): + extracted_past_input_ids = self.extract_generation(example_input_ids, mask) + extracted_past_attention_mask = self.extract_generation(new_attention_mask, mask) + # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence + extracted_past_attention_mask[num_generated_tokens - 1 :] = 0 + # move last valid generated token to the end of the sequence to be the start of the next generation + extracted_past_input_ids[-1] = extracted_past_input_ids[num_generated_tokens - 1] + extracted_past_attention_mask[-1] = 1 # attend to the last valid generated token + + new_past_attention_masks.append(new_past_attention_mask) + new_past_input_ids.append(batch_new_past_input_ids) + if output_logits: for i, num_generated_tokens in enumerate(stopping_criteria.generated_tokens): relevant_logits = [batched_logits[i] for batched_logits in logits[:num_generated_tokens]] all_logits.append(torch.stack(relevant_logits, dim=0)) - new_past_attention_masks.append(past_attention_mask) - new_past_input_ids.append(batch_new_past_input_ids) self.tokenizer.padding_side = padding_side_default if output_logits: return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False, all_logits From 50119a8a6f12f466d15afeae3caa09a268bdb74a Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Thu, 23 Jan 2025 16:58:00 +0100 Subject: [PATCH 27/38] docs: TextEnvironment use_cache note untested Encoder-Decoder architecture --- docs/source/text_environments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index 068832523b..ccd9202dd6 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -114,7 +114,7 @@ Let's decompose the settings: | `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | -| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. `use_cache` may currently be incompatible with torch.compile due to a possible issue in the transformers library's generate method. See this comment in [_get_initial_cache_position](https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582).| +| `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. Compatibility with Encoder-Decoder architectures is untested. `use_cache` may currently be incompatible with torch.compile due to a possible issue in the transformers library's generate method. See this comment in [_get_initial_cache_position](https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582).| You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! From 772527b08f459866edbbe198cd85612b2779a94c Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 10:39:58 +0100 Subject: [PATCH 28/38] refactor: extract method from _generate_batched --- trl/environment/base_environment.py | 66 +++++++++++++++++------------ 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index fe2b1fdf26..ecafe90d2e 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -596,7 +596,7 @@ def _get_batched_cache( combined_input_ids[start_index:end_index], ) - def extract_generation(self, sequence, mask): + def _extract_generation(self, sequence, mask): """Remove padding and prompt based on the attention mask to extract generated tokens Args: sequence (torch.Tensor): A sequence with length corresponding to input sequence length + generation sequence length @@ -613,6 +613,39 @@ def extract_generation(self, sequence, mask): output = output[(mask).sum() :] return output + def _create_new_past_inputs(self, sequences, input_attention_mask, generated_tokens): + """Creates the new past_input_ids and new past_attention_mask for a batch. + Args: + sequences (torch.Tensor): The sequences returned by model.generate(...) + input_attention_mask (torch.Tensor): The attention mask that was input into model.generate(...) + generated_tokens (list[int]): The number of tokens generated for each history in the batch + """ + new_past_attention_mask = torch.ones_like(sequences) + # Don't attend to generated padding or eos tokens + new_past_attention_mask[ + torch.logical_or( + sequences == self.tokenizer.eos_token_id, + sequences == self.tokenizer.pad_token_id, + ) + ] = 0 + new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask + # copy for in-place modification + batch_new_past_input_ids = sequences.detach().clone() + for mask, num_generated_tokens, new_attention_mask, example_input_ids in zip( + input_attention_mask, + generated_tokens, + new_past_attention_mask, + batch_new_past_input_ids, + ): + extracted_past_input_ids = self._extract_generation(example_input_ids, mask) + extracted_past_attention_mask = self._extract_generation(new_attention_mask, mask) + # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence + extracted_past_attention_mask[num_generated_tokens - 1 :] = 0 + # move last valid generated token to the end of the sequence to be the start of the next generation + extracted_past_input_ids[-1] = extracted_past_input_ids[num_generated_tokens - 1] + extracted_past_attention_mask[-1] = 1 # attend to the last valid generated token + return batch_new_past_input_ids, new_past_attention_mask + # TODO make batch_size changeable def _generate_batched( self, @@ -720,12 +753,10 @@ def _generate_batched( if output_logits: logits = generations.logits sequences = generations.sequences - # copy for in-place modification - batch_new_past_input_ids = sequences.detach().clone() for generation, mask, num_generated_tokens in zip( sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens ): - output = self.extract_generation(generation, mask) + output = self._extract_generation(generation, mask) # remove chunk generated after stopping criteria in batch mode generated_tokens = output[:num_generated_tokens] if len(generated_tokens) < 1: @@ -738,30 +769,9 @@ def _generate_batched( if generations.past_key_values[0][0].shape[2] != generations.sequences.shape[1] - 1: raise Exception("Cache should not contain keys and values for last generated token") new_past_key_values.append(generations.past_key_values) - new_past_attention_mask = torch.ones_like(generations.sequences) - # Don't attend to generated padding or eos tokens - new_past_attention_mask[ - torch.logical_or( - generations.sequences == self.tokenizer.eos_token_id, - generations.sequences == self.tokenizer.pad_token_id, - ) - ] = 0 - new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask - - for mask, num_generated_tokens, new_attention_mask, example_input_ids in zip( - padded_inputs["attention_mask"], - stopping_criteria.generated_tokens, - new_past_attention_mask, - batch_new_past_input_ids, - ): - extracted_past_input_ids = self.extract_generation(example_input_ids, mask) - extracted_past_attention_mask = self.extract_generation(new_attention_mask, mask) - # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence - extracted_past_attention_mask[num_generated_tokens - 1 :] = 0 - # move last valid generated token to the end of the sequence to be the start of the next generation - extracted_past_input_ids[-1] = extracted_past_input_ids[num_generated_tokens - 1] - extracted_past_attention_mask[-1] = 1 # attend to the last valid generated token - + batch_new_past_input_ids, new_past_attention_mask = self._create_new_past_inputs( + sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + ) new_past_attention_masks.append(new_past_attention_mask) new_past_input_ids.append(batch_new_past_input_ids) From 4b58de2373ecf1ae43725b995c279e41a5eff8db Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 14:31:48 +0100 Subject: [PATCH 29/38] feat: add optional logits to TextHistory and remove eos segment created by truncation --- docs/source/text_environments.md | 2 ++ tests/test_environments.py | 53 ++++++++++++++++++++++++++--- trl/environment/base_environment.py | 46 ++++++++++++++----------- 3 files changed, 77 insertions(+), 24 deletions(-) diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md index ccd9202dd6..d742293396 100644 --- a/docs/source/text_environments.md +++ b/docs/source/text_environments.md @@ -115,6 +115,7 @@ Let's decompose the settings: | `max_length` | The maximum number of tokens to allow in an episode. | | `generation_kwargs`| Generation settings used by the language model. | | `use_cache` | Cache keys and values between segment generation. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and `generation_kwargs`. Compatibility with Encoder-Decoder architectures is untested. `use_cache` may currently be incompatible with torch.compile due to a possible issue in the transformers library's generate method. See this comment in [_get_initial_cache_position](https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582).| +| `save_logits` | Whether to save logits for the generated tokens in the returned histories. Mainly intended to help the user test caching for their use case. Backpropagation through logits is not supported. | You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! @@ -170,6 +171,7 @@ The following table summarises the available attributes of the `TextEnvironment` | `token_masks` | The token masks can be used to ignore system generated tokens by masking them. | | `completed` | Indicates if the interaction with the environment has completed. | | `truncated` | Indicates if the interaction with the environment has completed because max length was reached. | +| `logits` | A list containing a tensor for each non-system segment. If no logits were saved, then `logits` is an empty list. The tensors contain the logits for the tokens generated by the model. See `save_logits` in [`TextEnvironment`]. | With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look! diff --git a/tests/test_environments.py b/tests/test_environments.py index 0146571319..093693f621 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -191,7 +191,7 @@ def test_text_environment_generate(self): model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - generations_batched, _, _, _, _ = env._generate_batched(model_inputs, batch_size=2) + generations_batched, _, _, _, _, _ = env._generate_batched(model_inputs, batch_size=2) generations_batched = self.tokenizer.batch_decode(generations_batched) generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs] @@ -444,7 +444,7 @@ def test_cached_generate_batched(self, support_cache_class): "something unnecessary", ] model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + outputs, past_key_values, past_attention_masks, past_input_ids, _, _ = env._generate_batched( model_inputs, batch_size=2, return_cache=True ) @@ -498,7 +498,7 @@ def test_cache_class_support(self, support_cache_class): input_texts = ["test"] model_inputs = list(self.tokenizer(input_texts, return_tensors="pt").input_ids) - _, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + _, past_key_values, past_attention_masks, past_input_ids, _, _ = env._generate_batched( model_inputs, batch_size=2, return_cache=True ) @@ -514,7 +514,7 @@ def test_cache_class_support(self, support_cache_class): self.model.pretrained_model, "forward", new=cache_class_support_forward(support_cache_class, feedback) ): try: - _, _, _, _, _, all_logits_cached = env._generate_batched( + _, _, _, _, _, _ = env._generate_batched( model_inputs2, batch_size=2, combined_past_key_values=past_key_values, @@ -541,7 +541,7 @@ def test_different_sequence_lengths(self, support_cache_class): input_texts = ["this is a test", "this is another, longer test", "some other batch"] model_inputs = [self.tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched( + outputs, past_key_values, past_attention_masks, past_input_ids, _, _ = env._generate_batched( model_inputs, batch_size=2, return_cache=True ) # remove the last two tokens from the second batch to pretend they were never generated @@ -596,6 +596,49 @@ def test_different_sequence_lengths(self, support_cache_class): self.assertEqual(logits_uncached.shape[0], 4) self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + def test_run_with_caching(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} + caching_env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor([1, 2, 3]), + prompt="I am a prompt\n", + generation_kwargs=generation_kwargs, + use_cache=True, + save_logits=True, + max_turns=1, + ) + + queries = ["Request goodbye ", " this is another, longer test", " batch"] + _, responses_cached, _, _, histories_cached = caching_env.run(queries) + + generation_kwargs2 = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} + uncached_env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor([1, 2, 3]), + prompt="I am a prompt\n", + generation_kwargs=generation_kwargs2, + use_cache=False, + save_logits=True, + max_turns=1, + ) + _, responses_uncached, _, _, histories_uncached = uncached_env.run(queries) + for response_uncached, response_cached, history_uncached, history_cached in zip( + responses_uncached, responses_cached, histories_uncached, histories_cached + ): + self.assertTrue(torch.all(response_uncached == response_cached)) + self.assertEqual(len(history_uncached.logits), 1) + self.assertEqual(len(history_cached.logits), 1) + for logit_segment_uncached, logit_segment_cached in zip(history_uncached.logits, history_cached.logits): + self.assertEqual(len(logit_segment_uncached), 4) + self.assertEqual(logit_segment_uncached.shape[-1], self.model.config.vocab_size) + self.assertEqual(len(logit_segment_cached), 4) + self.assertEqual(logit_segment_cached.shape[-1], self.model.config.vocab_size) + self.assertTrue(torch.all(torch.abs(logit_segment_uncached - logit_segment_cached) < 1e-6)) + if __name__ == "__main__": pass diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index ecafe90d2e..34f28144b6 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -81,6 +81,7 @@ def __init__(self, text, tokens, system=True): self.text_spans = [] self.token_spans = [] self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.logits = [] self.text = "" self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) self.completed = False @@ -94,7 +95,7 @@ def __init__(self, text, tokens, system=True): self.append_segment(text, tokens, system=system) - def append_segment(self, text, tokens, system=True): + def append_segment(self, text, tokens, system=True, logits=None): """ Append a new segment to the history. @@ -102,6 +103,7 @@ def append_segment(self, text, tokens, system=True): text (`str`): The text of the new segment. tokens (`torch.LongTensor`): The tokens of the new segment. system (`bool`, *optional*): Whether the new segment is a system or user segment. + logits (`torch.Tensor`, *optional*): The logits for a non-system segment. """ if len(text) == 0 or len(tokens) == 0: @@ -120,6 +122,8 @@ def append_segment(self, text, tokens, system=True): self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) else: self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + if logits is not None: + self.logits.append(logits) self.token_spans.append((original_token_length, len(self.tokens))) def complete(self, truncated=False): @@ -244,6 +248,7 @@ def __init__( max_length=None, generation_kwargs=None, use_cache=False, + save_logits=False, ): """ Initialize TextEnvironment. @@ -259,6 +264,7 @@ def __init__( max_length (Optional[int]): The maximum number of tokens to allow in an episode. generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. use_cache (bool): Whether to cache past_key_values between segments. When using caching, [`TextEnvironment`] is not suited for training use, i.e. backpropagation through the generated graph. Use with Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using `BatchNorm`, the model should thus be in eval mode. + save_logits (bool): Whether to save logits in the returned histories. Mainly intended to help the user test caching for their use case. Backpropagation through logits is not supported. """ self.model = model self.tokenizer = tokenizer @@ -276,6 +282,7 @@ def __init__( self.max_turns = max_turns self.max_tool_response = max_tool_reponse self.use_cache = use_cache + self.save_logits = save_logits if generation_kwargs is None: self.generation_kwargs = dict() @@ -509,28 +516,29 @@ def generate( else: query_tensors = [histories[i].tokens for i in active_histories] - response_tensors, past_key_values, past_attention_masks, past_input_ids, truncated = self._generate_batched( - query_tensors, - combined_past_key_values=combined_past_key_values, - combined_past_attention_masks=combined_past_attention_masks, - combined_past_input_ids=combined_past_input_ids, - return_cache=self.use_cache, + response_tensors, past_key_values, past_attention_masks, past_input_ids, truncated, logits = ( + self._generate_batched( + query_tensors, + combined_past_key_values=combined_past_key_values, + combined_past_attention_masks=combined_past_attention_masks, + combined_past_input_ids=combined_past_input_ids, + return_cache=self.use_cache, + output_logits=self.save_logits, + ) ) if not truncated: response_texts = self.tokenizer.batch_decode(response_tensors) - for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + for i, response_text, response_tensor, j in zip( + active_histories, response_texts, response_tensors, range(len(active_histories)) + ): history = histories[i] if not history.completed: - history.append_segment(response_text, response_tensor, system=False) + history.append_segment( + response_text, response_tensor, system=False, logits=(logits[j] if self.save_logits else None) + ) else: for history in histories: if not history.completed: - # Adds an eos token, so that we end on a non-system segment - history.append_segment( - self.tokenizer.eos_token, - torch.tensor([self.tokenizer.eos_token_id]).to(self.current_device), - system=False, - ) history.complete(truncated=True) return histories, None, None, None, [] # invalidate cache @@ -686,6 +694,8 @@ def _generate_batched( if output_logits: all_logits = [] + else: + all_logits = None # pad all batches to same length for cache compatibility mask = [torch.ones_like(element) for element in query_tensors] @@ -778,10 +788,8 @@ def _generate_batched( if output_logits: for i, num_generated_tokens in enumerate(stopping_criteria.generated_tokens): relevant_logits = [batched_logits[i] for batched_logits in logits[:num_generated_tokens]] - all_logits.append(torch.stack(relevant_logits, dim=0)) + all_logits.append(torch.stack(relevant_logits, dim=0).detach().clone()) self.tokenizer.padding_side = padding_side_default - if output_logits: - return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False, all_logits - return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False + return outputs, new_past_key_values, new_past_attention_masks, new_past_input_ids, False, all_logits From e012571879aaf1eca891f7f2cab6cf8e02221bd2 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 17:15:43 +0100 Subject: [PATCH 30/38] tests: remove redundant test code --- tests/test_environments.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 093693f621..46442120ce 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -596,48 +596,28 @@ def test_different_sequence_lengths(self, support_cache_class): self.assertEqual(logits_uncached.shape[0], 4) self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) - def test_run_with_caching(self): + def test_output_logits(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} - caching_env = TextEnvironment( + env = TextEnvironment( self.model, self.tokenizer, tools=[DummyTool()], reward_fn=lambda x: torch.tensor([1, 2, 3]), prompt="I am a prompt\n", generation_kwargs=generation_kwargs, - use_cache=True, + use_cache=False, save_logits=True, max_turns=1, ) queries = ["Request goodbye ", " this is another, longer test", " batch"] - _, responses_cached, _, _, histories_cached = caching_env.run(queries) + _, _, _, _, histories = env.run(queries) - generation_kwargs2 = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} - uncached_env = TextEnvironment( - self.model, - self.tokenizer, - tools=[DummyTool()], - reward_fn=lambda x: torch.tensor([1, 2, 3]), - prompt="I am a prompt\n", - generation_kwargs=generation_kwargs2, - use_cache=False, - save_logits=True, - max_turns=1, - ) - _, responses_uncached, _, _, histories_uncached = uncached_env.run(queries) - for response_uncached, response_cached, history_uncached, history_cached in zip( - responses_uncached, responses_cached, histories_uncached, histories_cached - ): - self.assertTrue(torch.all(response_uncached == response_cached)) - self.assertEqual(len(history_uncached.logits), 1) - self.assertEqual(len(history_cached.logits), 1) - for logit_segment_uncached, logit_segment_cached in zip(history_uncached.logits, history_cached.logits): - self.assertEqual(len(logit_segment_uncached), 4) - self.assertEqual(logit_segment_uncached.shape[-1], self.model.config.vocab_size) - self.assertEqual(len(logit_segment_cached), 4) - self.assertEqual(logit_segment_cached.shape[-1], self.model.config.vocab_size) - self.assertTrue(torch.all(torch.abs(logit_segment_uncached - logit_segment_cached) < 1e-6)) + for history in histories: + self.assertEqual(len(history.logits), 1) + for logit_segment in history.logits: + self.assertEqual(len(logit_segment), 4) + self.assertEqual(logit_segment.shape[-1], self.model.config.vocab_size) if __name__ == "__main__": From 520abe8a5e6557658326029b7d7d8016afa74421 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 17:16:22 +0100 Subject: [PATCH 31/38] tests: update logit similarity calculation --- tests/test_environments.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 46442120ce..0154c27b5e 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -45,6 +45,13 @@ def _forward(*args, **kwargs): return _forward +# max-normalized logit comparison +def almost_equal_logits(logits1, logits2): + return torch.all( + (torch.abs(logits1 - logits2) / torch.max(torch.abs(logits1).max(), torch.abs(logits2).max())) < 1e-6 + ) + + def reshape_cache(cache): new_cache = [] for layer in cache: @@ -480,7 +487,7 @@ def test_cached_generate_batched(self, support_cache_class): self.assertTrue(torch.all(cached == uncached)) self.assertEqual(logits_cached.shape[0], 4) self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + self.assertTrue(almost_equal_logits(logits_cached, logits_uncached)) @parameterized.expand([(True,), (False,)]) def test_cache_class_support(self, support_cache_class): @@ -594,7 +601,7 @@ def test_different_sequence_lengths(self, support_cache_class): self.assertTrue(torch.all(cached == uncached)) self.assertEqual(logits_cached.shape[0], 4) self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(torch.all(torch.abs(logits_cached - logits_uncached) < 1e-6)) + self.assertTrue(almost_equal_logits(logits_cached, logits_uncached)) def test_output_logits(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} From eae5b03063cb74f2c6843a3eede21ff7febe46f0 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 17:17:59 +0100 Subject: [PATCH 32/38] feat: pad each batch individually in TextEnvironment --- trl/environment/base_environment.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 34f28144b6..451b21a123 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -697,17 +697,6 @@ def _generate_batched( else: all_logits = None - # pad all batches to same length for cache compatibility - mask = [torch.ones_like(element) for element in query_tensors] - inputs = {"input_ids": query_tensors, "attention_mask": mask} - all_padded_inputs = self.tokenizer.pad( - inputs, - padding=True, - max_length=None, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ).to(self.current_device) - # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) for i in range(0, len(query_tensors), batch_size): @@ -719,12 +708,16 @@ def _generate_batched( i, end_index, combined_past_key_values, combined_past_attention_masks, combined_past_input_ids ) - padded_inputs = { - "input_ids": all_padded_inputs["input_ids"][i:end_index], - "attention_mask": all_padded_inputs["attention_mask"][i:end_index], - } - - input_attention_mask = padded_inputs["attention_mask"].clone() + query_batch = query_tensors[i:end_index] + mask = [torch.ones_like(element) for element in query_batch] + inputs = {"input_ids": query_batch, "attention_mask": mask} + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) generation_kwargs = copy.deepcopy(self.generation_kwargs) From 0736ce332098cfa147195744f70cdddecd6bebee Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 18:36:13 +0100 Subject: [PATCH 33/38] refactor: remove redundant past_input_ids manipulation in TExtEnvironment caching --- trl/environment/base_environment.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 451b21a123..1abf5d8103 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -621,7 +621,7 @@ def _extract_generation(self, sequence, mask): output = output[(mask).sum() :] return output - def _create_new_past_inputs(self, sequences, input_attention_mask, generated_tokens): + def _create_new_past_attention_mask(self, sequences, input_attention_mask, generated_tokens): """Creates the new past_input_ids and new past_attention_mask for a batch. Args: sequences (torch.Tensor): The sequences returned by model.generate(...) @@ -638,21 +638,15 @@ def _create_new_past_inputs(self, sequences, input_attention_mask, generated_tok ] = 0 new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask # copy for in-place modification - batch_new_past_input_ids = sequences.detach().clone() - for mask, num_generated_tokens, new_attention_mask, example_input_ids in zip( + for mask, num_generated_tokens, new_attention_mask in zip( input_attention_mask, generated_tokens, new_past_attention_mask, - batch_new_past_input_ids, ): - extracted_past_input_ids = self._extract_generation(example_input_ids, mask) extracted_past_attention_mask = self._extract_generation(new_attention_mask, mask) - # Do not attend to invalid tokens that were generated after or or the last valid generated token, as we move it to the end of the sequence - extracted_past_attention_mask[num_generated_tokens - 1 :] = 0 - # move last valid generated token to the end of the sequence to be the start of the next generation - extracted_past_input_ids[-1] = extracted_past_input_ids[num_generated_tokens - 1] - extracted_past_attention_mask[-1] = 1 # attend to the last valid generated token - return batch_new_past_input_ids, new_past_attention_mask + # Do not attend to invalid tokens that were generated after or + extracted_past_attention_mask[num_generated_tokens:] = 0 + return new_past_attention_mask # TODO make batch_size changeable def _generate_batched( @@ -772,11 +766,11 @@ def _generate_batched( if generations.past_key_values[0][0].shape[2] != generations.sequences.shape[1] - 1: raise Exception("Cache should not contain keys and values for last generated token") new_past_key_values.append(generations.past_key_values) - batch_new_past_input_ids, new_past_attention_mask = self._create_new_past_inputs( + new_past_attention_mask = self._create_new_past_attention_mask( sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens ) new_past_attention_masks.append(new_past_attention_mask) - new_past_input_ids.append(batch_new_past_input_ids) + new_past_input_ids.append(sequences.clone()) if output_logits: for i, num_generated_tokens in enumerate(stopping_criteria.generated_tokens): From 66812175bf072338cc0a675a642dcce77bda2260 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Fri, 24 Jan 2025 19:09:54 +0100 Subject: [PATCH 34/38] fix: remove redundant attention blocking for padding and eos token and add tests for utility methods in TextEnvironment --- tests/test_environments.py | 35 +++++++++++++++++++++++++++++ trl/environment/base_environment.py | 7 ------ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 0154c27b5e..602d2ad95e 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -431,6 +431,41 @@ def test_get_batched_cache(self): expected_input_ids = torch.tensor([[5], [6]]) self.assertTrue(torch.all(batched_input_ids == expected_input_ids)) + def test_extract_generation(self): + env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + sequences = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]) + attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]]) + out1 = env._extract_generation(sequences[0], attention_mask[0]) + self.assertTrue(torch.all(out1 == torch.tensor([7, 8]))) + out2 = env._extract_generation(sequences[1], attention_mask[1]) + self.assertTrue(torch.all(out2 == torch.tensor([15, 16]))) + + def test_create_new_past_attention_mask(self): + env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + sequences = torch.tensor( + [ + [1, 2, 3, 4, 5, 6, 7, self.tokenizer.pad_token_id, 100, 200], + [9, 10, 11, 12, 13, 14, self.tokenizer.pad_token_id, 400, self.tokenizer.eos_token_id, 16], + ] + ) + attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]]) + generated_tokens = [2, 3] + expected_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]) + created_attention_mask = env._create_new_past_attention_mask(sequences, attention_mask, generated_tokens) + self.assertTrue(torch.all(expected_attention_mask == created_attention_mask)) + @parameterized.expand([(True,), (False,)]) def test_cached_generate_batched(self, support_cache_class): with patch.object(self.model.pretrained_model, "_supports_cache_class", new=support_cache_class): diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 1abf5d8103..7aa6cdbb1c 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -629,13 +629,6 @@ def _create_new_past_attention_mask(self, sequences, input_attention_mask, gener generated_tokens (list[int]): The number of tokens generated for each history in the batch """ new_past_attention_mask = torch.ones_like(sequences) - # Don't attend to generated padding or eos tokens - new_past_attention_mask[ - torch.logical_or( - sequences == self.tokenizer.eos_token_id, - sequences == self.tokenizer.pad_token_id, - ) - ] = 0 new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask # copy for in-place modification for mask, num_generated_tokens, new_attention_mask in zip( From 37e5041ff300db7f94bf82b4199a9314d621f8b1 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sat, 25 Jan 2025 21:41:57 +0100 Subject: [PATCH 35/38] fix: code and documenation cleanup for TextEnvironment caching --- tests/test_environments.py | 17 +++++++++++++-- trl/environment/base_environment.py | 34 +++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 602d2ad95e..ad6395fafc 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -431,6 +431,21 @@ def test_get_batched_cache(self): expected_input_ids = torch.tensor([[5], [6]]) self.assertTrue(torch.all(batched_input_ids == expected_input_ids)) + def test_same_is_none(self): + env = TextEnvironment( + self.model, + self.tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + self.assertTrue(env._same_is_none("", "")) + self.assertTrue(env._same_is_none(None, None)) + self.assertFalse(env._same_is_none("", None)) + self.assertFalse(env._same_is_none(None, "")) + self.assertTrue(env._same_is_none(None)) + self.assertTrue(env._same_is_none("")) + def test_extract_generation(self): env = TextEnvironment( self.model, @@ -526,8 +541,6 @@ def test_cached_generate_batched(self, support_cache_class): @parameterized.expand([(True,), (False,)]) def test_cache_class_support(self, support_cache_class): - self.assertEqual(self.model_id, "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id} env = TextEnvironment( self.model, diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 7aa6cdbb1c..593fa5cb35 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -416,7 +416,7 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa combines all caches in order to exclude completed histories from further generation Args: - batch_examples (list[bool]): mask indicating for each example, whether it is supposed to remain or not + example_mask (list[bool]): mask indicating for each example, whether it is supposed to remain or not past_key_values (tuple[tuple[torch.Tensor]]) : Batched list of caches (in legacy format) from the last generation past_attention_masks (list[torch.Tensor]): Batched list of attention masks from the last generation past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation @@ -445,7 +445,6 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa if extracted_keys.shape[2] != extracted_values.shape[2]: raise Exception("Cache format incompatible") - # left padding ensures, that the last valid generated token is what the next generated token is conditioned on start_position = max_sequence_length - 1 - extracted_keys.shape[2] new_values = torch.zeros_like(new_keys).to(self.current_device) new_keys[:, :, start_position:, :] = extracted_keys @@ -486,6 +485,17 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa return combined_cache, combined_attention_masks, combined_input_ids + def _same_is_none(self, *values): + """For input validation + Args: + values: list[object]: A list of values to test for having the same return value for `is None` + """ + expected_is_none = values[0] is None + for value in values[1:]: + if (value is None) != expected_is_none: + return False + return True + def generate( self, histories, @@ -498,12 +508,15 @@ def generate( Generate responses for a list of histories. Either all of past_key_values, past_attention_masks, past_input_ids,last_active_histories are provided or all are None. Args: - histories (list[TextHistory]): + histories (list[TextHistory]): A complete list of the TextHistories past_key_values (Optional[tuple[tuple[torch.Tensor]]]): Batched list of caches in legacy format from the last generation past_attention_masks (Optional[list[torch.Tensor]]): Batched list of attention masks from the last generation past_input_ids (Optional[list[torch.Tensor]]): Batched list of input ids from the last generation last_active_histories (Optional[list[int]]): indices of histories for which generation took place during the last generation turn """ + if not self._same_is_none(past_key_values, past_attention_masks, past_input_ids, last_active_histories): + raise Exception("Either all cache related inputs are supposed to be None or all are not None.") + active_histories = [i for i in range(len(histories)) if not histories[i].completed] combined_past_key_values, combined_past_attention_masks, combined_past_input_ids = (None, None, None) @@ -583,7 +596,7 @@ def _get_batched_cache( self, start_index, end_index, combined_past_key_values, combined_attention_masks, combined_input_ids ): """ - Extract (batch) cache for current batch + Extract (batch) cache, attention_mask and input_ids for current batch Args: start_index (int): start index of current batch end_index (int): end index of current batch (points to first element not in batch) @@ -626,7 +639,7 @@ def _create_new_past_attention_mask(self, sequences, input_attention_mask, gener Args: sequences (torch.Tensor): The sequences returned by model.generate(...) input_attention_mask (torch.Tensor): The attention mask that was input into model.generate(...) - generated_tokens (list[int]): The number of tokens generated for each history in the batch + generated_tokens (list[int]): The number of valid tokens generated for each history in the batch """ new_past_attention_mask = torch.ones_like(sequences) new_past_attention_mask[:, : input_attention_mask.shape[1]] = input_attention_mask @@ -664,6 +677,9 @@ def _generate_batched( combined_past_attention_masks (Optional[torch.Tensor]): The combined (unbatched) attention masks from the last generation combined_past_input_ids (Optional[torch.Tensor]): The combined (unbatched) input ids from the last generation """ + if not self._same_is_none(combined_past_key_values, combined_past_attention_masks, combined_past_input_ids): + raise Exception("Either all cache related inputs are supposed to be None or all are not None.g") + caching_enabled = return_cache or (combined_past_key_values is not None) # Ensures, that the next token is never conditioned on a padding token. This should never be a problem, as empty system prompts are not particularly useful and between segments there is always a response token. for query in query_tensors: @@ -738,20 +754,20 @@ def _generate_batched( elif caching_enabled: generation_kwargs["past_key_values"] = past_key_values + cloned_attention_mask = padded_inputs["attention_mask"].clone() generations = extracted_model.generate(**padded_inputs, **generation_kwargs) if output_logits: logits = generations.logits sequences = generations.sequences for generation, mask, num_generated_tokens in zip( - sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + sequences, cloned_attention_mask, stopping_criteria.generated_tokens ): output = self._extract_generation(generation, mask) # remove chunk generated after stopping criteria in batch mode generated_tokens = output[:num_generated_tokens] if len(generated_tokens) < 1: - input_length = padded_inputs["input_ids"].shape[0] - raise Exception(f"Generation failed to produce any valid token; input length {input_length}") + raise Exception(f"Generation failed to produce any valid tokens") outputs.append(generated_tokens) @@ -760,7 +776,7 @@ def _generate_batched( raise Exception("Cache should not contain keys and values for last generated token") new_past_key_values.append(generations.past_key_values) new_past_attention_mask = self._create_new_past_attention_mask( - sequences, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + sequences, cloned_attention_mask, stopping_criteria.generated_tokens ) new_past_attention_masks.append(new_past_attention_mask) new_past_input_ids.append(sequences.clone()) From a5fc0d754b86d9a3e17fadd26b4b57fc4eb0d115 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Sun, 26 Jan 2025 11:01:20 +0100 Subject: [PATCH 36/38] fix: add more validation to tests and TextEnvironment caching --- tests/test_environments.py | 9 ++++++--- trl/environment/base_environment.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index ad6395fafc..5f4f7d1433 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -531,6 +531,8 @@ def test_cached_generate_batched(self, support_cache_class): outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( model_inputs2_full, batch_size=2, output_logits=True ) + self.assertEqual(len(all_logits_uncached), 3) + self.assertEqual(len(all_logits_cached), 3) for cached, uncached, logits_cached, logits_uncached in zip( outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached ): @@ -643,6 +645,8 @@ def test_different_sequence_lengths(self, support_cache_class): outputs_uncached, _, _, _, _, all_logits_uncached = env._generate_batched( model_inputs2_full, batch_size=2, output_logits=True ) + self.assertEqual(len(all_logits_uncached), 3) + self.assertEqual(len(all_logits_cached), 3) for cached, uncached, logits_cached, logits_uncached in zip( outputs_cached, outputs_uncached, all_logits_cached, all_logits_uncached ): @@ -670,9 +674,8 @@ def test_output_logits(self): for history in histories: self.assertEqual(len(history.logits), 1) - for logit_segment in history.logits: - self.assertEqual(len(logit_segment), 4) - self.assertEqual(logit_segment.shape[-1], self.model.config.vocab_size) + self.assertEqual(len(history.logits[0]), 4) + self.assertEqual(history.logits[0].shape[-1], self.model.config.vocab_size) if __name__ == "__main__": diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 593fa5cb35..01e6120c8a 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -445,7 +445,9 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa if extracted_keys.shape[2] != extracted_values.shape[2]: raise Exception("Cache format incompatible") - start_position = max_sequence_length - 1 - extracted_keys.shape[2] + if extracted_keys.shape[2] > max_sequence_length - 1: + raise Exception("Cache sequence length is too large") + start_position = max(max_sequence_length - 1 - extracted_keys.shape[2], 0) new_values = torch.zeros_like(new_keys).to(self.current_device) new_keys[:, :, start_position:, :] = extracted_keys new_values[:, :, start_position:, :] = extracted_values @@ -459,6 +461,8 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa torch.concat([other_new_values, new_values], dim=0), ) example_mask_offset += num_examples + if example_mask_offset != len(example_mask): + raise Exception("example_mask size and cache size are different") combined_cache.append(combined_layer) combined_cache = tuple(combined_cache) @@ -480,9 +484,14 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa padded_input_ids[:, start_position:] = input_ids padded_past_input_ids.append(padded_input_ids) - combined_attention_masks = torch.concat(padded_attentions_masks, dim=0)[example_mask] - combined_input_ids = torch.concat(padded_past_input_ids, dim=0)[example_mask] - + combined_attention_masks = torch.concat(padded_attentions_masks, dim=0) + if combined_attention_masks.shape[0] != len(example_mask): + raise Exception("example_mask and attention_masks have varying example counts") + combined_attention_masks = combined_attention_masks[example_mask] + combined_input_ids = torch.concat(padded_past_input_ids, dim=0) + if combined_input_ids.shape[0] != len(example_mask): + raise Exception("example_mask and input ids have varying example counts") + combined_input_ids = combined_input_ids[example_mask] return combined_cache, combined_attention_masks, combined_input_ids def _same_is_none(self, *values): From ac00cea22fba22884d04762426a9849b2cac18cf Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Wed, 29 Jan 2025 11:26:40 +0100 Subject: [PATCH 37/38] refactor: Add more validation to caching in TextEnv --- trl/environment/base_environment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 01e6120c8a..8b6c190c26 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -447,7 +447,9 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa raise Exception("Cache format incompatible") if extracted_keys.shape[2] > max_sequence_length - 1: raise Exception("Cache sequence length is too large") - start_position = max(max_sequence_length - 1 - extracted_keys.shape[2], 0) + start_position = max_sequence_length - 1 - extracted_keys.shape[2] + if start_position < 0: + raise Exception("start position incorrect") new_values = torch.zeros_like(new_keys).to(self.current_device) new_keys[:, :, start_position:, :] = extracted_keys new_values[:, :, start_position:, :] = extracted_values @@ -472,6 +474,8 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa if attention_mask.shape[1] != input_ids.shape[1]: raise Exception("Cache format incompatible") start_position = max_sequence_length - attention_mask.shape[1] + if start_position < 0: + raise Exception("start position incorrect") padded_attention_mask = torch.zeros( (attention_mask.shape[0], max_sequence_length), dtype=attention_mask.dtype ).to(self.current_device) From 8e5ac7e34e1f9473dbc57566231c7a7877a247e5 Mon Sep 17 00:00:00 2001 From: Konrad Gerlach Date: Wed, 29 Jan 2025 11:48:11 +0100 Subject: [PATCH 38/38] tests: fix logits comparison --- tests/test_environments.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 5f4f7d1433..1f0d09cb78 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -46,10 +46,8 @@ def _forward(*args, **kwargs): # max-normalized logit comparison -def almost_equal_logits(logits1, logits2): - return torch.all( - (torch.abs(logits1 - logits2) / torch.max(torch.abs(logits1).max(), torch.abs(logits2).max())) < 1e-6 - ) +def almost_equal_logits(logits_old, logits_new): + return torch.all((torch.abs(logits_old - logits_new) / torch.abs(logits_old).max()) < 1e-6) def reshape_cache(cache): @@ -539,7 +537,7 @@ def test_cached_generate_batched(self, support_cache_class): self.assertTrue(torch.all(cached == uncached)) self.assertEqual(logits_cached.shape[0], 4) self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(almost_equal_logits(logits_cached, logits_uncached)) + self.assertTrue(almost_equal_logits(logits_uncached, logits_cached)) @parameterized.expand([(True,), (False,)]) def test_cache_class_support(self, support_cache_class): @@ -653,7 +651,7 @@ def test_different_sequence_lengths(self, support_cache_class): self.assertTrue(torch.all(cached == uncached)) self.assertEqual(logits_cached.shape[0], 4) self.assertEqual(logits_uncached.shape[0], 4) - self.assertTrue(almost_equal_logits(logits_cached, logits_uncached)) + self.assertTrue(almost_equal_logits(logits_uncached, logits_cached)) def test_output_logits(self): generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.tokenizer.eos_token_id}