Skip to content

Commit

Permalink
[Pydantic III] Use list/dict annotations (#1021)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Cody Baker <[email protected]>
Co-authored-by: CodyCBakerPhD <[email protected]>
  • Loading branch information
4 people authored Aug 21, 2024
1 parent beb48d9 commit 8e20a25
Show file tree
Hide file tree
Showing 47 changed files with 201 additions and 210 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* `BaseRecordingInterface` now calls default metadata when metadata is not passing mimicking `run_conversion` behavior. [PR #1012](https://github.com/catalystneuro/neuroconv/pull/1012)
* Added `get_json_schema_from_method_signature` which constructs Pydantic models automatically from the signature of any function with typical annotation types used throughout NeuroConv. [PR #1016](https://github.com/catalystneuro/neuroconv/pull/1016)
* Replaced all interface annotations with Pydantic types. [PR #1017](https://github.com/catalystneuro/neuroconv/pull/1017)
* Changed typehint collections (e.g. `List`) to standard collections (e.g. `list`). [PR #1021](https://github.com/catalystneuro/neuroconv/pull/1021)



Expand Down
6 changes: 3 additions & 3 deletions src/neuroconv/basedatainterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import Literal, Optional, Union

from jsonschema.validators import validate
from pydantic import FilePath
Expand All @@ -30,8 +30,8 @@ class BaseDataInterface(ABC):
"""Abstract class defining the structure of all DataInterfaces."""

display_name: Union[str, None] = None
keywords: Tuple[str] = tuple()
associated_suffixes: Tuple[str] = tuple()
keywords: tuple[str] = tuple()
associated_suffixes: tuple[str] = tuple()
info: Union[str, None] = None

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions src/neuroconv/datainterfaces/behavior/audio/audiointerface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import List, Literal, Optional
from typing import Literal, Optional

import numpy as np
import scipy
Expand Down Expand Up @@ -28,7 +28,7 @@ class AudioInterface(BaseTemporalAlignmentInterface):
associated_suffixes = (".wav",)
info = "Interface for writing audio recordings to an NWB file."

def __init__(self, file_paths: List[FilePath], verbose: bool = False):
def __init__(self, file_paths: list[FilePath], verbose: bool = False):
"""
Data interface for writing acoustic recordings to an NWB file.
Expand Down Expand Up @@ -105,7 +105,7 @@ def get_original_timestamps(self) -> np.ndarray:
def get_timestamps(self) -> Optional[np.ndarray]:
raise NotImplementedError("The AudioInterface does not yet support timestamps.")

def set_aligned_timestamps(self, aligned_timestamps: List[np.ndarray]):
def set_aligned_timestamps(self, aligned_timestamps: list[np.ndarray]):
raise NotImplementedError("The AudioInterface does not yet support timestamps.")

def set_aligned_starting_time(self, aligned_starting_time: float):
Expand All @@ -132,7 +132,7 @@ def set_aligned_starting_time(self, aligned_starting_time: float):
"Please set them using 'set_aligned_segment_starting_times'."
)

def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float]):
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]):
"""
Align the individual starting time for each audio file in this interface relative to the common session start time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import warnings
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -305,7 +305,7 @@ def add_subject_to_nwbfile(
h5file: FilePath,
individual_name: str,
config_file: FilePath,
timestamps: Optional[Union[List, np.ndarray]] = None,
timestamps: Optional[Union[list, np.ndarray]] = None,
pose_estimation_container_kwargs: Optional[dict] = None,
) -> NWBFile:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np
from pydantic import FilePath
Expand Down Expand Up @@ -76,7 +76,7 @@ def get_timestamps(self) -> np.ndarray:
"Unable to retrieve timestamps for this interface! Define the `get_timestamps` method for this interface."
)

def set_aligned_timestamps(self, aligned_timestamps: Union[List, np.ndarray]):
def set_aligned_timestamps(self, aligned_timestamps: Union[list, np.ndarray]):
"""
Set aligned timestamps vector for DLC data with user defined timestamps
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Optional
from typing import Optional

from pydantic import FilePath
from pynwb import NWBFile
Expand Down Expand Up @@ -106,8 +106,8 @@ def add_to_nwbfile(
reference_frame: Optional[str] = None,
confidence_definition: Optional[str] = None,
external_mode: bool = True,
starting_frames_original_videos: Optional[List[int]] = None,
starting_frames_labeled_videos: Optional[List[int]] = None,
starting_frames_original_videos: Optional[list[int]] = None,
starting_frames_labeled_videos: Optional[list[int]] = None,
stub_test: bool = False,
):
original_video_interface = self.data_interface_objects["OriginalVideo"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional

import numpy as np
from pydantic import FilePath
Expand Down Expand Up @@ -116,7 +116,7 @@ def _load_source_data(self):
pose_estimation_data = pd.read_csv(self.file_path, header=[0, 1, 2])
return pose_estimation_data

def _get_original_video_shape(self) -> Tuple[int, int]:
def _get_original_video_shape(self) -> tuple[int, int]:
with self._vc(file_path=str(self.original_video_file_path)) as video:
video_shape = video.get_frame_shape()
# image size of the original video is in height x width
Expand Down
8 changes: 4 additions & 4 deletions src/neuroconv/datainterfaces/behavior/neuralynx/nvt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from datetime import datetime
from shutil import copy
from typing import Dict, List, Union
from typing import Union

import numpy as np
from pydantic import FilePath
Expand All @@ -26,7 +26,7 @@
]


def read_header(filename: str) -> Dict[str, Union[str, datetime, float, int, List[int]]]:
def read_header(filename: str) -> dict[str, Union[str, datetime, float, int, list[int]]]:
"""
Parses a Neuralynx NVT File Header and returns it as a dictionary.
Expand Down Expand Up @@ -83,7 +83,7 @@ def parse_bool(x):
return out


def read_data(filename: str) -> Dict[str, np.ndarray]:
def read_data(filename: str) -> dict[str, np.ndarray]:
"""
Reads a NeuroLynx NVT file and returns its data.
Expand All @@ -97,7 +97,7 @@ def read_data(filename: str) -> Dict[str, np.ndarray]:
Returns
-------
Dict[str, np.ndarray]
dict[str, np.ndarray]
Dictionary containing the parsed data.
Raises
Expand Down
2 changes: 1 addition & 1 deletion src/neuroconv/datainterfaces/behavior/video/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _get_frame_details(self):
min_frame_size_mb = (math.prod(frame_shape) * self._get_dtype().itemsize) / 1e6
return min_frame_size_mb, frame_shape

def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
def _get_data(self, selection: tuple[slice]) -> np.ndarray:
start_frame = selection[0].start
end_frame = selection[0].stop
frames = np.empty(shape=[end_frame - start_frame, *self._maxshape[1:]])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from copy import deepcopy
from pathlib import Path
from typing import List, Literal, Optional
from typing import Literal, Optional

import numpy as np
import psutil
Expand Down Expand Up @@ -30,7 +30,7 @@ class VideoInterface(BaseDataInterface):

def __init__(
self,
file_paths: List[FilePath],
file_paths: list[FilePath],
verbose: bool = False,
*,
metadata_key_name: str = "Videos",
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_metadata(self):

return metadata

def get_original_timestamps(self, stub_test: bool = False) -> List[np.ndarray]:
def get_original_timestamps(self, stub_test: bool = False) -> list[np.ndarray]:
"""
Retrieve the original unaltered timestamps for the data in this interface.
Expand Down Expand Up @@ -159,7 +159,7 @@ def get_timing_type(self) -> Literal["starting_time and rate", "timestamps"]:
"Please specify the temporal alignment of each video."
)

def get_timestamps(self, stub_test: bool = False) -> List[np.ndarray]:
def get_timestamps(self, stub_test: bool = False) -> list[np.ndarray]:
"""
Retrieve the timestamps for the data in this interface.
Expand All @@ -176,7 +176,7 @@ def get_timestamps(self, stub_test: bool = False) -> List[np.ndarray]:
"""
return self._timestamps or self.get_original_timestamps(stub_test=stub_test)

def set_aligned_timestamps(self, aligned_timestamps: List[np.ndarray]):
def set_aligned_timestamps(self, aligned_timestamps: list[np.ndarray]):
"""
Replace all timestamps for this interface with those aligned to the common session start time.
Expand Down Expand Up @@ -221,7 +221,7 @@ def set_aligned_starting_time(self, aligned_starting_time: float, stub_test: boo
else:
raise ValueError("There are no timestamps or starting times set to shift by a common value!")

def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float], stub_test: bool = False):
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float], stub_test: bool = False):
"""
Align the individual starting time for each video (segment) in this interface relative to the common session start time.
Expand Down Expand Up @@ -264,7 +264,7 @@ def add_to_nwbfile(
metadata: Optional[dict] = None,
stub_test: bool = False,
external_mode: bool = True,
starting_frames: Optional[List[int]] = None,
starting_frames: Optional[list[int]] = None,
chunk_data: bool = True,
module_name: Optional[str] = None,
module_description: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal, Optional, Union
from typing import Literal, Optional, Union

import numpy as np
from pynwb import NWBFile
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_metadata(self) -> DeepDict:

return metadata

def get_original_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
def get_original_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]:
"""
Retrieve the original unaltered timestamps for the data in this interface.
Expand All @@ -128,7 +128,7 @@ def get_original_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
for segment_index in range(self._number_of_segments)
]

def get_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
def get_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]:
"""
Retrieve the timestamps for the data in this interface.
Expand All @@ -152,7 +152,7 @@ def set_aligned_timestamps(self, aligned_timestamps: np.ndarray):

self.recording_extractor.set_times(times=aligned_timestamps)

def set_aligned_segment_timestamps(self, aligned_segment_timestamps: List[np.ndarray]):
def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]):
"""
Replace all timestamps for all segments in this interface with those aligned to the common session start time.
Expand Down Expand Up @@ -185,7 +185,7 @@ def set_aligned_starting_time(self, aligned_starting_time: float):
]
)

def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float]):
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]):
"""
Align the starting time for each segment in this interface relative to the common session start time.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Literal, Optional, Union
from typing import Literal, Optional, Union

import numpy as np
from pynwb import NWBFile
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_original_timestamps(self) -> np.ndarray:
"Unable to fetch original timestamps for a SortingInterface since it relies upon an attached recording."
)

