Skip to content

Commit

Permalink
Merge branch 'main' into fabiansinz-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxFBurg authored Mar 7, 2024
2 parents f28192e + 29206ec commit 331c5dd
Show file tree
Hide file tree
Showing 33 changed files with 2,777 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: [push, pull_request]

jobs:
black:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/isort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: [push, pull_request]

jobs:
isort:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: [push, pull_request]

jobs:
mypy:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
jobs:
test:
if: github.repository_owner == 'sinzlab'
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Run tests and generate coverage report
Expand Down
67 changes: 67 additions & 0 deletions neuralpredictors/data/datasets/movies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import h5py
import numpy as np
from scipy.signal import convolve2d
from torch.utils.data import Dataset

from ..transforms import DataTransform, Delay, MovieTransform, Subsequence
from ..utils import recursively_load_dict_contents_from_group
Expand Down Expand Up @@ -233,3 +234,69 @@ def transformed_mean(self, stats_source=None):
if self.rename_output:
x = self._output_point(*x)
return x


class NRandomSubSequenceDataset(Dataset):
"""
Data augmentation for training.
Generate a new dataset based on original_dat, by random sampling of each training item in original_dat for multiple times.
This only works for movie data and each sampling is a subsequence of the full sequence in a original_dat item.
Args:
original_dat: an original dataset
num_random_subsequence: number of subsequences sampled from each original_dat item
subsequence_length: the length of each subsequence
sequence_length: full sequence length of a training item from original_dat
seed: random seed
"""

def __init__(
self,
original_dat,
num_random_subsequence=10,
subsequence_length=100,
sequence_length=300,
seed=10,
):
new_tiers = [] # list, tiers for each item in new dataset
new_inds = [] # list, indice for each item in new dataset
for ii, tier in enumerate(original_dat.trial_info.tiers):
if tier != "none": # if tier!='none' and ii<15:
# print (ii, dat2[ii]._fields, tier)
if tier == "train":
new_tiers.extend(["train"] * num_random_subsequence)
new_inds.extend([ii] * num_random_subsequence)
else:
new_tiers.append(tier)
new_inds.append(ii)

self.original_dat = original_dat
self.new_tiers = new_tiers
self.new_inds = new_inds
np.random.seed(seed)
self.random_start = np.random.randint(
low=0, high=sequence_length - subsequence_length, size=num_random_subsequence
) # array, start positions at each original_dat item for random sampling
self.num4rand = len(self.random_start)
self.random_end = self.random_start + subsequence_length

def __getitem__(self, index):
if self.new_tiers[index] == "train":
return self.original_dat[self.new_inds[index]].__class__(
**{
k: getattr(self.original_dat[self.new_inds[index]], k)[
:,
self.random_start[index % self.num4rand] : self.random_end[index % self.num4rand],
]
for k in self.original_dat[self.new_inds[index]]._fields
}
)

else:
return self.original_dat[self.new_inds[index]]

def __len__(self):
return len(self.new_tiers)

@property
def neurons(self):
return self.original_dat.neurons
Loading

0 comments on commit 331c5dd

Please sign in to comment.