diff --git a/sapicore/engine/network/__init__.py b/sapicore/engine/network/__init__.py index 8d3a979..933037d 100644 --- a/sapicore/engine/network/__init__.py +++ b/sapicore/engine/network/__init__.py @@ -460,9 +460,9 @@ def draw(self, path: str, node_size: int = 750): """ plt.figure() - nx.draw(self.graph, node_size=node_size, with_labels=True, pos=nx.kamada_kawai_layout(self.network.graph)) + nx.draw(self.graph, node_size=node_size, with_labels=True, pos=nx.kamada_kawai_layout(self.graph)) - plt.savefig(fname=os.path.join(path, self.network.identifier + ".svg")) + plt.savefig(fname=os.path.join(path, self.identifier + ".svg")) plt.clf() def _in_edges(self, node: str) -> list[Synapse]: diff --git a/sapicore/model/__init__.py b/sapicore/model/__init__.py index 4b48a7c..a56a047 100644 --- a/sapicore/model/__init__.py +++ b/sapicore/model/__init__.py @@ -5,7 +5,7 @@ :class:`~engine.network.Network` output for practical purposes. """ -from typing import Callable +from typing import Callable, Sequence import dill import os @@ -38,14 +38,14 @@ def __init__(self, network: Network = None, **kwargs): # store a reference to one network object. self.network = network - def _serve(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int | list[int] = 0): + def _serve(self, data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0): """Serves a batch of data to this model's network. Each sample `i` is presented for `duration[i]`, followed by all 0s stimulation for `rinse[i]`. """ # wrap 2D tensor data in a list if need be, to make subsequent single- and multi-root operations uniform. - if not isinstance(data, list): + if not isinstance(data, Sequence): data = [data] num_samples = data[0].shape[0] @@ -67,7 +67,9 @@ def _serve(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: # advance progress bar. bar() - def fit(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int | list[int] = 0, **kwargs): + def fit( + self, data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0, **kwargs + ): """Applies :meth:`engine.network.Network.forward` sequentially on a block of buffer `data`, then turns off learning for the network. @@ -75,15 +77,15 @@ def fit(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int Parameters ---------- - data: Tensor or list of Tensor + data: Tensor or Sequence of Tensor 2D tensor(s) of data buffer to be fed to the root ensemble(s) of this object's `network`, formatted sample X feature. - duration: int or list of int + duration: int or Sequence of int Duration of sample presentation. Simulates duration of exposure to a particular input. If a list or a tensor is provided, the i-th sample in the batch is maintained for `duration[i]` steps. - rinse: int or list of int + rinse: int or Sequence of int Null stimulation steps (0s in-between samples). If a list or a tensor is provided, the i-th sample is followed by `rinse[i]` rinse steps. @@ -99,7 +101,7 @@ def fit(self, data: Tensor | list[Tensor], duration: int | list[int], rinse: int for synapse in self.network.get_synapses(): synapse.set_learning(False) - def predict(self, data: Data | Tensor) -> Tensor: + def predict(self, data: Data | Tensor, **kwargs) -> Tensor: """Predicts the labels of `data` by feeding the buffer to a trained network and applying some procedure to the resulting population/readout layer response. @@ -116,7 +118,7 @@ def predict(self, data: Data | Tensor) -> Tensor: """ raise NotImplementedError - def similarity(self, data: Tensor, metric: str | Callable) -> Tensor: + def similarity(self, data: Tensor, metric: str | Callable, **kwargs) -> Tensor: """Performs rudimentary similarity analysis on the network's responses to `data`, yielding a pairwise distance matrix. diff --git a/sapicore/pipeline/simple.py b/sapicore/pipeline/simple.py index 4feaf4b..901f731 100644 --- a/sapicore/pipeline/simple.py +++ b/sapicore/pipeline/simple.py @@ -141,7 +141,7 @@ def run(self): model.network.add_data_hook(data_dir, steps) # save an SVG plot of the network architecture. - model.draw(path=run_dir) + model.network.draw(path=run_dir) if not self.data: steps = self.configuration.get("simulation", {}).get("steps", {}) @@ -159,7 +159,7 @@ def run(self): # fit the model by passing `data` buffer to the network (`duration` controls exposure time to each sample). logging.info(f"Simulating {steps} steps at a resolution of {DT} ms.") - model.fit(data=self.data, repetitions=self.configuration.get("duration", 1)) + model.fit(data=self.data, duration=self.configuration.get("duration", 1)) # optional tensorboard logging. if tensorboard: diff --git a/sapicore/tests/data/test_data.py b/sapicore/tests/data/test_data.py index 91e5dc5..e9bd4ac 100644 --- a/sapicore/tests/data/test_data.py +++ b/sapicore/tests/data/test_data.py @@ -57,7 +57,7 @@ def test_data_pipeline(self): models.append(Model(Network())) # repeats each sample for a random number of steps (simulating variable exposure durations). - models[i].fit(data[train], repetitions=torch.randint(low=2, high=7, size=(data[train].shape[0],))) + models[i].fit(data[train], duration=torch.randint(low=2, high=7, size=(data[train].shape[0],))) @pytest.mark.parametrize( "url_", diff --git a/sapicore/utils/io.py b/sapicore/utils/io.py index 084177f..bdb6e0d 100644 --- a/sapicore/utils/io.py +++ b/sapicore/utils/io.py @@ -75,25 +75,28 @@ def monitor_hook(self) -> Callable: def fn(_, __, output): # add current outputs to this data accumulator instance's cache dictionary, whose values are tensors. for attr in self.attributes: + odim = output[attr].dim() if attr not in self.cache.keys(): # the cache dictionary is empty because this is the first iteration. if self.entries is None: # expand output by one dimension (zero axis) to fit. - self.cache[attr] = output[attr][None, :] + self.cache[attr] = output[attr][None, :] if odim == 1 else output[attr][None, :, :] else: # preallocate if number of steps is known. - self.cache[attr] = torch.empty((self.entries, len(output[attr]))) + dim = [self.entries, len(output[attr])] + ([output[attr].shape[1]] if odim == 2 else []) + self.cache[attr] = torch.empty(dim) else: if self.entries is None: # vertically stack output attribute to cache tensor at the appropriate key. - self.cache[attr] = torch.vstack([self.cache[attr], output[attr][None, :]]) + target = output[attr][None, :] if odim == 1 else output[attr][None, :, :] + self.cache[attr] = torch.vstack([self.cache[attr], target]) else: # update appropriate row in preallocated tensor. - self.cache[attr][self.iteration, :] = output[attr] + self.cache[attr][self.iteration] = output[attr] - # advance iteration counter. - self.iteration += 1 + # advance iteration counter. + self.iteration += 1 return fn