From 123adbc2524375cc9c7ff3b1d3cce4fe03c28477 Mon Sep 17 00:00:00 2001 From: Georgijs Vilums Date: Fri, 16 Aug 2024 16:58:30 -0700 Subject: [PATCH 1/2] fix lifetimes, add tensors iterator --- safetensors/src/tensor.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index e5b8b4b8..1e8d5d8f 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -345,10 +345,10 @@ impl<'data> SafeTensors<'data> { Ok(Self { metadata, data }) } - /// Allow the user to iterate over tensors within the SafeTensors. + /// Returns the tensors contained within the SafeTensors. /// The tensors returned are merely views and the data is not owned by this /// structure. - pub fn tensors(&self) -> Vec<(String, TensorView<'_>)> { + pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> { let mut tensors = Vec::with_capacity(self.metadata.index_map.len()); for (name, &index) in &self.metadata.index_map { let info = &self.metadata.tensors[index]; @@ -362,10 +362,24 @@ impl<'data> SafeTensors<'data> { tensors } + /// Returns an iterator over the tensors contained within the SafeTensors. + /// The tensors returned are merely views and the data is not owned by this + /// structure. + pub fn iter<'a>(&'a self) -> impl Iterator)> { + self.metadata.index_map.iter().map(|(name, &idx)| { + let info = &self.metadata.tensors[idx]; + (name.as_str(), TensorView { + dtype: info.dtype, + shape: info.shape.clone(), + data: &self.data[info.data_offsets.0..info.data_offsets.1], + }) + }) + } + /// Allow the user to get a specific tensor within the SafeTensors. /// The tensor returned is merely a view and the data is not owned by this /// structure. - pub fn tensor(&self, tensor_name: &str) -> Result, SafeTensorError> { + pub fn tensor(&self, tensor_name: &str) -> Result, SafeTensorError> { if let Some(index) = &self.metadata.index_map.get(tensor_name) { if let Some(info) = &self.metadata.tensors.get(**index) { Ok(TensorView { @@ -541,7 +555,7 @@ impl Metadata { /// A view of a Tensor within the file. /// Contains references to data within the full byte-buffer /// And is thus a readable view of a single tensor -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct TensorView<'data> { dtype: Dtype, shape: Vec, From a3011347745db0516da01ae85d5c4239da6e3cc8 Mon Sep 17 00:00:00 2001 From: Georgijs Vilums Date: Wed, 4 Sep 2024 14:49:03 -0700 Subject: [PATCH 2/2] add test for lifetimes --- safetensors/src/tensor.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index 1e8d5d8f..a288663e 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -1052,6 +1052,21 @@ mod tests { assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); } + #[test] + fn test_lifetimes() { + let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; + + let tensor = { + let loaded = SafeTensors::deserialize(serialized).unwrap(); + loaded.tensor("test").unwrap() + }; + + assert_eq!(tensor.shape(), vec![2, 2]); + assert_eq!(tensor.dtype(), Dtype::I32); + // 16 bytes + assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); + } + #[test] fn test_json_attack() { let mut tensors = HashMap::new();