def get_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
def get_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]:
if not self.sorting_extractor.has_recording():
raise NotImplementedError(
"In order to align timestamps for a SortingInterface, it must have a recording "
Expand Down Expand Up @@ -138,7 +138,7 @@ def set_aligned_timestamps(self, aligned_timestamps: np.ndarray):
times=aligned_timestamps[segment_index], segment_index=segment_index
)

def set_aligned_segment_timestamps(self, aligned_segment_timestamps: List[np.ndarray]):
def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]):
"""
Replace all timestamps for all segments in this interface with those aligned to the common session start time.
Expand Down Expand Up @@ -182,7 +182,7 @@ def set_aligned_starting_time(self, aligned_starting_time: float):
else:
sorting_segment._t_start += aligned_starting_time

def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float]):
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]):
"""
Align the starting time for each segment in this interface relative to the common session start time.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Optional
from typing import Optional

import numpy as np
from pydantic import DirectoryPath
Expand All @@ -18,7 +18,7 @@ class NeuralynxRecordingInterface(BaseRecordingExtractorInterface):
info = "Interface for Neuralynx recording data."

@classmethod
def get_stream_names(cls, folder_path: DirectoryPath) -> List[str]:
def get_stream_names(cls, folder_path: DirectoryPath) -> list[str]:
from spikeinterface.extractors import NeuralynxRecordingExtractor

