From 43e9d9483ab92cf8a2135fc354ab3be66cd9c785 Mon Sep 17 00:00:00 2001 From: achaiah Date: Tue, 26 Mar 2019 11:30:38 -0500 Subject: [PATCH] fixing an issue with checkpoint loading --- pywick/models/model_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pywick/models/model_utils.py b/pywick/models/model_utils.py index 410eef0..e01af70 100644 --- a/pywick/models/model_utils.py +++ b/pywick/models/model_utils.py @@ -410,7 +410,7 @@ def load_checkpoint(checkpoint_path, model=None, device='cpu', strict=True, igno which device to load model onto (default:'cpu') :param strict: bool whether to ensure strict key matching (True) or to ignore non-matching keys. (default: True) - :param ignore_chkpt_layers: one of {string, list) + :param ignore_chkpt_layers: one of {string, list) -- CURRENTLY UNIMPLEMENTED whether to ignore some subset of layers from checkpoint. This is usually done when loading checkpoint data into a model with a different number of final classes. In that case, you can pass in a special string: 'last_layer' which will trigger the logic to chop off the last layer of the checkpoint dictionary. Otherwise @@ -439,7 +439,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac checkpoint_path = os.path.expanduser(checkpoint_path) if os.path.isfile(checkpoint_path): print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device)) - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location=device) pretrained_state = checkpoint['state_dict'] print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch'))) @@ -469,5 +469,5 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac # finally load the model weights if model: print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict)) - model.load_state_dict(checkpoint['state_dict'], map_location=device, strict=strict) + model.load_state_dict(checkpoint['state_dict'], strict=strict) return checkpoint