Skip to content

Commit

Permalink
Draft signature
Browse files Browse the repository at this point in the history
  • Loading branch information
caiw committed Sep 21, 2024
1 parent 6b33c97 commit bee48b9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 36 deletions.
32 changes: 16 additions & 16 deletions demos/demo_ippm.ipynb

Large diffs are not rendered by default.

93 changes: 73 additions & 20 deletions kymata/ippm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
A graphing functions used to construct a dictionary that contains the nodes and all relevant
information to construct a dict containing node names as keys and Node objects (see namedtuple) as values.
"""

from collections import defaultdict, Counter
from copy import deepcopy
from enum import StrEnum
from typing import NamedTuple

import numpy as np
Expand All @@ -30,13 +31,20 @@ class IPPMNode(NamedTuple):
IPPMGraph = dict[str, IPPMNode]


class _YOrdinateMethods(StrEnum):
progressive = "progressive"
centred = "centred"


class IPPMBuilder:
def __init__(
self,
spikes: SpikeDict,
inputs: list[str],
hierarchy: TransformHierarchy,
hemisphere: str,
y_ordinate_method: str = _YOrdinateMethods.progressive,
levels: dict[str, int] = None
):
self._spikes: SpikeDict = deepcopy(spikes)
self._inputs: list[str] = inputs
Expand All @@ -46,21 +54,52 @@ def __init__(
self._sort_spikes_by_latency_asc()

self.graph: IPPMGraph = dict()
self.graph = self._build_graph_dict(deepcopy(self._hierarchy))

def _build_graph_dict(self, hierarchy: TransformHierarchy) -> IPPMGraph:
y_axis_partition_size = (
1 / len(hierarchy.keys()) if len(hierarchy.keys()) > 0 else 1
)
partition_ptr = 0
graph = dict()
while childless_functions := self._get_childless_functions(hierarchy):
for childless_func in childless_functions:
graph = self._create_nodes_and_edges_for_function(
childless_func, partition_ptr, y_axis_partition_size
)
hierarchy.pop(childless_func)
partition_ptr += 1
self.graph = self._build_graph_dict(deepcopy(self._hierarchy), y_ordinate_method, levels)

def _build_graph_dict(self,
hierarchy: TransformHierarchy,
y_ordinate_method: str,
levels: dict[str, int],
) -> IPPMGraph:
"""
y_ordinate_method == "progressive" for y ordinates to be selected progressively from the input
y_ordinate_method == "centred" for y ordinates to be centred vertically based on assigned levels in the
hierarchy
levels: maps node names in the hierarchy to level-idxs of vertically centred nodes
"""
if y_ordinate_method == _YOrdinateMethods.progressive:
y_axis_partition_size = (
1 / len(hierarchy.keys()) if len(hierarchy.keys()) > 0 else 1
)
partition_ptr = 0
graph = dict()
while childless_functions := self._get_childless_functions(hierarchy):
for childless_func in childless_functions:
graph = self._create_nodes_and_edges_for_function_progressive(
childless_func, partition_ptr, y_axis_partition_size
)
hierarchy.pop(childless_func)
partition_ptr += 1

elif y_ordinate_method == _YOrdinateMethods.centred:
if levels is None:
raise ValueError(f"Supply `levels` when using {_YOrdinateMethods.centred} option")
totals_within_level = Counter(levels.values())
idxs_within_level = defaultdict(int)
graph = dict()
while childless_functions := self._get_childless_functions(hierarchy):
for childless_func in childless_functions:
graph = self._create_nodes_and_edges_for_function_centred(
childless_func,
idxs_within_level[childless_func],
totals_within_level[levels[childless_func]],
max(totals_within_level.values()),
)
hierarchy.pop(childless_func)
idxs_within_level[childless_func] += 1

else:
raise NotImplementedError()

return graph

Expand All @@ -81,10 +120,24 @@ def __unpack_dict_values_into_list(dict_to_unpack):
# When no functions left, it returns empty set.
return current_functions.difference(functions_with_children)

def _create_nodes_and_edges_for_function(self, function_name: str, partition_ptr: int, partition_size: float) -> IPPMGraph:
def __get_y_coordinate(
curr_partition_number: int, partition_size: float
) -> float:
def _create_nodes_and_edges_for_function_centred(self,
function_name: str,
function_idx_within_level: int,
function_total_within_level: int,
max_function_total_within_level: int,
) -> IPPMGraph:
"""
x_batch_size: how many nodes in a vertically-centred batch.
x_batch_idx: which node this is in the batch
"""


def _create_nodes_and_edges_for_function_progressive(self,
function_name: str,
partition_ptr: int,
partition_size: float,
) -> IPPMGraph:
def __get_y_coordinate(curr_partition_number: int, partition_size: float) -> float:
return 1 - partition_size * curr_partition_number

func_parents = self._hierarchy[function_name]
Expand Down

0 comments on commit bee48b9

Please sign in to comment.