Skip to content

Commit

Permalink
Add support for device in safetensors.torch.load_model (#449)
Browse files Browse the repository at this point in the history
* fix typo

* allow loading torch model to device

* fix device type

* update device type
  • Loading branch information
Wauplin authored Apr 15, 2024
1 parent c3fca01 commit ff643a8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion bindings/python/py_src/safetensors/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular paddle device locations
Expand Down
15 changes: 9 additions & 6 deletions bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def save_model(
raise ValueError(msg)


def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]:
def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]:
"""
Loads a given filename onto a torch model.
This method exists specifically to avoid tensor sharing issues which are
Expand All @@ -185,16 +185,19 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict
filename (`str`, or `os.PathLike`):
The filename location to load the file from.
strict (`bool`, *optional*, defaults to True):
Wether to fail if you're missing keys or having unexpected ones
Whether to fail if you're missing keys or having unexpected ones.
When false, the function simply returns missing and unexpected names.
device (`Union[str, int]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular torch device locations.
Returns:
`(missing, unexpected): (List[str], List[str])`
`missing` are names in the model which were not modified during loading
`unexpected` are names that are on the file, but weren't used during
the load.
"""
state_dict = load_file(filename)
state_dict = load_file(filename, device=device)
model_state_dict = model.state_dict()
to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys())
missing, unexpected = model.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -281,16 +284,16 @@ def save_file(
serialize_file(_flatten(tensors), filename, metadata=metadata)


def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torch.Tensor]:
def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]:
"""
Loads a safetensors file into torch format.
Args:
filename (`str`, or `os.PathLike`):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
device (`Union[str, int]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular torch device locations
available options are all regular torch device locations.
Returns:
`Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`
Expand Down

0 comments on commit ff643a8

Please sign in to comment.