Skip to content

Commit

Permalink
iterator and concatenate_two_sets that can deal with per-time-point y
Browse files Browse the repository at this point in the history
  • Loading branch information
robintibor committed Oct 3, 2017
1 parent 3694b49 commit a294170
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 41 deletions.
1 change: 1 addition & 0 deletions braindecode/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from braindecode.version import __version__
13 changes: 9 additions & 4 deletions braindecode/datautil/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _yield_block_batches(self, X, y, start_stop_blocks_per_trial, shuffle):
for i_blocks in blocks_per_batch:
start_stop_blocks = i_trial_start_stop_block[i_blocks]
batch = _create_batch_from_i_trial_start_stop_blocks(
X, y, start_stop_blocks)
X, y, start_stop_blocks, self.n_preds_per_input)
yield batch


Expand All @@ -231,7 +231,7 @@ def _compute_start_stop_block_inds(i_trial_starts, i_trial_stops,
----------
i_trial_starts: 1darray/list of int
Indices of first samples to predict(!).
i_trial_stops: 1daray/list of int
i_trial_stops: 1darray/list of int
Indices one past last sample to predict.
input_time_length: int
n_preds_per_input: int
Expand Down Expand Up @@ -299,12 +299,17 @@ def _get_start_stop_blocks_for_trial(i_trial_start, i_trial_stop,
return start_stop_blocks


def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block):
def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block,
n_preds_per_input=None):
Xs = []
ys = []
for i_trial, start, stop in i_trial_start_stop_block:
Xs.append(X[i_trial][:,start:stop])
ys.append(y[i_trial])
if not hasattr(y[i_trial], '__len__'):
ys.append(y[i_trial])
else:
assert n_preds_per_input is not None
ys.append(y[i_trial][stop-n_preds_per_input:stop])
batch_X = np.array(Xs)
batch_y = np.array(ys)
# add empty fourth dimension if necessary
Expand Down
23 changes: 14 additions & 9 deletions braindecode/datautil/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,22 @@ def concatenate_two_sets(set_a, set_b):
-------
concatenated_set: :class:`.SignalAndTarget`
"""
if hasattr(set_a.X, 'ndim') and hasattr(set_b.X, 'ndim'):
new_X = np.concatenate((set_a.X, set_b.X), axis=0)
else:
if hasattr(set_a.X, 'ndim'):
set_a.X = set_a.X.tolist()
if hasattr(set_b.X, 'ndim'):
set_b.X = set_b.X.tolist()
new_X = set_a.X + set_b.X
new_y = np.concatenate((set_a.y, set_b.y), axis=0)
new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X)
new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y)
return SignalAndTarget(new_X, new_y)

def concatenate_np_array_or_add_lists(a, b):
if hasattr(a, 'ndim') and hasattr(b, 'ndim'):
new = np.concatenate((a, b), axis=0)
else:
if hasattr(a, 'ndim'):
a = a.tolist()
if hasattr(b, 'ndim'):
b = b.tolist()
new = a + b
return new



def split_into_two_sets(dataset, first_set_fraction=None, n_first_set=None):
"""
Expand Down
39 changes: 37 additions & 2 deletions braindecode/torch_ext/losses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch as th


def log_categorical_crossentropy(logpreds, targets, dims=None):
def log_categorical_crossentropy_1_hot(logpreds, targets, dims=None):
"""
Returns log categorical crossentropy for given log-predictions and targets.
Returns log categorical crossentropy for given log-predictions and targets,
targets should be one-hot-encoded.
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}`
Expand All @@ -12,6 +13,7 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
logpreds: `torch.autograd.Variable`
Logarithm of softmax output.
targets: `torch.autograd.Variable`
One-hot encoded targets
dims: int or iterable of int, optional.
Compute sum across these dims
Expand All @@ -31,6 +33,39 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
return result


def log_categorical_crossentropy(log_preds, targets):
"""
Returns log categorical crossentropy for given log-predictions and targets.
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}` if you assume
targets to be one-hot-encoded. Also works for targets that are not
one-hot-encoded, in this case only uses targets that are in the range
of the expected class labels, i.e., [0,log_preds.size()[1]-1].
Parameters
----------
log_preds: torch.autograd.Variable`
Logarithm of softmax output.
targets: `torch.autograd.Variable`
Returns
-------
loss: `torch.autograd.Variable`
"""
if log_preds.size() == targets.size():
return log_categorical_crossentropy_1_hot(log_preds, targets)
n_classes = log_preds.size()[1]
n_elements = 0
losses = []
for i_class in range(n_classes):
mask = targets == i_class
mask = mask.type_as(log_preds)
n_elements -= th.sum(mask)
losses.append(th.sum(mask * log_preds[:,i_class]))
return th.sum(th.stack(losses)) / n_elements


def l2_loss(model):
losses = [th.sum(p * p) for p in model.parameters()]
return sum(losses)
Expand Down
1 change: 1 addition & 0 deletions braindecode/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.0"
10 changes: 10 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ apidoc:
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


removeipynbcheckpoints: Makefile
rm -rf notebooks/.ipynb_checkpoints/ notebooks/visualization/.ipynb_checkpoints/

removesource: Makefile
rm -rf source/

rmanddoc: removesource removeipynbcheckpoints apidoc html
echo "Done"
27 changes: 6 additions & 21 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def find_source():
## Default flags used by autodoc directives
autodoc_default_flags = ['members', 'show-inheritance']

exclude_patterns = ['_build', '_templates',]

exclude_patterns = ['_build', '_templates']


napoleon_google_docstring = False
napoleon_use_param = False
napoleon_use_ivar = True
Expand All @@ -110,9 +111,10 @@ def find_source():
# built documents.
#
# The short X.Y version.
version = '0.1.9'
import braindecode
version = braindecode.__version__
# The full version, including alpha/beta/rc tags.
release = '0.1.9'
release = version

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down Expand Up @@ -211,20 +213,3 @@ def find_source():
author, 'Braindecode', 'One line description of project.',
'Miscellaneous'),
]


## mock stuff
"""
import sys
from mock import Mock as MagicMock
#from unittest.mock import MagicMock
class Mock(MagicMock):
@classmethod
def __getattr__(cls, name):
return MagicMock()
MOCK_MODULES = ['torch', 'h5py',]
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
"""
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Welcome to Braindecode

A deep learning toolbox to decode raw time-domain EEG.

For EEG researchers that want to want to work with deep learning and
For EEG researchers that want to work with deep learning and
deep learning researchers that want to work with EEG data.
For now focussed on convolutional networks.

Expand Down
3 changes: 3 additions & 0 deletions docs/notebooks/Cropped_Decoding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@
}
],
"metadata": {
"git": {
"keep_outputs": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
Expand Down
8 changes: 8 additions & 0 deletions docs/source/braindecode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,12 @@ braindecode\.util module
:undoc-members:
:show-inheritance:

braindecode\.version module
---------------------------

.. automodule:: braindecode.version
:members:
:undoc-members:
:show-inheritance:


11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
long_description = f.read()


# This will add __version__ to version dict
version = {}
with open(path.join(here, 'braindecode/version.py'), encoding='utf-8') as (
version_file):
exec(version_file.read(), version)

setup(
name='Braindecode',

# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# http://packaging.python.org/en/latest/tutorial.html#version
version='0.1.9', # TODO: read from __init__.py?
version=version['__version__'],

description='A deep learning toolbox to decode raw time-domain EEG.',
long_description=long_description,
Expand Down

0 comments on commit a294170

Please sign in to comment.