Skip to content

Commit

Permalink
Infrastructure and hooks
Browse files Browse the repository at this point in the history
* Removed dynamic instance variables in `Component`.
* Component `dt` now treated as an optional keyword argument.

* Added monitor forward hooks (in-memory buffers).
* Monitors and legacy data hooks now support custom attribute logging.
* Runtime improvements to encourage hook usage in practice.
* Added monitor example usage to basic_experiment tutorial.
* Minor file/directory creation fixes.

* Connections and weights now passed as arguments to Synapse, facilitating programmatic initialization with new config classes.
* Synapse connection topology (mode and proportion) now sought and applied at instantiation (`kwargs.pop()``).

* Model class no longer inherits from BaseEstimator.
* Cleaned up method signatures.
* Added private method for batch serving.
* Base class `fit()` now invokes `_serve()`.
* Moved legacy `draw` method from Model to Network.

* Misc docstring improvements.
* Legacy tests passing.

* Bump to 0.4.0.
  • Loading branch information
rm875 committed Sep 3, 2024
1 parent 840798b commit 1390191
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
4 changes: 2 additions & 2 deletions sapicore/engine/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ def draw(self, path: str, node_size: int = 750):
"""
plt.figure()
nx.draw(self.graph, node_size=node_size, with_labels=True, pos=nx.kamada_kawai_layout(self.network.graph))
nx.draw(self.graph, node_size=node_size, with_labels=True, pos=nx.kamada_kawai_layout(self.graph))

plt.savefig(fname=os.path.join(path, self.network.identifier + ".svg"))
plt.savefig(fname=os.path.join(path, self.identifier + ".svg"))
plt.clf()

def _in_edges(self, node: str) -> list[Synapse]:
Expand Down
20 changes: 11 additions & 9 deletions sapicore/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
:class:`~engine.network.Network` output for practical purposes.
"""
from typing import Callable
from typing import Callable, Sequence

import dill
import os
Expand Down Expand Up @@ -38,14 +38,14 @@ def __init__(self, network: Network = None, **kwargs):
# store a reference to one network object.
self.network = network

def _serve(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int | list[int] = 0):
def _serve(self, data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0):
"""Serves a batch of data to this model's network.
Each sample `i` is presented for `duration[i]`, followed by all 0s stimulation for `rinse[i]`.
"""
# wrap 2D tensor data in a list if need be, to make subsequent single- and multi-root operations uniform.
if not isinstance(data, list):
if not isinstance(data, Sequence):
data = [data]

num_samples = data[0].shape[0]
Expand All @@ -67,23 +67,25 @@ def _serve(self, data: Tensor | list[Tensor], duration: int | list[int], rinse:
# advance progress bar.
bar()

def fit(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int | list[int] = 0, **kwargs):
def fit(
self, data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0, **kwargs
):
"""Applies :meth:`engine.network.Network.forward` sequentially on a block of buffer `data`,
then turns off learning for the network.
The training buffer may be obtained, e.g., from a :class:`~data.sampling.CV` cross validator object.
Parameters
----------
data: Tensor or list of Tensor
data: Tensor or Sequence of Tensor
2D tensor(s) of data buffer to be fed to the root ensemble(s) of this object's `network`,
formatted sample X feature.
duration: int or list of int
duration: int or Sequence of int
Duration of sample presentation. Simulates duration of exposure to a particular input.
If a list or a tensor is provided, the i-th sample in the batch is maintained for `duration[i]` steps.
rinse: int or list of int
rinse: int or Sequence of int
Null stimulation steps (0s in-between samples).
If a list or a tensor is provided, the i-th sample is followed by `rinse[i]` rinse steps.
Expand All @@ -99,7 +101,7 @@ def fit(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int
for synapse in self.network.get_synapses():
synapse.set_learning(False)

def predict(self, data: Data | Tensor) -> Tensor:
def predict(self, data: Data | Tensor, **kwargs) -> Tensor:
"""Predicts the labels of `data` by feeding the buffer to a trained network and applying
some procedure to the resulting population/readout layer response.
Expand All @@ -116,7 +118,7 @@ def predict(self, data: Data | Tensor) -> Tensor:
"""
raise NotImplementedError

def similarity(self, data: Tensor, metric: str | Callable) -> Tensor:
def similarity(self, data: Tensor, metric: str | Callable, **kwargs) -> Tensor:
"""Performs rudimentary similarity analysis on the network's responses to `data`,
yielding a pairwise distance matrix.
Expand Down
4 changes: 2 additions & 2 deletions sapicore/pipeline/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run(self):
model.network.add_data_hook(data_dir, steps)

# save an SVG plot of the network architecture.
model.draw(path=run_dir)
model.network.draw(path=run_dir)

if not self.data:
steps = self.configuration.get("simulation", {}).get("steps", {})
Expand All @@ -159,7 +159,7 @@ def run(self):

# fit the model by passing `data` buffer to the network (`duration` controls exposure time to each sample).
logging.info(f"Simulating {steps} steps at a resolution of {DT} ms.")
model.fit(data=self.data, repetitions=self.configuration.get("duration", 1))
model.fit(data=self.data, duration=self.configuration.get("duration", 1))

# optional tensorboard logging.
if tensorboard:
Expand Down
2 changes: 1 addition & 1 deletion sapicore/tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_data_pipeline(self):
models.append(Model(Network()))

# repeats each sample for a random number of steps (simulating variable exposure durations).
models[i].fit(data[train], repetitions=torch.randint(low=2, high=7, size=(data[train].shape[0],)))
models[i].fit(data[train], duration=torch.randint(low=2, high=7, size=(data[train].shape[0],)))

@pytest.mark.parametrize(
"url_",
Expand Down
15 changes: 9 additions & 6 deletions sapicore/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,28 @@ def monitor_hook(self) -> Callable:
def fn(_, __, output):
# add current outputs to this data accumulator instance's cache dictionary, whose values are tensors.
for attr in self.attributes:
odim = output[attr].dim()
if attr not in self.cache.keys():
# the cache dictionary is empty because this is the first iteration.
if self.entries is None:
# expand output by one dimension (zero axis) to fit.
self.cache[attr] = output[attr][None, :]
self.cache[attr] = output[attr][None, :] if odim == 1 else output[attr][None, :, :]
else:
# preallocate if number of steps is known.
self.cache[attr] = torch.empty((self.entries, len(output[attr])))
dim = [self.entries, len(output[attr])] + ([output[attr].shape[1]] if odim == 2 else [])
self.cache[attr] = torch.empty(dim)

else:
if self.entries is None:
# vertically stack output attribute to cache tensor at the appropriate key.
self.cache[attr] = torch.vstack([self.cache[attr], output[attr][None, :]])
target = output[attr][None, :] if odim == 1 else output[attr][None, :, :]
self.cache[attr] = torch.vstack([self.cache[attr], target])
else:
# update appropriate row in preallocated tensor.
self.cache[attr][self.iteration, :] = output[attr]
self.cache[attr][self.iteration] = output[attr]

# advance iteration counter.
self.iteration += 1
# advance iteration counter.
self.iteration += 1

return fn

Expand Down

0 comments on commit 1390191

Please sign in to comment.