Skip to content

Commit

Permalink
Sleepnet (#13)
Browse files Browse the repository at this point in the history
* Include sleepnet and refactoring
  • Loading branch information
angerhang authored Apr 10, 2023
1 parent ac13267 commit 581e67a
Show file tree
Hide file tree
Showing 21 changed files with 1,114 additions and 152 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ outputs

src/asleep/torch_hub_cache
assets/*.joblib.*
*.ipynb_checkpoints
*.ipynb_checkpoints
*.ipynb
src/asleep/models
5 changes: 4 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
include tox.ini
recursive-include tests *.py
recursive-include assets *.jpg
recursive-include assets *.jpg *.lzma
recursive-include src *.yaml
exclude src/asleep/models *.mdl
recursive-exclude src/asleep/torch_hub_cache *
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ $ pip install asleep
```

# Usage
All the processing will be much faster after the first time because the model weights will to have to be downloaded
the first time that the package is used.
```shell
# Process an AX3 file
$ get_sleep sample.cwa
Expand Down Expand Up @@ -88,5 +90,5 @@ TBD

# Acknowledgements
We would like to thank all our code contributors, manuscript co - authors, and research participants for their help in making this work possible. The
data processing pipeline of this repository is based on the[step_count](https: // github.com / OxWearables / stepcount # processing-csv-files) package from our group. Special
thanks to @ chanshing for his help in developing the package.
data processing pipeline of this repository is based on the [step_count](https://github.com/OxWearables/stepcount) package from our group. Special
thanks to @chanshing for his help in developing the package.
Binary file added assets/ssl.joblib.lzma
Binary file not shown.
26 changes: 26 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from src.asleep.models import CNNLSTM, weight_init
import torch

dependencies = ["torch"]


def sleepnet(pretrained=True, my_device="cpu", class_num=2, lstm_nn_size=128,
dropout_p=0.5, bi_lstm=True, lstm_layer=1):
model = CNNLSTM(
num_classes=class_num,
model_device=my_device,
lstm_nn_size=lstm_nn_size,
dropout_p=dropout_p,
bidrectional=bi_lstm,
lstm_layer=lstm_layer,
)
weight_init(model)

if pretrained:
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint,
progress=True,
map_location=torch.device(my_device)))
model.to(my_device, dtype=torch.float)
return model
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ dependencies = [ # Optional
"torch",
"torchvision",
"transforms3d",
"stepcount"
"stepcount",
"hydra-core",
]

# List additional groups of dependencies here (e.g. development
Expand Down
Empty file added src/__init__.py
Empty file.
Empty file added src/asleep/conf/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions src/asleep/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defaults:
- model: baseline
- data: vm_data
- _self_

gpu: 0
num_epoch: 200
multi_gpu: false
gpu_ids: [0, 1, 2, 3] # if more than one gpu
verbose: false
patience: 10
deployment: false
num_split: 5
test_size: 0.2
validation_fold: -1 # validation fold < num_split used for cnn cv
isDebug: false
train_master: false
specific_fold: -1
15 changes: 15 additions & 0 deletions src/asleep/conf/config_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- model: cnn_lstm_eval
- data: deployment
- _self_

gpu: -1
num_epoch: 200
multi_gpu: false
gpu_ids: [0, 1, 2, 3] # if more than one gpu
verbose: false
patience: 10
deployment: true
num_split: 5
test_size: 0.2
validation_fold: -1 # validation fold < num_split used for cnn cv
16 changes: 16 additions & 0 deletions src/asleep/conf/data/deployment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
data_root: /Users/hangy/data
X_zip_path: "${data.data_root}/X.npy.gz"
X_path: "${data.data_root}/X.npy"
PID_path: "${data.data_root}/npid.npy"
subject_file: ""
y_pred_path: "${data.data_root}/y_pred.npy"
y_prob_path: "${data.data_root}/pred_prob.npy"
pid_pred_path: "${data.data_root}/PID_pred.npy"
use_gen2: False
context_path: "${data.data_root}/context.pkl"
log_path: /home/cxx579/raine/logs
win_length: 1
meta_data_size: 0
test_ratio: 0.2
val_ratio: 0.2
num_classes: 5
3 changes: 3 additions & 0 deletions src/asleep/conf/model/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
learning_rate: 0.0001
isSmall: false
name: "baseline_WinLen${data.win_length}_BiDi${data.isBidirectional}_Meta${data.meta_data_size}_Aug${augment}"
12 changes: 12 additions & 0 deletions src/asleep/conf/model/cnn_lstm_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
deployment: false
deployment_path: "/home/cxx579/raine/predictions"
# eval_weight_path: "/data/UKBB/final_models/sleepnet_12_16.mdl"
isSmall: false
bi_lstm: True
lstm_nn_size: 1024
lstm_layer: 2
batch_size: 2 # indicate the number of subjects for cnnlstm
model_name: 'cnn_lstm'
dropout_p: 0
multi_task: false
augment: 'cnn_lstm'
Loading

0 comments on commit 581e67a

Please sign in to comment.