Skip to content

Commit

Permalink
Update and document GPU buffer handling (#2751)
Browse files Browse the repository at this point in the history
* Update GPU handling

This updates how we handle GPU buffers. See the new docs page for a
simple example.

The basic idea, as discussed in ..., is to use host buffers for all
metadata objects and device buffers for data.

Zarr has two types of buffers: plain buffers (used for a stream of
bytes) and ndbuffers (used for bytes that represent ndarrays). To make
it easier for users, I've added a new config option
`zarr.config.enable_gpu()` that can be used to update those both. If
we need additional customizations in the future, we can add them here.

* fixed doc

* Fixup

* changelog

* doctest, skip

* removed not gpu

* assert that the type matches

* Added changelog notes

---------

Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
  • Loading branch information
TomAugspurger and d-v-b authored Feb 14, 2025
1 parent 47003d7 commit 24ef221
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 9 deletions.
1 change: 1 addition & 0 deletions changes/2751.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug with Zarr using device memory, instead of host memory, for storing metadata when using GPUs.
1 change: 1 addition & 0 deletions changes/2751.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added new user guide on :ref:`user-guide-gpu`.
1 change: 1 addition & 0 deletions changes/2751.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added :meth:`zarr.config.enable_gpu` to update Zarr's configuration to use GPUs.
21 changes: 21 additions & 0 deletions docs/developers/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ during development at `http://0.0.0.0:8000/ <http://0.0.0.0:8000/>`_. This can b

$ hatch --env docs run serve

.. _changelog:

Changelog
~~~~~~~~~

zarr-python uses `towncrier`_ to manage release notes. Most pull requests should
include at least one news fragment describing the changes. To add a release
note, you'll need the GitHub issue or pull request number and the type of your
change (``feature``, ``bugfix``, ``doc``, ``removal``, ``misc``). With that, run
```towncrier create``` with your development environment, which will prompt you
for the issue number, change type, and the news text::

towncrier create

Alternatively, you can manually create the files in the ``changes`` directory
using the naming convention ``{issue-number}.{change-type}.rst``.

See the `towncrier`_ docs for more.

.. _towncrier: https://towncrier.readthedocs.io/en/stable/tutorial.html

Development best practices, policies and procedures
---------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/user-guide/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Configuration options include the following:
- Whether empty chunks are written to storage ``array.write_empty_chunks``
- Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers``
- Selections of implementations of codecs, codec pipelines and buffers
- Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more.

For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
first register the implementations in the registry and then select them in the config.
Expand Down
37 changes: 37 additions & 0 deletions docs/user-guide/gpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
.. _user-guide-gpu:

Using GPUs with Zarr
====================

Zarr can use GPUs to accelerate your workload by running
:meth:`zarr.config.enable_gpu`.

.. note::

`zarr-python` currently supports reading the ndarray data into device (GPU)
memory as the final stage of the codec pipeline. Data will still be read into
or copied to host (CPU) memory for encoding and decoding.

In the future, codecs will be available compressing and decompressing data on
the GPU, avoiding the need to move data between the host and device for
compression and decompression.

Reading data into device memory
-------------------------------

:meth:`zarr.config.enable_gpu` configures Zarr to use GPU memory for the data
buffers used internally by Zarr.

.. code-block:: python
>>> import zarr
>>> import cupy as cp # doctest: +SKIP
>>> zarr.config.enable_gpu() # doctest: +SKIP
>>> store = zarr.storage.MemoryStore() # doctest: +SKIP
>>> z = zarr.create_array( # doctest: +SKIP
... store=store, shape=(100, 100), chunks=(10, 10), dtype="float32",
... )
>>> type(z[:10, :10]) # doctest: +SKIP
cupy.ndarray
Note that the output type is a ``cupy.ndarray`` rather than a NumPy array.
1 change: 1 addition & 0 deletions docs/user-guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Advanced Topics
performance
consolidated_metadata
extending
gpu


.. Coming soon
Expand Down
16 changes: 9 additions & 7 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
NDBuffer,
default_buffer_prototype,
)
from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype
from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks
from zarr.core.chunk_key_encodings import (
ChunkKeyEncoding,
Expand Down Expand Up @@ -163,19 +164,20 @@ async def get_array_metadata(
) -> dict[str, JSON]:
if zarr_format == 2:
zarray_bytes, zattrs_bytes = await gather(
(store_path / ZARRAY_JSON).get(), (store_path / ZATTRS_JSON).get()
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
)
if zarray_bytes is None:
raise FileNotFoundError(store_path)
elif zarr_format == 3:
zarr_json_bytes = await (store_path / ZARR_JSON).get()
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
if zarr_json_bytes is None:
raise FileNotFoundError(store_path)
elif zarr_format is None:
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
(store_path / ZARR_JSON).get(),
(store_path / ZARRAY_JSON).get(),
(store_path / ZATTRS_JSON).get(),
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
)
if zarr_json_bytes is not None and zarray_bytes is not None:
# warn and favor v3
Expand Down Expand Up @@ -1348,7 +1350,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
"""
Asynchronously save the array metadata.
"""
to_save = metadata.to_buffer_dict(default_buffer_prototype())
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]

