Skip to content

Commit

Permalink
fixing an issue with checkpoint loading
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah authored and achaiah committed Mar 26, 2019
1 parent afbbd54 commit 43e9d94
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pywick/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def load_checkpoint(checkpoint_path, model=None, device='cpu', strict=True, igno
which device to load model onto (default:'cpu')
:param strict: bool
whether to ensure strict key matching (True) or to ignore non-matching keys. (default: True)
:param ignore_chkpt_layers: one of {string, list)
:param ignore_chkpt_layers: one of {string, list) -- CURRENTLY UNIMPLEMENTED
whether to ignore some subset of layers from checkpoint. This is usually done when loading
checkpoint data into a model with a different number of final classes. In that case, you can pass in a
special string: 'last_layer' which will trigger the logic to chop off the last layer of the checkpoint dictionary. Otherwise
Expand Down Expand Up @@ -439,7 +439,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
checkpoint_path = os.path.expanduser(checkpoint_path)
if os.path.isfile(checkpoint_path):
print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device))
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)

pretrained_state = checkpoint['state_dict']
print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch')))
Expand Down Expand Up @@ -469,5 +469,5 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
# finally load the model weights
if model:
print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict))
model.load_state_dict(checkpoint['state_dict'], map_location=device, strict=strict)
model.load_state_dict(checkpoint['state_dict'], strict=strict)
return checkpoint

0 comments on commit 43e9d94

Please sign in to comment.