From 320a0626925bb2fcc2eec1c358a6b232142f6ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 1 Sep 2023 19:36:42 +0200 Subject: [PATCH] feat: allow passing device when loading a model --- edspdf/pipeline.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index 51359064..c60aa859 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -700,6 +700,7 @@ def load_state_from_disk( path: Union[str, Path], *, exclude: Set[str] = None, + device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821 ) -> "Pipeline": """ Load the pipeline from a directory. Components will be updated in-place. @@ -730,7 +731,9 @@ def deserialize_tensors(path: Path): # are expected to be shared pipe = trainable_components[pipe_names[0]] tensor_dict = {} - for keys, tensor in safetensors.torch.load_file(file_name).items(): + for keys, tensor in safetensors.torch.load_file( + file_name, device=device + ).items(): split_keys = [split_path(key) for key in keys.split("+")] key = next(key for key in split_keys if key[0] == pipe_names[0]) tensor_dict[join_path(key[1:])] = tensor @@ -769,11 +772,12 @@ def load( path: Union[str, Path], *, exclude: Optional[Set[str]] = None, + device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821 ): path = Path(path) if isinstance(path, str) else path config = Config.from_disk(path / "config.cfg") self = Pipeline.from_config(config) - self.load_state_from_disk(path, exclude=exclude) + self.load_state_from_disk(path, exclude=exclude, device=device) return self # override config property getter to remove "factory" key from components @@ -854,12 +858,15 @@ def __exit__(ctx_self, type, value, traceback): return context() -def load(config: Union[Path, str, Config]) -> Pipeline: +def load( + config: Union[Path, str, Config], + device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821 +) -> Pipeline: error = "The load function expects a Config or a path to a config file" if isinstance(config, (Path, str)): path = Path(config) if path.is_dir(): - return Pipeline.load(path) + return Pipeline.load(path, device=device) elif path.is_file(): config = Config.from_disk(path) else: