Skip to content

Commit

Permalink
add model specific prompt and gen kwargs in sqa (EvolvingLMMs-Lab#19)
Browse files Browse the repository at this point in the history
* add mmme

* black

* add model specific prompt and gen kwargs

* black

* add yaml config to supprot multi-model eval

* print table at the end

* refactor multi model code
  • Loading branch information
jzhang38 authored Jan 25, 2024
1 parent 52ee4a1 commit f7a7db5
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 25 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pip install -e .
```

```bash
accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-13b" --tasks gqa --batch_size 1 --log_samples --log_samples_sufix debug --output_path ./logs/ # Eactly reproduce llava results
accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-13b" --tasks mme --batch_size 1 --log_samples --log_samples_sufix debug --output_path ./logs/ # Eactly reproduce llava results
accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-13b" --tasks mme --batch_size 1 --log_samples --log_samples_suffix debug --output_path ./logs/ # Eactly reproduce llava results
accelerate launch --num_processes=8 -m lmms_eval --config example_eval.yaml # Eactly reproduce llava results
```
## Current models

Expand Down
15 changes: 15 additions & 0 deletions example_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
- model: llava
model_args: pretrained=liuhaotian/llava-v1.5-7b
tasks: mme
batch_size: 1
log_samples: true
log_samples_suffix: debug
output_path: "./logs/"

- model: llava
model_args: pretrained=liuhaotian/llava-v1.5-13b
tasks: mme
batch_size: 1
log_samples: true
log_samples_suffix: debug
output_path: "./logs/"
50 changes: 39 additions & 11 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import re
import yaml
import sys
import json
import logging
Expand All @@ -26,6 +26,7 @@ def _handle_non_serializable(o):

def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--config", default="", help="Path to a yaml file specifying all eval arguments, will ignore cli arguments if specified")
parser.add_argument("--model", default="hf", help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
Expand Down Expand Up @@ -109,14 +110,12 @@ def parse_eval_args() -> argparse.Namespace:
default="INFO",
help="Log error when tasks are not registered.",
)
return parser.parse_args()
args = parser.parse_args()

return args

def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
# we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args()

def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
eval_logger.info(f"Verbosity set to {args.verbosity}")
Expand Down Expand Up @@ -204,11 +203,40 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
filename.open("w").write(samples_dumped)
eval_logger.info(f"Saved samples to {filename}")

print(f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"batch_size: {args.batch_size}")
print(evaluator.make_table(results))
if "groups" in results:
print(evaluator.make_table(results, "groups"))
return results
return None


def print_results(args, results):
print(f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"batch_size: {args.batch_size}")
print(evaluator.make_table(results))
if "groups" in results:
print(evaluator.make_table(results, "groups"))


if __name__ == "__main__":
cli_evaluate()
args = parse_eval_args()
args_list = []
results_list = []
if args.config and os.path.exists(args.config):
with open(args.config, "r") as file:
config_args = yaml.safe_load(file)
config_args = [config_args] if type(config_args) != list else config_args
# multiple configs, create args list first
for config in config_args:
args_copy = argparse.Namespace(**vars(args))
for key, value in config.items():
setattr(args_copy, key, value)
args_list.append(args_copy)
else:
args_list.append(args)

# run each config
for args in args_list:
results = cli_evaluate(args)
results_list.append(results)
# print results
for args, results in zip(args_list, results_list):
# cli_evaluate will return none if the process is not the main process (rank 0)
if results is not None:
print_results(args, results)
34 changes: 32 additions & 2 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class TaskConfig(dict):

metadata: Union[str, list] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks

model_specific_prompt_kwargs: dict = None
model_specific_generation_kwargs: dict = None

def __post_init__(self) -> None:
if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)):
import inspect
Expand Down Expand Up @@ -495,9 +498,13 @@ class ConfigurableTask(Task):
OUTPUT_TYPE = None
CONFIG = None

def __init__(self) -> None: # TODO no super() call here
def __init__(self, model_name) -> None: # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
# different model requires different prompt, we have to take those into account.

self.model_name = model_name
self._prepare_model_specific_config()

assert self.config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self.config.output_type
Expand Down Expand Up @@ -579,6 +586,23 @@ def __init__(self) -> None: # TODO no super() call here
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning(f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace')

def _prepare_model_specific_config(self):
self.model_specific_prompt_kwargs = self.config.model_specific_prompt_kwargs
if self.model_specific_prompt_kwargs is not None:
if self.model_name in self.model_specific_prompt_kwargs:
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs[self.model_name]
else:
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs["default"]

self.model_specific_generation_kwargs = self.config.model_specific_generation_kwargs
if self.model_specific_generation_kwargs is not None:
if self.model_name in self.model_specific_generation_kwargs:
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs[self.model_name]
else:
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs["default"]

self.config.generation_kwargs.update(self.model_specific_generation_kwargs)

def _prepare_metric_and_aggregation(self):
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
Expand Down Expand Up @@ -764,7 +788,13 @@ def doc_to_text(self, doc):
else:
return text_string
elif callable(doc_to_text):
return doc_to_text(doc)
return (
doc_to_text(doc, self.model_specific_prompt_kwargs)
if self.model_specific_prompt_kwargs is not None
else doc_to_text(
doc,
)
)
# Used when applying a Promptsource template
elif hasattr(doc_to_text, "apply"):
applied_prompt = doc_to_text.apply(doc)
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def simple_evaluate(
},
)

task_dict = lmms_eval.tasks.get_task_dict(tasks)
task_dict = lmms_eval.tasks.get_task_dict(tasks, model_name=model)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def initialize_tasks(verbosity="INFO"):
include_path(task_dir)


def get_task(task_name):
def get_task(task_name, model_name):
try:
return TASK_REGISTRY[task_name]()
return TASK_REGISTRY[task_name](model_name=model_name)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
Expand All @@ -136,7 +136,7 @@ def get_task_name_from_object(task_object):


# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]]):
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], model_name: str):
all_task_dict = {}

if type(task_name_list) != list:
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]]):
if task_name not in all_task_dict:
all_task_dict = {
**all_task_dict,
task_name: get_task(task_name=task_element),
task_name: get_task(task_name=task_element, model_name=model_name),
}

return all_task_dict
12 changes: 9 additions & 3 deletions lmms_eval/tasks/scienceqa_img/scienceqa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ doc_to_visual: !function utils.sqa_doc_to_visual
doc_to_text: !function utils.sqa_doc_to_text
doc_to_target: !function utils.sqa_doc_to_target
generation_kwargs:
until:
- "ASSISTANT:"
image_aspect_ratio: original
max_new_tokens: 16
temperature: 0
do_sample: False
Expand All @@ -23,3 +20,12 @@ metric_list:
process_results: !function utils.sqa_process_results
metadata:
- version: 0.0

model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer with the option's letter from the given choices directly."
model_specific_generation_kwargs:
llava:
image_aspect_ratio: original

7 changes: 5 additions & 2 deletions lmms_eval/tasks/scienceqa_img/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
def sqa_doc_to_text(doc):
def sqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
context, question, choices = doc["hint"], doc["question"], doc["choices"]
len_choices = len(choices)
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
if context:
context = f"Context: {context}\n"
return f"{context}{question}\n{choices_str}\nAnswer with the option's letter from the given choices directly."

post_prompt = model_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
return f"{pre_prompt}{context}{question}\n{choices_str}{post_prompt}"


def sqa_doc_to_visual(doc):
Expand Down

0 comments on commit f7a7db5

Please sign in to comment.