Skip to content

Commit

Permalink
test pipeline fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KyloRen1 committed Mar 13, 2023
1 parent f9a41ad commit af7d839
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion configs/mobilenet_lstm_m2o.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ model_kwargs:
pretrained: True
temporal:
name: lstm
hidden_size: 384
hidden_size: 256
dropout: 0.3
n_layers: 2
data_kwargs:
Expand Down
6 changes: 4 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def parse_arguments():
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = create_model(experiment_config).to(device)
model.load_state_dict(torch.load(args.checkpoint_path))
pretrained_weights = torch.load(args.checkpoint_path, map_location=device)
model.load_state_dict(pretrained_weights)
print('loaded weights')


test_model(
experiment_config,
Expand Down

0 comments on commit af7d839

Please sign in to comment.