Skip to content

Commit

Permalink
Debug local mode
Browse files Browse the repository at this point in the history
  • Loading branch information
angerhang committed Apr 28, 2023
1 parent 0ece7c2 commit 76c76a2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
5 changes: 3 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def sleepnet(pretrained=True, my_device="cpu", num_classes=2, lstm_nn_size=128,
if pretrained:
if len(local_weight_path) > 0:
print("Loading local weight from %s" % local_weight_path)
state_dict = torch.load(local_weight_path)
state_dict = torch.load(local_weight_path,
map_location=torch.device(my_device))
model.load_state_dict(
state_dict, map_location=torch.device(my_device))
state_dict)
else:
checkpoint = 'https://github.com/OxWearables/asleep/' \
'releases/download/0.0.3/bi_sleepnet.mdl'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "asleep" # Required
#
# For a discussion on single-sourcing the version, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version = "0.3.7" # Required
version = "0.3.8" # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
7 changes: 6 additions & 1 deletion src/asleep/get_sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def main():
"you might want to specify the path to the model weight file",
type=str,
default='')
parser.add_argument(
"--local",
help="Load model definition from local repo",
action="store_true")
parser.add_argument(
"--min_wear",
"-m",
Expand Down Expand Up @@ -248,7 +252,8 @@ def main():
master_acc, master_npids) = get_sleep_windows(data2model, times, args)

y_pred, test_pids = start_sleep_net(
master_acc, master_npids, args.model_weight_path, args.outdir)
master_acc, master_npids, args.outdir,
args.model_weight_path, is_local_repo=args.local)
sleepnet_output = binary_y

for block_id in range(len(all_sleep_wins_df)):
Expand Down
21 changes: 13 additions & 8 deletions src/asleep/sleepnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import gzip
import os.path
from pathlib import Path


# Model utils
from asleep.utils import cnnLSTMInFerDataset, cnn_lstm_infer_collate, \
Expand Down Expand Up @@ -102,11 +104,12 @@ def config_device(cfg):


def setup_cnn(cfg, my_device, weight_path, is_local_repo=False):

print("setting up cnn")
if is_local_repo:
dirname = os.path.dirname(__file__)
repo_path = os.path.join(dirname, '..', '..') # path containing hubconf.py
model = torch.hub.load(repo_path,
print("access local repo")
dirname = Path(__file__).parent.parent.parent
print(dirname)
model = torch.hub.load(dirname,
'sleepnet',
source='local',
num_classes=cfg.data.num_classes,
Expand All @@ -119,6 +122,8 @@ def setup_cnn(cfg, my_device, weight_path, is_local_repo=False):
trust_repo=True
)
else:
print("access remote repo")

repo = 'OxWearables/asleep'
model = torch.hub.load(repo,
'sleepnet',
Expand Down Expand Up @@ -161,11 +166,11 @@ def align_output(y_red, real_pid, test_pid):
return np.array(aligned_pred)


def sleepnet_inference(X, pid, weight_path, cfg):
def sleepnet_inference(X, pid, weight_path, cfg, is_local_repo=False):
start = time.time()
my_device = config_device(cfg)

model = setup_cnn(cfg, my_device, weight_path)
model = setup_cnn(cfg, my_device, weight_path, is_local_repo=is_local_repo)
test_loader = setup_dataset(X, pid, cfg)

test_y_pred, test_pid, test_probs = forward_batches(
Expand Down Expand Up @@ -195,7 +200,7 @@ def sleepnet_inference(X, pid, weight_path, cfg):
return aligned_y_pred, test_pid


def start_sleep_net(X, pid, data_root, weight_path, device_id=-1):
def start_sleep_net(X, pid, data_root, weight_path, device_id=-1, is_local_repo=False):
initialize(config_path="conf")
cfg = compose(
"config_eval",
Expand All @@ -209,4 +214,4 @@ def start_sleep_net(X, pid, data_root, weight_path, device_id=-1):
)
if cfg.verbose:
print(OmegaConf.to_yaml(cfg, resolve=True))
return sleepnet_inference(X, pid, weight_path, cfg)
return sleepnet_inference(X, pid, weight_path, cfg, is_local_repo=is_local_repo)

0 comments on commit 76c76a2

Please sign in to comment.