stream_names, _ = NeuralynxRecordingExtractor.get_streams(folder_path=folder_path)
Expand Down Expand Up @@ -158,16 +158,16 @@ def extract_neo_header_metadata(neo_reader) -> dict:
return common_header


def _dict_intersection(dict_list: List) -> dict:
def _dict_intersection(dict_list: list[dict]) -> dict:
"""
Intersect dict_list and return only common keys and values
Parameters
----------
dict_list: list of dicitionaries each representing a header
dict_list: list of dictionaries each representing a header
Returns
-------
dict:
Dictionary containing key-value pairs common to all input dicitionary_list
Dictionary containing key-value pairs common to all input dictionary_list
"""

# Collect keys appearing in all dictionaries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(
self,
folder_path: DirectoryPath,
keep_mua_units: bool = True,
exclude_shanks: Optional[list] = None,
exclude_shanks: Optional[list[int]] = None,
xml_file_path: Optional[FilePath] = None,
verbose: bool = True,
):
Expand All @@ -282,7 +282,7 @@ def __init__(
Path to folder containing .clu and .res files.
keep_mua_units : bool, default: True
Optional. Whether to return sorted spikes from multi-unit activity.
exclude_shanks : list, optional
exclude_shanks : list of integers, optional
List of indices to ignore. The set of all possible indices is chosen by default, extracted as the
final integer of all the .res.%i and .clu.%i pairs.
xml_file_path : FilePathType, optional
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional

from pydantic import DirectoryPath

Expand All @@ -20,7 +20,7 @@ class OpenEphysBinaryRecordingInterface(BaseRecordingExtractorInterface):
ExtractorName = "OpenEphysBinaryRecordingExtractor"

@classmethod
def get_stream_names(cls, folder_path: DirectoryPath) -> List[str]:
def get_stream_names(cls, folder_path: DirectoryPath) -> list[str]:
from spikeinterface.extractors import OpenEphysBinaryRecordingExtractor

stream_names, _ = OpenEphysBinaryRecordingExtractor.get_streams(folder_path=folder_path)
Expand Down
Loading

0 comments on commit 8e20a25

Please sign in to comment.