Skip to content

Commit

Permalink
Merge pull request #3089 from h-mayorquin/recording_dict_iterator
Browse files Browse the repository at this point in the history
extractor_dict_iterator for solving path detection in object `kwargs`
  • Loading branch information
samuelgarcia authored Jun 28, 2024
2 parents c24c966 + 3eee955 commit 47dd371
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 107 deletions.
125 changes: 113 additions & 12 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations
from pathlib import Path, WindowsPath
from typing import Union
from typing import Union, Generator
import os
import sys
import datetime
import json
from copy import deepcopy
import importlib
from math import prod
from collections import namedtuple

import numpy as np

Expand Down Expand Up @@ -183,6 +184,75 @@ def is_dict_extractor(d: dict) -> bool:
return is_extractor


extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])


def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
"""
Iterator for recursive traversal of a dictionary.
This function explores the dictionary recursively and yields the path to each value along with the value itself.
By path here we mean the keys that lead to the value in the dictionary:
e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b').
See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure.
Parameters
----------
extractor_dict : dict
Input dictionary
Yields
------
extractor_dict_element
Named tuple containing the value, the name, and the access_path to the value in the dictionary.
"""

def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""):
if isinstance(dict_list_or_value, dict):
for k, v in dict_list_or_value.items():
yield from _extractor_dict_iterator(v, access_path + (k,), name=k)
elif isinstance(dict_list_or_value, list):
for i, v in enumerate(dict_list_or_value):
yield from _extractor_dict_iterator(
v, access_path + (i,), name=name
) # Propagate name of list to children
else:
yield extractor_dict_element(
value=dict_list_or_value,
name=name,
access_path=access_path,
)

yield from _extractor_dict_iterator(extractor_dict)


def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value):
"""
In place modification of a value in a nested dictionary given its access path.
Parameters
----------
extractor_dict : dict
The dictionary to modify
access_path : tuple
The path to the value in the dictionary
new_value : object
The new value to set
Returns
-------
dict
The modified dictionary
"""

current = extractor_dict
for key in access_path[:-1]:
current = current[key]
current[access_path[-1]] = new_value


def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
"""
Generic function for recursive modification of paths in an extractor dict.
Expand Down Expand Up @@ -250,15 +320,17 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
raise ValueError(f"{k} key for path must be str or list[str]")


def _get_paths_list(d):
# this explore a dict and get all paths flatten in a list
# the trick is to use a closure func called by recursive_path_modifier()
path_list = []
# This is the current definition that an element in a extractor_dict is a path
# This is shared across a couple of definition so it is here for DNRY
element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path))


def append_to_path(p):
path_list.append(p)
def _get_paths_list(d: dict) -> list[str | Path]:
path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)]

# if check_if_exists: TODO: Enable this once container_tools test uses proper mocks
# path_list = [p for p in path_list if Path(p).exists()]

recursive_path_modifier(d, append_to_path, target="path", copy=True)
return path_list


Expand Down Expand Up @@ -318,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool:
return len(not_possible) == 0


def make_paths_relative(input_dict, relative_folder) -> dict:
def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict:
"""
Recursively transform a dict describing an BaseExtractor to make every path relative to a folder.
Expand All @@ -334,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict:
output_dict: dict
A copy of the input dict with modified paths.
"""

relative_folder = Path(relative_folder).resolve().absolute()
func = lambda p: _relative_to(p, relative_folder)
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
# Only paths that exist are made relative
path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()]

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
new_value = _relative_to(element.value, relative_folder)
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


Expand All @@ -359,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder):
base_folder = Path(base_folder)
# use as_posix instead of str to make the path unix like even on window
func = lambda p: (base_folder / p).resolve().absolute().as_posix()
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
output_dict = deepcopy(input_dict)

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
absolute_path = (base_folder / element.value).resolve()
if Path(absolute_path).exists():
new_value = absolute_path.as_posix() # Not so sure about this, Sam
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


def recursive_key_finder(d, key):
# Find all values for a key on a dictionary, even if nested
# TODO refactor to use extractor_dict_iterator

for k, v in d.items():
if isinstance(v, dict):
yield from recursive_key_finder(v, key)
Expand Down
Loading

0 comments on commit 47dd371

Please sign in to comment.