if ensure_parents:
Expand All @@ -1360,7 +1362,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
[
(parent.store_path / key).set_if_not_exists(value)
for key, value in parent.metadata.to_buffer_dict(
default_buffer_prototype()
cpu_buffer_prototype
).items()
]
)
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from zarr.core.buffer import core
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike
from zarr.registry import (
register_buffer,
register_ndbuffer,
)

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -215,3 +219,6 @@ def __setitem__(self, key: Any, value: Any) -> None:


buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)

register_buffer(Buffer)
register_ndbuffer(NDBuffer)
13 changes: 12 additions & 1 deletion src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@

from __future__ import annotations

from typing import Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast

from donfig import Config as DConfig

if TYPE_CHECKING:
from donfig.config_obj import ConfigSet


class BadConfigError(ValueError):
_msg = "bad Config: %r"
Expand All @@ -56,6 +59,14 @@ def reset(self) -> None:
self.clear()
self.refresh()

def enable_gpu(self) -> ConfigSet:
"""
Configure Zarr to use GPUs where possible.
"""
return self.set(
{"buffer": "zarr.core.buffer.gpu.Buffer", "ndbuffer": "zarr.core.buffer.gpu.NDBuffer"}
)


# The default configuration for zarr
config = Config(
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def has_cupy() -> bool:
return False


T_Callable = TypeVar("T_Callable", bound=Callable[[], Coroutine[Any, Any, None]])
T_Callable = TypeVar("T_Callable", bound=Callable[..., Coroutine[Any, Any, None] | None])


# Decorator for GPU tests
Expand Down
38 changes: 38 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from zarr.errors import MetadataValidationError
from zarr.storage import MemoryStore
from zarr.storage._utils import normalize_path
from zarr.testing.utils import gpu_test


def test_create(memory_store: Store) -> None:
Expand Down Expand Up @@ -1121,3 +1122,40 @@ def test_open_array_with_mode_r_plus(store: Store) -> None:
assert isinstance(z2, Array)
assert (z2[:] == 1).all()
z2[:] = 3


@gpu_test
@pytest.mark.parametrize(
"store",
["local", "memory", "zip"],
indirect=True,
)
@pytest.mark.parametrize("zarr_format", [None, 2, 3])
def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None) -> None:
import cupy as cp

if zarr_format == 2:
# Without this, the zstd codec attempts to convert the cupy
# array to bytes.
compressors = None
else:
compressors = "auto"

with zarr.config.enable_gpu():
src = cp.random.uniform(size=(100, 100)) # allocate on the device
z = zarr.create_array(
store,
name="a",
shape=src.shape,
chunks=(10, 10),
dtype=src.dtype,
overwrite=True,
zarr_format=zarr_format,
compressors=compressors,
)
z[:10, :10] = src[:10, :10]

result = z[:10, :10]
# assert_array_equal doesn't check the type
assert isinstance(result, type(src))
cp.testing.assert_array_equal(result, src[:10, :10])

0 comments on commit 24ef221

Please sign in to comment.