Skip to content

Commit

Permalink
Merge pull request #3 from sinzlab/master
Browse files Browse the repository at this point in the history
update master
  • Loading branch information
KonstantinWilleke authored Apr 2, 2020
2 parents 6eed380 + 69d02be commit afbefb3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
14 changes: 8 additions & 6 deletions nnvision/datasets/monkey_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __contains__(self, key):
return key in self.cache

def __getitem__(self, item):
item = item.tolist() if isinstance(item, Iterable) else item
return [self[i] for i in item] if isinstance(item, Iterable) else self.update(item)

def update(self, key):
Expand Down Expand Up @@ -123,12 +124,13 @@ def get_cached_loader(image_ids, responses, batch_size, shuffle=True, image_cach
image_ids = torch.tensor(image_ids.astype(np.int32))
responses = torch.tensor(responses).to(torch.float)
dataset = CachedTensorDataset(image_ids, responses, image_cache=image_cache)
sampler = RepeatsBatchSampler(torch.tensor(repeat_condition.astype(np.int32))) if repeat_condition is not None else None
sampler = RepeatsBatchSampler(repeat_condition) if repeat_condition is not None else None

return utils.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
batch_sampler=sampler)
dataloader = utils.DataLoader(dataset, batch_sampler=sampler) if batch_size is None else utils.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
)
return dataloader

def monkey_static_loader(dataset,
neuronal_data_files,
Expand Down Expand Up @@ -256,7 +258,7 @@ def monkey_static_loader(dataset,

train_loader = get_cached_loader(training_image_ids, responses_train, batch_size=batch_size, image_cache=cache)
val_loader = get_cached_loader(validation_image_ids, responses_val, batch_size=batch_size, image_cache=cache)
test_loader = get_cached_loader(testing_image_ids, responses_test, batch_size=1, shuffle=False,
test_loader = get_cached_loader(testing_image_ids, responses_test, batch_size=None, shuffle=None,
image_cache=cache, repeat_condition=testing_image_ids)

dataloaders["train"][data_key] = train_loader
Expand Down
10 changes: 5 additions & 5 deletions nnvision/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def full_objective(model, dataloader, data_key, *args):
model.train()

criterion = getattr(mlmeasures, loss_function)(avg=avg_loss)
stop_closure = partial(getattr(measures, stop_function), dataloaders=dataloaders["validation"], device=device, avg=True, as_dict=False)
stop_closure = partial(getattr(measures, stop_function), dataloaders=dataloaders["validation"], device=device, per_neuron=False, avg=True)

n_iterations = len(LongCycler(dataloaders["train"]))

Expand All @@ -90,8 +90,8 @@ def full_objective(model, dataloader, data_key, *args):
optim_step_count = len(dataloaders["train"].keys()) if loss_accum_batch_n is None else loss_accum_batch_n

if track_training:
tracker_dict = dict(correlation=partial(get_correlations(), model, dataloaders["validation"], device=device, avg=True),
poisson_loss=partial(get_poisson_loss(), model, dataloaders["validation"], device=device, avg=True))
tracker_dict = dict(correlation=partial(get_correlations(), model, dataloaders["validation"], device=device, per_neuron=False),
poisson_loss=partial(get_poisson_loss(), model, dataloaders["validation"], device=device, per_neuron=False, avg=False))
if hasattr(model, 'tracked_values'):
tracker_dict.update(model.tracked_values)
tracker = MultipleObjectiveTracker(**tracker_dict)
Expand Down Expand Up @@ -130,8 +130,8 @@ def full_objective(model, dataloader, data_key, *args):
tracker.finalize() if track_training else None

# Compute avg validation and test correlation
validation_correlation = get_correlations(model, dataloaders["validation"], device=device, as_dict=False, avg=False)
test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, avg=False)
validation_correlation = get_correlations(model, dataloaders["validation"], device=device, as_dict=False, per_neuron=False)
test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, per_neuron=False)

# return the whole tracker output as a dict
output = {k: v for k, v in tracker.log.items()} if track_training else {}
Expand Down
26 changes: 15 additions & 11 deletions nnvision/utility/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import contextlib


def model_predictions(loader, model, data_key, device='cpu'):
def model_predictions(dataloader, model, data_key, device='cpu'):
"""
computes model predictions for a given dataloader and a model
Returns:
Expand All @@ -16,7 +16,7 @@ def model_predictions(loader, model, data_key, device='cpu'):
"""

target, output = torch.empty(0), torch.empty(0)
for images, responses in loader:
for images, responses in dataloader:
if len(images.shape) == 5:
images = images.squeeze(dim=0)
responses = responses.squeeze(dim=0)
Expand All @@ -27,32 +27,36 @@ def model_predictions(loader, model, data_key, device='cpu'):
return target.numpy(), output.numpy()


def get_correlations(model, dataloaders, device='cpu', as_dict=False, avg=False):
def get_correlations(model, dataloaders, device='cpu', as_dict=False, per_neuron=True, **kwargs):
correlations = {}
with eval_state(model) if not isinstance(model, types.FunctionType) else contextlib.nullcontext():
for k, v in dataloaders.items():
target, output = model_predictions(loader=v, model=model, data_key=k, device=device)
target, output = model_predictions(dataloader=v, model=model, data_key=k, device=device)
correlations[k] = corr(target, output, axis=0)

if np.any(np.isnan(correlations[k])):
warnings.warn('{}% NaNs , NaNs will be set to Zero.'.format(np.isnan(correlations[k]).mean() * 100))
correlations[k][np.isnan(correlations[k])] = 0

if not as_dict:
correlations = np.mean(np.hstack([v for v in correlations.values()])) if avg else np.hstack([v for v in correlations.values()])
correlations = np.hstack([v for v in correlations.values()]) if per_neuron else np.mean(np.hstack([v for v in correlations.values()]))
return correlations


def get_poisson_loss(model, dataloaders, device='cpu', as_dict=False, avg=True, eps=1e-12):
def get_poisson_loss(model, dataloaders, device='cpu', as_dict=False, avg=False, per_neuron=False, eps=1e-12):
poisson_loss = {}
with eval_state(model) if not isinstance(model, types.FunctionType) else contextlib.nullcontext():
for k, v in dataloaders.items():
target, output = model_predictions(loader=v, model=model, data_key=k, device=device)
poisson_loss[k] = output - target * np.log(output + eps)
if not as_dict:
return np.mean(np.hstack([v for v in poisson_loss.values()])) if avg else np.hstack([v for v in poisson_loss.values()])
else:
target, output = model_predictions(dataloader=v, model=model, data_key=k, device=device)
loss = output - target * np.log(output + eps)
poisson_loss[k] = np.mean(loss, axis=0) if avg else np.sum(loss, axis=0)
if as_dict:
return poisson_loss
else:
if per_neuron:
return np.hstack([v for v in poisson_loss.values()])
else:
return np.mean(np.hstack([v for v in poisson_loss.values()])) if avg else np.sum(np.hstack([v for v in poisson_loss.values()]))


def get_repeats(dataloader, min_repeats=2):
Expand Down

0 comments on commit afbefb3

Please sign in to comment.