Skip to content

Commit

Permalink
feat: allow passing device when loading a model
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 1, 2023
1 parent 2d117f3 commit 95ba47c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions edspdf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def load_state_from_disk(
path: Union[str, Path],
*,
exclude: Set[str] = None,
device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
) -> "Pipeline":
"""
Load the pipeline from a directory. Components will be updated in-place.
Expand Down Expand Up @@ -836,7 +837,9 @@ def deserialize_tensors(path: Path):
# are expected to be shared
pipe = trainable_components[pipe_names[0]]
tensor_dict = {}
for keys, tensor in safetensors.torch.load_file(file_name).items():
for keys, tensor in safetensors.torch.load_file(
file_name, device=device
).items():
split_keys = [split_path(key) for key in keys.split("+")]
key = next(key for key in split_keys if key[0] == pipe_names[0])
tensor_dict[join_path(key[1:])] = tensor
Expand Down Expand Up @@ -875,11 +878,12 @@ def load(
path: Union[str, Path],
*,
exclude: Optional[Set[str]] = None,
device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
):
path = Path(path) if isinstance(path, str) else path
config = Config.from_disk(path / "config.cfg")
self = Pipeline.from_config(config)
self.load_state_from_disk(path, exclude=exclude)
self.load_state_from_disk(path, exclude=exclude, device=device)
return self

# override config property getter to remove "factory" key from components
Expand Down Expand Up @@ -941,12 +945,15 @@ def select_pipes(
self._disabled = disabled_before


def load(config: Union[Path, str, Config]):
def load(
config: Union[Path, str, Config],
device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
):
error = "The load function expects a Config or a path to a config file"
if isinstance(config, (Path, str)):
path = Path(config)
if path.is_dir():
return Pipeline.load(path)
return Pipeline.load(path, device=device)
elif path.is_file():
config = Config.from_disk(path)
else:
Expand Down

0 comments on commit 95ba47c

Please sign in to comment.