Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

separating static batching as a utils function #77

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions alfred/fm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .query import Query, RankedQuery, CompletionQuery
from .response import Response, CompletionResponse, RankedResponse
from .utils import DynamicBatcher, clear_cuda_cache, batch_multimodal
from .utils import DynamicBatcher, clear_cuda_cache, batch_multimodal, static_batch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,9 +131,7 @@ def forward(
)
except AttributeError:
if batch_policy == "static":
batched_queries = np.array_split(
queries, max(1, len(queries) // batch_size)
)
batched_queries = static_batch(queries, batch_size=batch_size)
pretokenized = False
elif batch_policy == "dynamic":
if pretokenize:
Expand Down Expand Up @@ -200,9 +198,7 @@ def forward(
clear_cuda_cache()
if batch_policy == "static":
batch_size = int(batch_size * 0.8)
batched_queries = np.array_split(
queries, len(queries) // batch_size
)
batched_queries = static_batch(queries, batch_size=batch_size)
logging.info(f"New batch size: {batch_size}")
elif batch_policy == "dynamic":
DB = DynamicBatcher(
Expand Down
28 changes: 28 additions & 0 deletions alfred/fm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,31 @@ def _process_batch(batch):
clear_cuda_cache()

return batches


def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]:
"""
Static Batching Utility
Batch queries into fixed size batches

:param queries: A list of queries to be batched
:type queries: List[Query]
:param batch_sz: The batch size
:type batch_sz: int
:return: A list of batches
:rtype: List[List[Query]]
"""
batches = []
batch = []
for query in queries:
if len(batch) == batch_sz:
batches.append(batch)
batch = []
if isinstance(query, CompletionQuery):
_q = query.load()[0]
elif isinstance(query, RankedQuery):
_q = query.prompt
batch.append(_q)
if len(batch) > 0:
batches.append(batch)
return batches
14 changes: 7 additions & 7 deletions docs/alfred/fm/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

## APIAccessFoundationModel

[Show source in model.py:382](../../../alfred/fm/model.py#L382)
[Show source in model.py:378](../../../alfred/fm/model.py#L378)

#### Signature

Expand Down Expand Up @@ -49,7 +49,7 @@ class FoundationModel(abc.ABC): ...

### FoundationModel().__call__

[Show source in model.py:360](../../../alfred/fm/model.py#L360)
[Show source in model.py:356](../../../alfred/fm/model.py#L356)

This function returns the output of the run function when the
model is called as a function. It can be used as model(queries),
Expand Down Expand Up @@ -157,7 +157,7 @@ def _score_batch(

### FoundationModel().encode

[Show source in model.py:277](../../../alfred/fm/model.py#L277)
[Show source in model.py:273](../../../alfred/fm/model.py#L273)

This function is a wrapper around the forward function

Expand Down Expand Up @@ -239,7 +239,7 @@ def forward(

### FoundationModel().generate

[Show source in model.py:226](../../../alfred/fm/model.py#L226)
[Show source in model.py:222](../../../alfred/fm/model.py#L222)

This function is a wrapper around the forward function for running
CompletionQuery objects through the foundation model. It returns a list
Expand Down Expand Up @@ -275,7 +275,7 @@ def generate(

### FoundationModel().run

[Show source in model.py:308](../../../alfred/fm/model.py#L308)
[Show source in model.py:304](../../../alfred/fm/model.py#L304)

This function is the main entry point for users to run queries through the foundation model.
It accepts raw query content and automatically converts it into query objects.
Expand Down Expand Up @@ -308,7 +308,7 @@ def run(

### FoundationModel().score

[Show source in model.py:251](../../../alfred/fm/model.py#L251)
[Show source in model.py:247](../../../alfred/fm/model.py#L247)

This function is a wrapper around the forward function
for running RankedQuery objects through the foundation model.
Expand Down Expand Up @@ -346,7 +346,7 @@ def score(

## LocalAccessFoundationModel

[Show source in model.py:397](../../../alfred/fm/model.py#L397)
[Show source in model.py:393](../../../alfred/fm/model.py#L393)

#### Signature

Expand Down
28 changes: 28 additions & 0 deletions docs/alfred/fm/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [normalize_logits](#normalize_logits)
- [reorder_array](#reorder_array)
- [retry](#retry)
- [static_batch](#static_batch)
- [tokenize](#tokenize)
- [type_print](#type_print)

Expand Down Expand Up @@ -369,6 +370,33 @@ def retry(num_retries=3, wait_time=0.1, exceptions=(Exception)): ...



## static_batch

[Show source in utils.py:529](../../../alfred/fm/utils.py#L529)

Static Batching Utility
Batch queries into fixed size batches

#### Arguments

- `queries` - A list of queries to be batched
:type queries: List[Query]
- `batch_sz` - The batch size
:type batch_sz: int

#### Returns

A list of batches
Type: *List[List[Query]]*

#### Signature

```python
def static_batch(queries: Query, batch_sz: int = 1024) -> List[List[Query]]: ...
```



## tokenize

[Show source in utils.py:89](../../../alfred/fm/utils.py#L89)
Expand Down
Loading