Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix lifetimes, add tensors iterator #518

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ impl<'data> SafeTensors<'data> {
Ok(Self { metadata, data })
}

/// Allow the user to iterate over tensors within the SafeTensors.
/// Returns the tensors contained within the SafeTensors.
/// The tensors returned are merely views and the data is not owned by this
/// structure.
pub fn tensors(&self) -> Vec<(String, TensorView<'_>)> {
pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> {
let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
for (name, &index) in &self.metadata.index_map {
let info = &self.metadata.tensors[index];
Expand All @@ -362,10 +362,24 @@ impl<'data> SafeTensors<'data> {
tensors
}

/// Returns an iterator over the tensors contained within the SafeTensors.
/// The tensors returned are merely views and the data is not owned by this
/// structure.
pub fn iter<'a>(&'a self) -> impl Iterator<Item = (&'a str, TensorView<'data>)> {
self.metadata.index_map.iter().map(|(name, &idx)| {
let info = &self.metadata.tensors[idx];
(name.as_str(), TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data: &self.data[info.data_offsets.0..info.data_offsets.1],
})
})
}

/// Allow the user to get a specific tensor within the SafeTensors.
/// The tensor returned is merely a view and the data is not owned by this
/// structure.
pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'_>, SafeTensorError> {
pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'data>, SafeTensorError> {
if let Some(index) = &self.metadata.index_map.get(tensor_name) {
if let Some(info) = &self.metadata.tensors.get(**index) {
Ok(TensorView {
Expand Down Expand Up @@ -541,7 +555,7 @@ impl Metadata {
/// A view of a Tensor within the file.
/// Contains references to data within the full byte-buffer
/// And is thus a readable view of a single tensor
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct TensorView<'data> {
dtype: Dtype,
shape: Vec<usize>,
Expand Down Expand Up @@ -1038,6 +1052,21 @@ mod tests {
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}

#[test]
fn test_lifetimes() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";

let tensor = {
let loaded = SafeTensors::deserialize(serialized).unwrap();
loaded.tensor("test").unwrap()
};

assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(tensor.dtype(), Dtype::I32);
// 16 bytes
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}

#[test]
fn test_json_attack() {
let mut tensors = HashMap::new();
Expand Down