From 094540837fc52b4a7aa2b1a43132ab961c806214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Thu, 22 Feb 2024 23:52:30 +0700 Subject: [PATCH] Add option to read tensors lazily --- lib/safetensors.ex | 76 +++++++++++++++++++--------------- lib/safetensors/file_tensor.ex | 20 +++++++++ lib/safetensors/shared.ex | 36 ++++++++++++++++ mix.lock | 2 +- test/safetensors_test.exs | 17 ++++++++ 5 files changed, 116 insertions(+), 35 deletions(-) create mode 100644 lib/safetensors/file_tensor.ex create mode 100644 lib/safetensors/shared.ex diff --git a/lib/safetensors.ex b/lib/safetensors.ex index 88d3730..cbf150f 100644 --- a/lib/safetensors.ex +++ b/lib/safetensors.ex @@ -22,6 +22,8 @@ defmodule Safetensors do """ + alias Safetensors.Shared + @header_metadata_key "__metadata__" @type_to_dtype %{ @@ -60,7 +62,7 @@ defmodule Safetensors do :ok = :file.write(file, header_binary(header_entries)) for {_tensor_name, tensor} <- tensors do - :ok = :file.write(file, tensor_to_binary(tensor)) + :ok = :file.write(file, tensor_to_iodata(tensor)) end end) @@ -97,12 +99,12 @@ defmodule Safetensors do Nx.size(tensor) * elem_byte_size end - defp tensor_to_binary(tensor) do + defp tensor_to_iodata(tensor) do {_, elem_size} = Nx.type(tensor) tensor |> Nx.to_binary() - |> new_byte_order(elem_size, :little) + |> Shared.new_byte_order(elem_size, :little) end @doc """ @@ -119,7 +121,7 @@ defmodule Safetensors do {header_entries, {buffer, _offset}} = Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} -> {header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset) - binary = tensor_to_binary(tensor) + binary = tensor_to_iodata(tensor) {header_entry, {[buffer, binary], end_offset}} end) @@ -131,9 +133,19 @@ defmodule Safetensors do Tensors are loaded into Nx one by one, without the need to load the entire file from disk into memory. + + ## Options + + * `:lazy` - when `true`, instead of returning tensors, the function + returns lazy containers. Such a container can be converted to a + tensor using `Nx.to_tensor/1` and it is only at that point that + it is from the file. Defaults to `false` + """ - @spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()} - def read!(path) do + @spec read!(path :: Path.t(), keyword()) :: %{String.t() => Nx.LazyContainer.t()} + def read!(path, opts \\ []) do + opts = Keyword.validate!(opts, lazy: false) + File.open!(path, [:read, :raw], fn file -> {:ok, <>} = :file.read(file, 8) {:ok, header_json} = :file.read(file, header_size) @@ -143,10 +155,26 @@ defmodule Safetensors do for {tensor_name, tensor_info} <- header, into: %{} do %{"data_offsets" => [offset_start, offset_end]} = tensor_info - {:ok, binary} = - :file.pread(file, header_size + 8 + offset_start, offset_end - offset_start) - - {tensor_name, build_tensor(binary, tensor_info)} + {shape, type} = shape_and_type(tensor_info) + + byte_offset = header_size + 8 + offset_start + byte_size = offset_end - offset_start + + value = + if opts[:lazy] do + %Safetensors.FileTensor{ + shape: shape, + type: type, + path: path, + byte_offset: byte_offset, + byte_size: byte_size + } + else + {:ok, binary} = :file.pread(file, byte_offset, byte_size) + Shared.build_tensor(binary, shape, type) + end + + {tensor_name, value} end end) end @@ -170,11 +198,12 @@ defmodule Safetensors do for {tensor_name, tensor_info} <- header, into: %{} do %{"data_offsets" => [offset_start, offset_end]} = tensor_info + {shape, type} = shape_and_type(tensor_info) tensor = buffer |> binary_slice(offset_start, offset_end - offset_start) - |> build_tensor(tensor_info) + |> Shared.build_tensor(shape, type) {tensor_name, tensor} end @@ -189,14 +218,8 @@ defmodule Safetensors do header end - defp build_tensor(binary, tensor_info) do - %{"dtype" => dtype, "shape" => shape} = tensor_info - {_, elem_size} = type = dtype_to_type(dtype) - - binary - |> new_byte_order(elem_size, :little) - |> Nx.from_binary(type) - |> Nx.reshape(List.to_tuple(shape)) + defp shape_and_type(%{"dtype" => dtype, "shape" => shape}) do + {List.to_tuple(shape), dtype_to_type(dtype)} end defp type_to_dtype(type) do @@ -206,19 +229,4 @@ defmodule Safetensors do defp dtype_to_type(dtype) do @dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}" end - - defp new_byte_order(binary, size, endianness) do - if System.endianness() == endianness do - binary - else - data = - for <> do - data - |> :binary.decode_unsigned() - |> :binary.encode_unsigned(endianness) - end - - IO.iodata_to_binary(data) - end - end end diff --git a/lib/safetensors/file_tensor.ex b/lib/safetensors/file_tensor.ex new file mode 100644 index 0000000..439425f --- /dev/null +++ b/lib/safetensors/file_tensor.ex @@ -0,0 +1,20 @@ +defmodule Safetensors.FileTensor do + @moduledoc false + + defstruct [:shape, :type, :path, :byte_offset, :byte_size] +end + +defimpl Nx.LazyContainer, for: Safetensors.FileTensor do + def traverse(lazy_tensor, acc, fun) do + template = Nx.template(lazy_tensor.shape, lazy_tensor.type) + + load = fn -> + File.open!(lazy_tensor.path, [:read, :raw], fn file -> + {:ok, binary} = :file.pread(file, lazy_tensor.byte_offset, lazy_tensor.byte_size) + Safetensors.Shared.build_tensor(binary, lazy_tensor.shape, lazy_tensor.type) + end) + end + + fun.(template, load, acc) + end +end diff --git a/lib/safetensors/shared.ex b/lib/safetensors/shared.ex new file mode 100644 index 0000000..d984141 --- /dev/null +++ b/lib/safetensors/shared.ex @@ -0,0 +1,36 @@ +defmodule Safetensors.Shared do + @moduledoc false + + @doc """ + Builds Nx tensor from the given safetensors binary. + """ + @spec build_tensor(binary(), tuple(), Nx.Type.t()) :: Nx.Tensor.t() + def build_tensor(binary, shape, type) do + {_, elem_size} = type + + binary + |> new_byte_order(elem_size, :little) + |> IO.iodata_to_binary() + |> Nx.from_binary(type) + |> Nx.reshape(shape) + end + + @doc """ + Changes endianness `binary` if `endianness` does not match system. + """ + @spec new_byte_order(binary(), pos_integer(), :little | :big) :: iodata() + def new_byte_order(binary, size, endianness) do + if System.endianness() == endianness do + binary + else + data = + for <> do + data + |> :binary.decode_unsigned() + |> :binary.encode_unsigned(endianness) + end + + IO.iodata_to_binary(data) + end + end +end diff --git a/mix.lock b/mix.lock index bca9561..f497851 100644 --- a/mix.lock +++ b/mix.lock @@ -7,6 +7,6 @@ "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, + "nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, } diff --git a/test/safetensors_test.exs b/test/safetensors_test.exs index a20430b..a27808e 100644 --- a/test/safetensors_test.exs +++ b/test/safetensors_test.exs @@ -44,6 +44,23 @@ defmodule SafetensorsTest do assert Safetensors.read!(path) == %{"test" => Nx.tensor([[0, 0], [0, 0]], type: :s32)} end + @tag :tmp_dir + test "read lazy", %{tmp_dir: tmp_dir} do + path = Path.join(tmp_dir, "safetensor") + + # source: + # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40 + File.write!( + path, + ~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00) + ) + + assert %{"test" => %Safetensors.FileTensor{} = file_tensor} = + Safetensors.read!(path, lazy: true) + + assert Nx.to_tensor(file_tensor) == Nx.tensor([[0, 0], [0, 0]], type: :s32) + end + test "load" do # source: # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40