diff --git a/webshop/main.py b/webshop/main.py index b9f0550..91b20b1 100644 --- a/webshop/main.py +++ b/webshop/main.py @@ -338,7 +338,7 @@ def generate_embeddings(memory): continue # extract embeddings embeddings[key] = model_embedding.encode(retrieve_info) - return embeddings + return memory, embeddings def generate_examples(info, actions, memory, embeddings, reasoning='', k=3, act_len=0, use_act_obs=False): @@ -558,7 +558,7 @@ def webshop_run_rap(idx, prompt, memory, embeddings, to_print=True): sr_games = [] if trial != 0: memory = current_memory[:] - embeddings = generate_embeddings(memory) + memory, embeddings = generate_embeddings(memory) current_memory = [] for i in range(start, start+n): print('-----------------')