diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index e5a88da60..effb94fb4 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -46,9 +46,9 @@ def __init__( revision: str = "main", device: str = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", - batch_size: Union[int, str] = 1, + batch_size: int = 1, trust_remote_code: Optional[bool] = False, - attn_implementation: Optional[str] = "flash_attention_2", + attn_implementation: Optional[str] = None, device_map: str = "", chat_template: Optional[str] = None, use_cache: bool = True,