diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml
index 3ec7775..761547e 100644
--- a/.github/workflows/test_and_deploy.yml
+++ b/.github/workflows/test_and_deploy.yml
@@ -26,9 +26,6 @@ jobs:
needs: [linting, manifest]
name: ${{ matrix.os }} py${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
- env:
- # pooch cache dir
- BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
strategy:
matrix:
@@ -45,12 +42,13 @@ jobs:
python-version: "3.12"
steps:
- - name: Cache pooch data
+ - name: Cache data
uses: actions/cache@v4
with:
- path: "~/.pooch_cache"
+ path: "$HOME/.brainglobe"
# hash on conftest in case url changes
- key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/conftest.py') }}
+ key: brainglobe-dir-${{ runner.os }}${{ hashFiles('**/conftest.py') }}
+
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
@@ -72,16 +70,15 @@ jobs:
runs-on: ubuntu-latest
env:
NUMBA_DISABLE_JIT: "1"
- # pooch cache dir
- BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"
+
steps:
- - name: Cache pooch data
+ - name: Cache data
uses: actions/cache@v4
with:
- path: "~/.pooch_cache"
+ path: "$HOME/.brainglobe"
# hash on conftest in case url changes
- key: ${{ runner.os }}-3.11-${{ hashFiles('**/conftest.py') }}
+ key: brainglobe-dir-${{ runner.os }}${{ hashFiles('**/conftest.py') }}
- name: Set up Python
uses: actions/setup-python@v4
diff --git a/MANIFEST.in b/MANIFEST.in
index bc165f3..a7ba596 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -5,6 +5,7 @@ include LICENSE
include README.md
graft brainglobe_utils *.py
+include brainglobe_utils/napari.yaml
include brainglobe_utils/qtpy/brainglobe.png
prune tests
diff --git a/brainglobe_utils/brainmapper/analysis.py b/brainglobe_utils/brainmapper/analysis.py
index e78369a..60067f7 100644
--- a/brainglobe_utils/brainmapper/analysis.py
+++ b/brainglobe_utils/brainmapper/analysis.py
@@ -3,9 +3,10 @@
by Charly Rousseau (https://github.com/crousseau).
"""
+import os
from dataclasses import dataclass
from pathlib import Path
-from typing import List, Set, Union
+from typing import List, Optional, Set, Union
import numpy as np
import pandas as pd
@@ -104,15 +105,10 @@ def combine_df_hemispheres(df: pd.DataFrame) -> pd.DataFrame:
return both
-def create_all_cell_csv(
- points: List[Point], output_filename: Union[str, Path]
-) -> None:
+def create_all_cell_df(points: List[Point]) -> None:
"""
- Create a CSV file with cell data from a list of Point objects.
-
This function takes a list of Point objects, each representing cell
coordinates and brain region and converts this into a pandas DataFrame.
- The DataFrame is then saved to a CSV file at the specified filename.
Parameters
----------
@@ -120,16 +116,12 @@ def create_all_cell_csv(
A list of Point objects, each containing cell data such as
raw and atlas coordinates,
structure name, and hemisphere information.
- output_filename : Union[str, Path]
- The filename (including path) where the CSV file will be saved.
- Can be a string or a Path object.
Returns
-------
- None
+ df: pd.DataFrame
"""
- ensure_directory_exists(Path(output_filename).parent)
df = pd.DataFrame(
columns=(
"coordinate_raw_axis_0",
@@ -155,14 +147,13 @@ def create_all_cell_csv(
temp_matrix[i].append(point.hemisphere)
df = pd.DataFrame(temp_matrix, columns=df.columns, index=None)
- df.to_csv(output_filename, index=False)
+ return df
def count_points_per_brain_region(
points: List[Point],
structures_with_points: Set[str],
brainreg_volume_csv_path: Union[str, Path],
- output_filename: Union[str, Path],
) -> None:
"""
Count the number of points per brain region.
@@ -177,12 +168,11 @@ def count_points_per_brain_region(
brainreg_volume_csv_path : Union[str, Path]
The path to the CSV file containing volume information from the
brainreg registration.
- output_filename : Union[str, Path]
- The path where the summary of points by atlas region will be saved.
+
Returns
-------
- None
+ df: pd.DataFrame
"""
structures_with_points = list(structures_with_points)
@@ -219,17 +209,16 @@ def count_points_per_brain_region(
combined_hemispheres = combine_df_hemispheres(sorted_point_numbers)
df = calculate_densities(combined_hemispheres, brainreg_volume_csv_path)
df = sanitise_df(df)
-
- df.to_csv(output_filename, index=False)
+ return df
def summarise_points_by_atlas_region(
points_in_raw_data_space: np.ndarray,
points_in_atlas_space: np.ndarray,
atlas: BrainGlobeAtlas,
- brainreg_volume_csv_path: Union[str, Path],
- points_list_output_filename: Union[str, Path],
- summary_filename: Union[str, Path],
+ brainreg_volume_csv_path: Optional[os.PathLike] = None,
+ points_list_output_filename: Optional[os.PathLike] = None,
+ summary_filename: Optional[os.PathLike] = None,
) -> None:
"""
Summarise points data by atlas region.
@@ -282,11 +271,19 @@ def summarise_points_by_atlas_region(
except Exception:
continue
- create_all_cell_csv(points, points_list_output_filename)
+ all_cell_df = create_all_cell_df(points)
- count_points_per_brain_region(
+ if points_list_output_filename is not None:
+ ensure_directory_exists(Path(points_list_output_filename).parent)
+ all_cell_df.to_csv(points_list_output_filename, index=False)
+
+ points_per_region_df = count_points_per_brain_region(
points,
structures_with_points,
brainreg_volume_csv_path,
- summary_filename,
)
+
+ if summary_filename is not None:
+ points_per_region_df.to_csv(summary_filename, index=False)
+
+ return all_cell_df, points_per_region_df
diff --git a/brainglobe_utils/brainmapper/transform_widget.py b/brainglobe_utils/brainmapper/transform_widget.py
new file mode 100644
index 0000000..09d071c
--- /dev/null
+++ b/brainglobe_utils/brainmapper/transform_widget.py
@@ -0,0 +1,550 @@
+import json
+from pathlib import Path
+from typing import Any, Dict, List, Union
+
+import napari
+import pandas as pd
+import tifffile
+from brainglobe_atlasapi import BrainGlobeAtlas
+from brainglobe_atlasapi.list_atlases import get_downloaded_atlases
+from brainglobe_space import AnatomicalSpace
+from qtpy import QtCore
+from qtpy.QtWidgets import (
+ QComboBox,
+ QFileDialog,
+ QGridLayout,
+ QLabel,
+ QTableView,
+ QWidget,
+)
+
+from brainglobe_utils.brainmapper.analysis import (
+ summarise_points_by_atlas_region,
+)
+from brainglobe_utils.brainreg.transform import (
+ transform_points_from_downsampled_to_atlas_space,
+)
+from brainglobe_utils.general.system import ensure_extension
+from brainglobe_utils.qtpy.dialog import display_info
+from brainglobe_utils.qtpy.interaction import add_button, add_combobox
+from brainglobe_utils.qtpy.logo import header_widget
+from brainglobe_utils.qtpy.table import DataFrameModel
+
+
+class TransformPoints(QWidget):
+ def __init__(self, viewer: napari.viewer.Viewer):
+ super(TransformPoints, self).__init__()
+ self.viewer = viewer
+ self.raw_data = None
+ self.points_layer = None
+ self.atlas = None
+ self.transformed_points = None
+
+ self.image_layer_names = self._get_layer_names()
+ self.points_layer_names = self._get_layer_names(
+ layer_type=napari.layers.Points
+ )
+ self.setup_main_layout()
+
+ @self.viewer.layers.events.connect
+ def update_layer_list(v: napari.viewer.Viewer):
+ """
+ Update internal list of layers whenever the napari layers list
+ is updated.
+ """
+ self.image_layer_names = self._get_layer_names()
+ self.points_layer_names = self._get_layer_names(
+ layer_type=napari.layers.Points
+ )
+
+ self._update_combobox_options(
+ self.raw_data_choice, self.image_layer_names
+ )
+
+ self._update_combobox_options(
+ self.points_layer_choice, self.points_layer_names
+ )
+
+ @staticmethod
+ def _update_combobox_options(combobox: QComboBox, options_list: List[str]):
+ original_text = combobox.currentText()
+ combobox.clear()
+ combobox.addItems(options_list)
+ combobox.setCurrentText(original_text)
+
+ def _get_layer_names(
+ self,
+ layer_type: napari.layers.Layer = napari.layers.Image,
+ default: str = "",
+ ) -> List[str]:
+ """
+ Get list of layer names of a given layer type.
+ """
+ layer_names = [
+ layer.name
+ for layer in self.viewer.layers
+ if type(layer) == layer_type
+ ]
+
+ if layer_names:
+ return [default] + layer_names
+ else:
+ return [default]
+
+ def setup_main_layout(self):
+ self.layout = QGridLayout()
+ self.layout.setContentsMargins(10, 10, 10, 10)
+ self.layout.setAlignment(QtCore.Qt.AlignTop)
+ self.layout.setSpacing(4)
+ self.add_header()
+ self.add_points_combobox(row=1, column=0)
+ self.add_raw_data_combobox(row=2, column=0)
+ self.add_transform_button(row=3, column=0)
+
+ self.add_points_summary_table(row=4, column=0)
+ self.add_save_all_points_button(row=6, column=0)
+ self.add_save_points_summary_button(row=6, column=1)
+ self.add_status_label(row=7, column=0)
+
+ self.setLayout(self.layout)
+
+ def add_header(self):
+ """
+ Header including brainglobe logo and documentation links.
+ """
+ #
is included in the package_name to make the label under the logo
+ # more compact, by splitting it onto two lines
+ header = header_widget(
+ package_name="brainglobe-
utils",
+ package_tagline="Transform points to atlas space",
+ github_repo_name="brainglobe-utils",
+ citation_doi="https://doi.org/10.1038/s41598-021-04676-9",
+ help_text="For help, hover the cursor over each parameter.",
+ )
+ self.layout.addWidget(header, 0, 0, 1, 2)
+
+ def add_points_combobox(self, row, column):
+ self.points_layer_choice, _ = add_combobox(
+ self.layout,
+ "Points layer",
+ self.points_layer_names,
+ column=column,
+ row=row,
+ callback=self.set_points_layer,
+ )
+
+ def add_raw_data_combobox(self, row, column):
+ self.raw_data_choice, _ = add_combobox(
+ self.layout,
+ "Raw data layer",
+ self.image_layer_names,
+ column=column,
+ row=row,
+ callback=self.set_raw_data_layer,
+ )
+
+ def add_transform_button(self, row, column):
+ self.transform_button = add_button(
+ "Transform points",
+ self.layout,
+ self.transform_points_to_atlas_space,
+ row=row,
+ column=column,
+ visibility=True,
+ tooltip="Transform points layer to atlas space",
+ )
+
+ def add_points_summary_table(self, row, column):
+ self.points_per_region_table_title = QLabel(
+ "Points distribution summary"
+ )
+ self.points_per_region_table_title.setVisible(False)
+ self.layout.addWidget(self.points_per_region_table_title, row, column)
+ self.points_per_region_table = QTableView()
+ self.points_per_region_table.setVisible(False)
+ self.layout.addWidget(self.points_per_region_table, row + 1, column)
+
+ def add_save_all_points_button(self, row, column):
+ self.save_all_points_button = add_button(
+ "Save all points information",
+ self.layout,
+ self.save_all_points_csv,
+ row=row,
+ column=column,
+ visibility=False,
+ tooltip="Save all points information as a csv file",
+ )
+
+ def add_save_points_summary_button(self, row, column):
+ self.save_points_summary_button = add_button(
+ "Save points summary",
+ self.layout,
+ self.save_points_summary_csv,
+ row=row,
+ column=column,
+ visibility=False,
+ tooltip="Save points summary as a csv file",
+ )
+
+ def add_status_label(self, row, column):
+ self.status_label = QLabel()
+ self.status_label.setText("Ready")
+ self.layout.addWidget(self.status_label, row, column)
+
+ def set_raw_data_layer(self):
+ """
+ Set background layer from current background text box selection.
+ """
+ if self.raw_data_choice.currentText() != "":
+ self.raw_data = self.viewer.layers[
+ self.raw_data_choice.currentText()
+ ]
+
+ def set_points_layer(self):
+ """
+ Set background layer from current background text box selection.
+ """
+ if self.points_layer_choice.currentText() != "":
+ self.points_layer = self.viewer.layers[
+ self.points_layer_choice.currentText()
+ ]
+
+ def transform_points_to_atlas_space(self):
+ layers_in_place = self.check_layers()
+ if not layers_in_place:
+ return
+
+ self.status_label.setText("Loading brainreg data ...")
+ data_loaded = self.load_brainreg_directory()
+
+ if not data_loaded:
+ self.status_label.setText("Ready")
+ return
+
+ self.status_label.setText("Transforming points ...")
+
+ self.run_transform_points_to_downsampled_space()
+ self.run_transform_downsampled_points_to_atlas_space()
+
+ self.status_label.setText("Analysing point distribution ...")
+ self.analyse_points()
+ self.status_label.setText("Ready")
+
+ def check_layers(self):
+ if self.raw_data is None and self.points_layer is None:
+ display_info(
+ self,
+ "No layers selected",
+ "Please select the layers corresponding to the points "
+ "you would like to transform and the raw data (registered by "
+ "brainreg)",
+ )
+ return False
+
+ if self.raw_data is None:
+ display_info(
+ self,
+ "No raw data layer selected",
+ "Please select a layer that corresponds to the raw "
+ "data (registered by brainreg)",
+ )
+ return False
+
+ if self.points_layer is None:
+ display_info(
+ self,
+ "No points layer selected",
+ "Please select a points layer you would like to transform",
+ )
+ return False
+
+ return True
+
+ def load_brainreg_directory(self):
+ brainreg_directory = QFileDialog.getExistingDirectory(
+ self,
+ "Select brainreg directory",
+ )
+ if brainreg_directory == "":
+ return False
+ else:
+ self.brainreg_directory = Path(brainreg_directory)
+
+ self.initialise_brainreg_data()
+ self.status_label.setText("Ready")
+ return True
+
+ def initialise_brainreg_data(self):
+ self.get_brainreg_paths()
+ self.check_brainreg_directory()
+ self.get_registration_metadata()
+ self.load_atlas()
+
+ def get_brainreg_paths(self):
+ self.paths = Paths(self.brainreg_directory)
+
+ def check_brainreg_directory(self):
+ try:
+ with open(self.paths.brainreg_metadata_file) as json_file:
+ self.brainreg_metadata = json.load(json_file)
+
+ if "atlas" not in self.brainreg_metadata:
+ self.display_brainreg_directory_warning()
+
+ except FileNotFoundError:
+ self.display_brainreg_directory_warning()
+
+ def display_brainreg_directory_warning(self):
+ display_info(
+ self,
+ "Not a brainreg directory",
+ "This directory does not appear to be a valid brainreg "
+ "directory. Please try loading another brainreg output directory.",
+ )
+
+ def get_registration_metadata(self):
+ self.metadata = Metadata(self.brainreg_metadata)
+
+ def load_atlas(self):
+ if not self.is_atlas_installed(self.metadata.atlas_string):
+ display_info(
+ self,
+ "Atlas not downloaded",
+ f"Atlas: {self.metadata.atlas_string} needs to be "
+ f"downloaded. This may take some time depending on "
+ f"the size of the atlas and your network speed.",
+ )
+ self.atlas = BrainGlobeAtlas(self.metadata.atlas_string)
+
+ def run_transform_points_to_downsampled_space(self):
+ downsampled_space = self.get_downsampled_space()
+ raw_data_space = self.get_raw_data_space()
+ self.points_in_downsampled_space = raw_data_space.map_points_to(
+ downsampled_space, self.points_layer.data
+ )
+ self.viewer.add_points(
+ self.points_in_downsampled_space,
+ name="Points in downsampled space",
+ visible=False,
+ )
+
+ def run_transform_downsampled_points_to_atlas_space(self):
+ deformation_field_paths = [
+ self.paths.deformation_field_0,
+ self.paths.deformation_field_1,
+ self.paths.deformation_field_2,
+ ]
+ self.points_in_atlas_space, points_out_of_bounds = (
+ transform_points_from_downsampled_to_atlas_space(
+ self.points_in_downsampled_space,
+ self.atlas,
+ deformation_field_paths,
+ warn_out_of_bounds=False,
+ )
+ )
+ self.viewer.add_points(
+ self.points_in_atlas_space,
+ name="Points in atlas space",
+ visible=True,
+ )
+
+ if len(points_out_of_bounds) > 0:
+ display_info(
+ self,
+ "Points outside atlas",
+ f"{len(points_out_of_bounds)} "
+ f"points fell outside the atlas space",
+ )
+
+ def get_downsampled_space(self):
+ target_shape = tifffile.imread(self.paths.downsampled_image).shape
+
+ downsampled_space = AnatomicalSpace(
+ self.atlas.orientation,
+ shape=target_shape,
+ resolution=self.atlas.resolution,
+ )
+ return downsampled_space
+
+ def get_raw_data_space(self):
+ raw_data_space = AnatomicalSpace(
+ self.metadata.orientation,
+ shape=self.raw_data.data.shape,
+ resolution=[float(i) for i in self.metadata.voxel_sizes],
+ )
+ return raw_data_space
+
+ def analyse_points(self):
+ self.all_points_df, self.points_per_region_df = (
+ summarise_points_by_atlas_region(
+ self.points_layer.data,
+ self.points_in_atlas_space,
+ self.atlas,
+ self.paths.volume_csv_path,
+ )
+ )
+
+ self.populate_summary_table()
+ self.save_all_points_button.setVisible(True)
+ self.save_points_summary_button.setVisible(True)
+
+ print("Analysing points")
+
+ def populate_summary_table(
+ self,
+ columns_to_keep=[
+ "structure_name",
+ "left_cell_count",
+ "right_cell_count",
+ ],
+ ):
+ summary_df = self.points_per_region_df[columns_to_keep]
+ self.points_per_region_table_model = DataFrameModel(summary_df)
+ self.points_per_region_table.setModel(
+ self.points_per_region_table_model
+ )
+ self.points_per_region_table_title.setVisible(True)
+ self.points_per_region_table.setVisible(True)
+
+ def save_all_points_csv(self):
+ self.save_df_to_csv(self.all_points_df)
+
+ def save_points_summary_csv(self):
+ self.save_df_to_csv(self.points_per_region_df)
+
+ def save_df_to_csv(self, df: pd.DataFrame) -> None:
+ """
+ Save the given DataFrame to a CSV file.
+
+ Prompts the user to choose a filename and ensures the file has a
+ .csv extension.
+ The DataFrame is then saved to the specified file.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The DataFrame to be saved.
+
+ Returns
+ -------
+ None
+ """
+ path, _ = QFileDialog.getSaveFileName(
+ self,
+ "Choose filename",
+ "",
+ "CSV Files (*.csv)",
+ )
+
+ if path:
+ path = ensure_extension(path, ".csv")
+ df.to_csv(path, index=False)
+
+ @staticmethod
+ def is_atlas_installed(atlas):
+ downloaded_atlases = get_downloaded_atlases()
+ if atlas in downloaded_atlases:
+ return True
+ else:
+ return False
+
+
+class Paths:
+ """
+ A class to hold all brainreg-related file paths.
+
+ N.B. this could be imported from brainreg, but it is copied here to
+ prevent a circular dependency
+
+ Attributes
+ ----------
+ brainreg_directory : Path
+ Path to brainreg output directory (or brainmapper
+ "registration" directory)
+ brainreg_metadata_file : Path
+ The path to the brainreg metadata (brainreg.json) file
+ deformation_field_0 : Path
+ The path to the deformation field (0th dimension)
+ deformation_field_1 : Path
+ The path to the deformation field (1st dimension)
+ deformation_field_2 : Path
+ The path to the deformation field (2nd dimension)
+ downsampled_image : Path
+ The path to the downsampled.tiff image file
+ volume_csv_path : Path
+ The path to the csv file containing region volumes
+
+ Parameters
+ ----------
+ brainreg_directory : Union[str, Path]
+ Path to brainreg output directory (or brainmapper
+ "registration" directory)
+ """
+
+ def __init__(self, brainreg_directory: Union[str, Path]) -> None:
+ self.brainreg_directory: Path = Path(brainreg_directory)
+ self.brainreg_metadata_file: Path = self.make_filepaths(
+ "brainreg.json"
+ )
+ self.deformation_field_0: Path = self.make_filepaths(
+ "deformation_field_0.tiff"
+ )
+ self.deformation_field_1: Path = self.make_filepaths(
+ "deformation_field_1.tiff"
+ )
+ self.deformation_field_2: Path = self.make_filepaths(
+ "deformation_field_2.tiff"
+ )
+ self.downsampled_image: Path = self.make_filepaths("downsampled.tiff")
+ self.volume_csv_path: Path = self.make_filepaths("volumes.csv")
+
+ def make_filepaths(self, filename: str) -> Path:
+ """
+ Create a full file path by combining the directory with a filename.
+
+ Parameters
+ ----------
+ filename : str
+ The name of the file to create a path for.
+
+ Returns
+ -------
+ Path
+ The full path to the specified file.
+ """
+ return self.brainreg_directory / filename
+
+
+class Metadata:
+ """
+ A class to represent brainreg registration metadata
+ (loaded from brainreg.json)
+
+ Attributes
+ ----------
+ orientation : str
+ The orientation of the input data (in brainglobe-space format)
+ atlas_string : str
+ The BrainGlobe atlas used for brain registration.
+ voxel_sizes : List[float]
+ The voxel sizes of the input data
+
+ Parameters
+ ----------
+ brainreg_metadata : Dict[str, Any]
+ A dictionary containing metadata information,
+ loaded from brainreg.json
+ """
+
+ def __init__(self, brainreg_metadata: Dict[str, Any]) -> None:
+ """
+ Initialize the Metadata instance with brainreg metadata.
+
+ Parameters
+ ----------
+ brainreg_metadata : Dict[str, Any]
+ A dictionary containing metadata information from brainreg.json
+ """
+ self.orientation: str = brainreg_metadata["orientation"]
+ self.atlas_string: str = brainreg_metadata["atlas"]
+ self.voxel_sizes: List[float] = brainreg_metadata["voxel_sizes"]
diff --git a/brainglobe_utils/general/system.py b/brainglobe_utils/general/system.py
index f33046f..ee4ba20 100644
--- a/brainglobe_utils/general/system.py
+++ b/brainglobe_utils/general/system.py
@@ -6,6 +6,7 @@
import subprocess
from pathlib import Path
from tempfile import gettempdir
+from typing import Union
import psutil
from natsort import natsorted
@@ -20,6 +21,33 @@
MAX_PROCESSES_WINDOWS = 61
+def ensure_extension(
+ file_path: Union[str, os.PathLike], extension: str
+) -> Path:
+ """
+ Ensure that the given file path has the specified extension.
+
+ If the file path does not already have the specified extension,
+ it changes the file path to have that extension.
+
+ Parameters
+ ----------
+ file_path : Union[str, os.PathLike]
+ The path to the file.
+ extension : str
+ The desired file extension (should include the dot, e.g., '.txt').
+
+ Returns
+ -------
+ Path
+ The Path object with the ensured extension.
+ """
+ path = Path(file_path)
+ if path.suffix != extension:
+ path = path.with_suffix(extension)
+ return path
+
+
def replace_extension(file, new_extension, check_leading_period=True):
"""
Replaces the file extension of a given file.
diff --git a/brainglobe_utils/napari.yaml b/brainglobe_utils/napari.yaml
new file mode 100644
index 0000000..638c257
--- /dev/null
+++ b/brainglobe_utils/napari.yaml
@@ -0,0 +1,12 @@
+name: brainglobe-utils
+display_name: BrainGlobe
+
+contributions:
+ commands:
+ - id: brainglobe-utils.TransformPoints
+ title: Open points transformation widget
+ python_name: brainglobe_utils.brainmapper.transform_widget:TransformPoints
+
+ widgets:
+ - command: brainglobe-utils.TransformPoints
+ display_name: Transform points to a BrainGlobe atlas
diff --git a/brainglobe_utils/qtpy/table.py b/brainglobe_utils/qtpy/table.py
new file mode 100644
index 0000000..95280d5
--- /dev/null
+++ b/brainglobe_utils/qtpy/table.py
@@ -0,0 +1,118 @@
+from typing import Any, Optional
+
+import pandas as pd
+from qtpy.QtCore import QAbstractTableModel, QModelIndex, Qt
+
+
+class DataFrameModel(QAbstractTableModel):
+ """
+ A Qt table model that wraps a pandas DataFrame for use with Qt
+ view widgets.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The DataFrame to be displayed in the Qt view.
+
+ """
+
+ def __init__(self, df: pd.DataFrame):
+ """
+ Initialize the model with a DataFrame.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The DataFrame to be displayed.
+ """
+ super().__init__()
+ self._df = df
+
+ def rowCount(self, parent: Optional[QModelIndex] = None) -> int:
+ """
+ Return the number of rows in the model.
+
+ Parameters
+ ----------
+ parent : Optional[QModelIndex], optional
+ The parent index, by default None.
+
+ Returns
+ -------
+ int
+ The number of rows in the DataFrame.
+ """
+ return self._df.shape[0]
+
+ def columnCount(self, parent: Optional[QModelIndex] = None) -> int:
+ """
+ Return the number of columns in the model.
+
+ Parameters
+ ----------
+ parent : Optional[QModelIndex], optional
+ The parent index, by default None.
+
+ Returns
+ -------
+ int
+ The number of columns in the DataFrame.
+ """
+ return self._df.shape[1]
+
+ def data(
+ self, index: QModelIndex, role: int = Qt.DisplayRole
+ ) -> Optional[Any]:
+ """
+ Return the data at the given index for the specified role.
+
+ Parameters
+ ----------
+ index : QModelIndex
+ The index of the data to be retrieved.
+ role : int, optional
+ The role for which the data is being requested,
+ by default Qt.DisplayRole.
+
+ Returns
+ -------
+ Optional[Any]
+ The data at the specified index, or None if the role
+ is not Qt.DisplayRole.
+ """
+ if role == Qt.DisplayRole:
+ return str(self._df.iloc[index.row(), index.column()])
+ return None
+
+ def headerData(
+ self,
+ section: int,
+ orientation: Qt.Orientation,
+ role: int = Qt.DisplayRole,
+ ) -> Optional[Any]:
+ """
+ Return the header data for the specified section and orientation.
+
+ Parameters
+ ----------
+ section : int
+ The section (column or row) for which the header data is requested.
+ orientation : Qt.Orientation
+ The orientation (horizontal or vertical) of the header.
+ role : int, optional
+ The role for which the header data is being requested, by
+ default Qt.DisplayRole.
+
+ Returns
+ -------
+ Optional[Any]
+ The header data for the specified section and orientation, or
+ None if the role is not Qt.DisplayRole.
+ """
+
+ if role == Qt.DisplayRole:
+ if orientation == Qt.Horizontal:
+ return self._df.columns[section]
+ if orientation == Qt.Vertical:
+ return self._df.index[section]
+ return None
diff --git a/pyproject.toml b/pyproject.toml
index a1a8140..485b4f8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,8 +47,12 @@ bug_tracker = "https://github.com/brainglobe/brainglobe-utils/issues"
source_code = "https://github.com/brainglobe/brainglobe-utils"
user_support = "https://github.com/brainglobe/brainglobe-utils/issues"
+[project.entry-points."napari.manifest"]
+brainglobe-utils = "brainglobe_utils:napari.yaml"
+
[project.optional-dependencies]
qt = ["qtpy", "superqt"]
+napari = ["brainglobe-utils[qt]", "napari[all]"]
dev = [
"black",
@@ -65,7 +69,7 @@ dev = [
"setuptools_scm",
"tox",
"pooch",
- "brainglobe-utils[qt]",
+ "brainglobe-utils[napari]",
]
@@ -125,5 +129,4 @@ passenv =
DISPLAY
XAUTHORITY
PYVISTA_OFF_SCREEN
- BRAINGLOBE_TEST_DATA_DIR
"""
diff --git a/tests/tests/conftest.py b/tests/tests/conftest.py
index f0a78cc..a554762 100644
--- a/tests/tests/conftest.py
+++ b/tests/tests/conftest.py
@@ -20,12 +20,12 @@ def test_data_registry():
"""
registry = pooch.create(
- path=pooch.os_cache("brainglobe_test_data"),
+ path=Path.home() / ".brainglobe" / "test_data",
base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/",
registry={
"cellfinder/cells-z-1000-1050.xml": None,
"cellfinder/other-cells-z-1000-1050.xml": None,
+ "brainglobe-utils/points_transform_brainreg_directory.zip": "a1997f61a5efa752584ea91b7c479506343215bb91f5be09a72349f24e21fc54", # noqa: E501
},
- env="BRAINGLOBE_TEST_DATA_DIR",
)
return registry
diff --git a/tests/tests/test_brainmapper/test_analysis.py b/tests/tests/test_brainmapper/test_analysis.py
index f20d9eb..356f953 100644
--- a/tests/tests/test_brainmapper/test_analysis.py
+++ b/tests/tests/test_brainmapper/test_analysis.py
@@ -169,9 +169,12 @@ def test_get_region_totals(
)
output_path = Path(tmp_path / "tmp_region_totals.csv")
- count_points_per_brain_region(
- points, structures_with_points, volumes_path, output_path
+ points_per_region_df = count_points_per_brain_region(
+ points,
+ structures_with_points,
+ volumes_path,
)
+ points_per_region_df.to_csv(output_path, index=False)
assert output_path.exists()
# Read data back in, and sort rows by the structures for comparison.
diff --git a/tests/tests/test_brainmapper/test_transform_widget.py b/tests/tests/test_brainmapper/test_transform_widget.py
new file mode 100644
index 0000000..d3050d5
--- /dev/null
+++ b/tests/tests/test_brainmapper/test_transform_widget.py
@@ -0,0 +1,567 @@
+import json
+from pathlib import Path
+from typing import Any, Dict
+
+import numpy as np
+import pandas as pd
+import pooch
+import pytest
+from brainglobe_atlasapi import BrainGlobeAtlas
+
+from brainglobe_utils.brainmapper.transform_widget import (
+ Metadata,
+ Paths,
+ TransformPoints,
+)
+
+RAW_DATA_ORIENTATION = ATLAS_ORIENTATION = "asr"
+points = np.array(
+ [
+ [10, 68, 105],
+ [10, 90, 134],
+ [10, 105, 157],
+ [36, 69, 86],
+ [36, 72, 155],
+ [36, 112, 128],
+ [74, 54, 60],
+ [74, 121, 50],
+ [74, 87, 153],
+ [74, 84, 169],
+ [74, 108, 156],
+ [74, 75, 148],
+ [74, 98, 169],
+ [74, 76, 159],
+ [74, 99, 156],
+ [74, 91, 146],
+ [74, 87, 160],
+ [112, 44, 60],
+ [112, 76, 136],
+ [156, 77, 54],
+ [173, 126, 159],
+ [201, 66, 130],
+ [219, 132, 199],
+ [219, 1, 1],
+ ]
+)
+
+points_outside_brain = np.array(
+ [
+ [10000, 10000, 10000],
+ [100001, 100001, 100001],
+ ]
+)
+
+
+points_in_downsampled_space = np.array(
+ [
+ [10.0, 68.0, 105.0],
+ [10.0, 90.0, 134.0],
+ [10.0, 105.0, 157.0],
+ [36.0, 69.0, 86.0],
+ [36.0, 72.0, 155.0],
+ [36.0, 112.0, 128.0],
+ [74.0, 54.0, 60.0],
+ [74.0, 121.0, 50.0],
+ [74.0, 87.0, 153.0],
+ [74.0, 84.0, 169.0],
+ [74.0, 108.0, 156.0],
+ [74.0, 75.0, 148.0],
+ [74.0, 98.0, 169.0],
+ [74.0, 76.0, 159.0],
+ [74.0, 99.0, 156.0],
+ [74.0, 91.0, 146.0],
+ [74.0, 87.0, 160.0],
+ [112.0, 44.0, 60.0],
+ [112.0, 76.0, 136.0],
+ [156.0, 77.0, 54.0],
+ [173.0, 126.0, 159.0],
+ [201.0, 66.0, 130.0],
+ [219.0, 132.0, 199.0],
+ [219.0, 1.0, 1.0],
+ ]
+)
+
+points_in_atlas_space = np.array(
+ [
+ [36, 54, 97],
+ [34, 76, 124],
+ [34, 90, 146],
+ [61, 58, 82],
+ [62, 60, 145],
+ [58, 101, 120],
+ [100, 47, 60],
+ [93, 113, 53],
+ [95, 80, 146],
+ [95, 76, 161],
+ [93, 100, 148],
+ [97, 67, 141],
+ [94, 90, 161],
+ [97, 68, 151],
+ [94, 92, 148],
+ [95, 84, 139],
+ [95, 80, 152],
+ [139, 42, 60],
+ [131, 72, 132],
+ [173, 81, 56],
+ [177, 135, 155],
+ [214, 79, 129],
+ [218, 150, 194],
+ [249, 17, 10],
+ ]
+)
+
+
+@pytest.fixture
+def sample_dataframe():
+ return pd.DataFrame({"column1": [1, 2, 3], "column2": ["a", "b", "c"]})
+
+
+@pytest.fixture
+def random_json_path(tmp_path):
+ json_path = tmp_path / "random_json.json"
+ content = {
+ "name": "Pooh Bear",
+ "location": "100 acre wood",
+ "food": "Honey",
+ }
+ with open(json_path, "w") as f:
+ json.dump(content, f)
+
+ return json_path
+
+
+@pytest.fixture
+def mock_display_info(mocker):
+ return mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.display_info"
+ )
+
+
+@pytest.fixture(scope="function")
+def transformation_widget_with_transformed_points(
+ transformation_widget_with_data,
+):
+ transformation_widget_with_data.run_transform_points_to_downsampled_space()
+ transformation_widget_with_data.run_transform_downsampled_points_to_atlas_space()
+ return transformation_widget_with_data
+
+
+@pytest.fixture(scope="function")
+def transformation_widget_with_data(
+ transformation_widget_with_napari_layers, brainreg_directory
+):
+ transformation_widget_with_napari_layers.brainreg_directory = (
+ brainreg_directory
+ )
+ transformation_widget_with_napari_layers.initialise_brainreg_data()
+ return transformation_widget_with_napari_layers
+
+
+@pytest.fixture(scope="function")
+def transformation_widget(make_napari_viewer):
+ viewer = make_napari_viewer()
+ widget = TransformPoints(viewer)
+ viewer.window.add_dock_widget(widget)
+ return widget
+
+
+@pytest.fixture(scope="function")
+def transformation_widget_with_napari_layers(
+ transformation_widget, brainreg_directory
+):
+ points_layer = transformation_widget.viewer.add_points(points)
+ transformation_widget.points_layer = points_layer
+
+ raw_data = brainreg_directory / "downsampled.tiff"
+ raw_data_layer = transformation_widget.viewer.open(raw_data)
+ transformation_widget.raw_data = raw_data_layer[0]
+ return transformation_widget
+
+
+@pytest.fixture
+def brainreg_directory(test_data_registry) -> Path:
+ _ = test_data_registry.fetch(
+ "brainglobe-utils/points_transform_brainreg_directory.zip",
+ progressbar=True,
+ processor=pooch.Unzip(extract_dir=""),
+ )
+ return (
+ Path.home()
+ / ".brainglobe"
+ / "test_data"
+ / "brainglobe-utils"
+ / "points_transform_brainreg_directory"
+ )
+
+
+def test_download_brainreg_directory(brainreg_directory):
+ assert brainreg_directory.exists()
+
+
+def test_atlas_download():
+ atlas_name = "allen_mouse_50um"
+ atlas = BrainGlobeAtlas(atlas_name)
+ assert atlas.atlas_name == atlas_name
+
+
+@pytest.fixture
+def dummy_brainreg_directory() -> Path:
+ return Path("/path/to/brainreg_directory")
+
+
+@pytest.fixture
+def dummy_brainreg_file_paths(dummy_brainreg_directory) -> Paths:
+ return Paths(dummy_brainreg_directory)
+
+
+def test_initialise_brainreg_data(
+ transformation_widget_with_data, brainreg_directory
+):
+
+ assert (
+ transformation_widget_with_data.paths.brainreg_directory
+ == brainreg_directory
+ )
+ assert (
+ transformation_widget_with_data.metadata.orientation
+ == ATLAS_ORIENTATION
+ )
+ assert (
+ transformation_widget_with_data.atlas.atlas_name == "allen_mouse_50um"
+ )
+
+
+def test_get_downsampled_space(transformation_widget_with_data):
+ downsampled_space = transformation_widget_with_data.get_downsampled_space()
+ assert downsampled_space.origin_string == ATLAS_ORIENTATION
+
+
+def test_get_raw_data_space(transformation_widget_with_data):
+ raw_data_space = transformation_widget_with_data.get_raw_data_space()
+ assert raw_data_space.origin_string == RAW_DATA_ORIENTATION
+
+
+def test_call_transform_points_to_atlas_space(
+ mocker, transformation_widget_with_data
+):
+ mock_load_brainreg = mocker.patch.object(
+ transformation_widget_with_data, "load_brainreg_directory"
+ )
+ mock_transform_downsampled = mocker.patch.object(
+ transformation_widget_with_data,
+ "run_transform_points_to_downsampled_space",
+ )
+ mock_transform_atlas = mocker.patch.object(
+ transformation_widget_with_data,
+ "run_transform_downsampled_points_to_atlas_space",
+ )
+ mock_analyse_points = mocker.patch.object(
+ transformation_widget_with_data, "analyse_points"
+ )
+
+ transformation_widget_with_data.transform_points_to_atlas_space()
+ mock_load_brainreg.assert_called_once()
+ mock_transform_downsampled.assert_called_once()
+ mock_transform_atlas.assert_called_once()
+ mock_analyse_points.assert_called_once()
+
+
+def test_transform_points_to_atlas_space(
+ transformation_widget_with_transformed_points,
+):
+ np.testing.assert_array_equal(
+ transformation_widget_with_transformed_points.viewer.layers[
+ "Points in downsampled space"
+ ].data,
+ points_in_downsampled_space,
+ )
+ np.testing.assert_array_equal(
+ transformation_widget_with_transformed_points.viewer.layers[
+ "Points in atlas space"
+ ].data,
+ points_in_atlas_space,
+ )
+
+
+def test_transformation_raises_info_points_out_of_bounds(
+ transformation_widget_with_data, mock_display_info
+):
+ points_layer = transformation_widget_with_data.viewer.add_points(
+ points_outside_brain
+ )
+ transformation_widget_with_data.points_layer = points_layer
+ transformation_widget_with_data.run_transform_points_to_downsampled_space()
+ transformation_widget_with_data.run_transform_downsampled_points_to_atlas_space()
+ mock_display_info.assert_called_once_with(
+ transformation_widget_with_data,
+ "Points outside atlas",
+ "2 points fell outside the atlas space",
+ )
+
+
+def test_check_layers(transformation_widget_with_data):
+ assert transformation_widget_with_data.check_layers()
+
+
+def test_check_layers_no_layers(transformation_widget, mock_display_info):
+ transformation_widget.check_layers()
+
+ mock_display_info.assert_called_once_with(
+ transformation_widget,
+ "No layers selected",
+ "Please select the layers corresponding to the points "
+ "you would like to transform and the raw data (registered by "
+ "brainreg)",
+ )
+
+
+def test_check_layers_no_raw_data(transformation_widget, mock_display_info):
+ points_layer = transformation_widget.viewer.add_points(points)
+ transformation_widget.points_layer = points_layer
+
+ transformation_widget.check_layers()
+
+ mock_display_info.assert_called_once_with(
+ transformation_widget,
+ "No raw data layer selected",
+ "Please select a layer that corresponds to the raw "
+ "data (registered by brainreg)",
+ )
+
+
+def test_check_layers_no_points_data(
+ transformation_widget, brainreg_directory, mock_display_info
+):
+ raw_data = brainreg_directory / "downsampled.tiff"
+ raw_data_layer = transformation_widget.viewer.open(raw_data)
+ transformation_widget.raw_data = raw_data_layer[0]
+
+ transformation_widget.check_layers()
+
+ mock_display_info.assert_called_once_with(
+ transformation_widget,
+ "No points layer selected",
+ "Please select a points layer you would like to transform",
+ )
+
+
+def test_load_brainreg_directory(
+ transformation_widget_with_napari_layers, brainreg_directory, mocker
+):
+ # Mock dialog to avoid need for UI
+ mock_get_save_file_name = mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.QFileDialog.getExistingDirectory"
+ )
+ mock_get_save_file_name.return_value = brainreg_directory
+
+ transformation_widget_with_napari_layers.load_brainreg_directory()
+ assert (
+ transformation_widget_with_napari_layers.atlas.atlas_name
+ == "allen_mouse_50um"
+ )
+
+
+def test_load_brainreg_directory_no_input(
+ transformation_widget_with_napari_layers, mocker
+):
+ # Mock dialog to avoid need for UI
+ mock_get_save_file_name = mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.QFileDialog.getExistingDirectory"
+ )
+ mock_get_save_file_name.return_value = ""
+
+ transformation_widget_with_napari_layers.load_brainreg_directory()
+ assert not hasattr(
+ transformation_widget_with_napari_layers.atlas, "atlas_name"
+ )
+
+
+def test_check_brainreg_directory_correct_metadata(
+ mocker, transformation_widget_with_data
+):
+ mock_method = mocker.patch.object(
+ transformation_widget_with_data, "display_brainreg_directory_warning"
+ )
+
+ transformation_widget_with_data.check_brainreg_directory()
+ mock_method.assert_not_called()
+
+
+def test_check_brainreg_directory_random_data(
+ mocker, transformation_widget_with_data, random_json_path
+):
+ mock_method = mocker.patch.object(
+ transformation_widget_with_data, "display_brainreg_directory_warning"
+ )
+ transformation_widget_with_data.paths.brainreg_metadata_file = (
+ random_json_path
+ )
+ transformation_widget_with_data.check_brainreg_directory()
+ mock_method.assert_called_once()
+
+
+def test_check_brainreg_directory_false_path(
+ mocker, transformation_widget_with_data
+):
+ mock_method = mocker.patch.object(
+ transformation_widget_with_data, "display_brainreg_directory_warning"
+ )
+
+ transformation_widget_with_data.paths.brainreg_metadata_file = "/some/file"
+ transformation_widget_with_data.check_brainreg_directory()
+ mock_method.assert_called_once()
+
+
+def test_display_brainreg_directory_warning_calls_display_info(
+ transformation_widget_with_napari_layers, mock_display_info
+):
+ transformation_widget_with_napari_layers.display_brainreg_directory_warning()
+
+ # Assert display_info was called once with the expected arguments
+ mock_display_info.assert_called_once_with(
+ transformation_widget_with_napari_layers,
+ "Not a brainreg directory",
+ "This directory does not appear to be a valid brainreg directory. "
+ "Please try loading another brainreg output directory.",
+ )
+
+
+def test_analysis(transformation_widget_with_transformed_points):
+ transformation_widget_with_transformed_points.analyse_points()
+
+ assert (
+ transformation_widget_with_transformed_points.all_points_df.shape[0]
+ == 21
+ )
+
+ df = transformation_widget_with_transformed_points.points_per_region_df
+ assert (
+ df.loc[
+ df["structure_name"] == "Caudoputamen", "left_cell_count"
+ ].values[0]
+ == 9
+ )
+ assert (
+ df.loc[df["structure_name"] == "Pons", "left_cell_count"].values[0]
+ == 1
+ )
+ assert (
+ df.loc[
+ df["structure_name"]
+ == "Primary somatosensory area, upper limb, layer 5",
+ "left_cells_per_mm3",
+ ].values[0]
+ == 0
+ )
+
+
+def test_save_df_to_csv(
+ mocker,
+ transformation_widget_with_transformed_points,
+ sample_dataframe,
+ tmp_path,
+):
+ # Mock dialog to avoid need for UI
+ mock_get_save_file_name = mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.QFileDialog.getSaveFileName"
+ )
+
+ save_path = tmp_path / "test.csv"
+ mock_get_save_file_name.return_value = (save_path, "CSV Files (*.csv)")
+
+ transformation_widget_with_transformed_points.save_df_to_csv(
+ sample_dataframe
+ )
+
+ # Ensure the file dialog was called
+ mock_get_save_file_name.assert_called_once_with(
+ transformation_widget_with_transformed_points,
+ "Choose filename",
+ "",
+ "CSV Files (*.csv)",
+ )
+
+ assert save_path.exists()
+
+
+def test_save_all_points_and_summary_csv(
+ mocker,
+ transformation_widget_with_transformed_points,
+ tmp_path,
+):
+ transformation_widget_with_transformed_points.analyse_points()
+
+ # Mock dialog to avoid need for UI
+ mock_get_save_file_name = mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.QFileDialog.getSaveFileName"
+ )
+
+ save_path = tmp_path / "all_points.csv"
+ mock_get_save_file_name.return_value = (save_path, "CSV Files (*.csv)")
+ transformation_widget_with_transformed_points.save_all_points_csv()
+ assert save_path.exists()
+
+ save_path = tmp_path / "points_per_region.csv"
+ mock_get_save_file_name.return_value = (save_path, "CSV Files (*.csv)")
+ transformation_widget_with_transformed_points.save_points_summary_csv()
+ assert save_path.exists()
+
+
+def test_is_atlas_installed(mocker, transformation_widget):
+ mock_get_downloaded_atlases = mocker.patch(
+ "brainglobe_utils.brainmapper.transform_widget.get_downloaded_atlases"
+ )
+ mock_get_downloaded_atlases.return_value = [
+ "allen_mouse_10um",
+ "allen_mouse_50um",
+ ]
+
+ assert transformation_widget.is_atlas_installed("allen_mouse_10um")
+ assert not transformation_widget.is_atlas_installed("allen_mouse_25um")
+
+
+def test_paths_initialisation(
+ dummy_brainreg_file_paths, dummy_brainreg_directory
+):
+ assert (
+ dummy_brainreg_file_paths.brainreg_directory
+ == dummy_brainreg_directory
+ )
+ assert (
+ dummy_brainreg_file_paths.brainreg_metadata_file
+ == dummy_brainreg_directory / "brainreg.json"
+ )
+ assert (
+ dummy_brainreg_file_paths.deformation_field_0
+ == dummy_brainreg_directory / "deformation_field_0.tiff"
+ )
+ assert (
+ dummy_brainreg_file_paths.downsampled_image
+ == dummy_brainreg_directory / "downsampled.tiff"
+ )
+ assert (
+ dummy_brainreg_file_paths.volume_csv_path
+ == dummy_brainreg_directory / "volumes.csv"
+ )
+
+
+def test_make_filepaths(dummy_brainreg_file_paths, dummy_brainreg_directory):
+ filename = "test_file.txt"
+ expected_path = dummy_brainreg_directory / filename
+ assert dummy_brainreg_file_paths.make_filepaths(filename) == expected_path
+
+
+@pytest.fixture
+def sample_metadata() -> Dict[str, Any]:
+ return {
+ "orientation": "prs",
+ "atlas": "allen_mouse_25um",
+ "voxel_sizes": [5, 2, 2],
+ }
+
+
+def test_metadata_initialisation(sample_metadata):
+ metadata = Metadata(sample_metadata)
+ assert metadata.orientation == sample_metadata["orientation"]
+ assert metadata.atlas_string == sample_metadata["atlas"]
+ assert metadata.voxel_sizes == sample_metadata["voxel_sizes"]
diff --git a/tests/tests/test_brainreg/__init__.py b/tests/tests/test_brainreg/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/tests/test_general/test_system.py b/tests/tests/test_general/test_system.py
index 16ef4cf..ea851b8 100644
--- a/tests/tests/test_general/test_system.py
+++ b/tests/tests/test_general/test_system.py
@@ -64,6 +64,22 @@ def mock_statvfs():
return mock_stats
+def test_ensure_extension():
+ assert system.ensure_extension("example.txt", ".txt") == Path(
+ "example.txt"
+ )
+ assert system.ensure_extension(Path("example.txt"), ".txt") == Path(
+ "example.txt"
+ )
+
+ assert system.ensure_extension("example.md", ".txt") == Path("example.txt")
+ assert system.ensure_extension(Path("example.md"), ".txt") == Path(
+ "example.txt"
+ )
+
+ assert system.ensure_extension("example", ".txt") == Path("example.txt")
+
+
def test_replace_extension():
test_file = "test_file.sh"
test_ext = "txt"
diff --git a/tests/tests/test_qtpy/test_table.py b/tests/tests/test_qtpy/test_table.py
new file mode 100644
index 0000000..952e4d6
--- /dev/null
+++ b/tests/tests/test_qtpy/test_table.py
@@ -0,0 +1,46 @@
+import pandas as pd
+import pytest
+from qtpy.QtCore import Qt
+
+from brainglobe_utils.qtpy.table import DataFrameModel
+
+
+@pytest.fixture
+def sample_df():
+ return pd.DataFrame(
+ {"A": [1, 2, 3], "B": ["cat", "dog", "rabbit"], "C": [7, 8, 9]}
+ )
+
+
+@pytest.fixture
+def model(sample_df):
+ return DataFrameModel(sample_df)
+
+
+def test_row_count(model, sample_df):
+ assert model.rowCount() == sample_df.shape[0]
+
+
+def test_column_count(model, sample_df):
+ assert model.columnCount() == sample_df.shape[1]
+
+
+def test_data(model):
+ index = model.index(0, 0)
+ assert model.data(index, Qt.DisplayRole) == "1"
+ index = model.index(1, 1)
+ assert model.data(index, Qt.DisplayRole) == "dog"
+ index = model.index(2, 2)
+ assert model.data(index, Qt.DisplayRole) == "9"
+
+
+def test_header_data(model, sample_df):
+ assert (
+ model.headerData(0, Qt.Vertical, Qt.DisplayRole) == sample_df.index[0]
+ )
+ assert (
+ model.headerData(1, Qt.Vertical, Qt.DisplayRole) == sample_df.index[1]
+ )
+ assert (
+ model.headerData(2, Qt.Vertical, Qt.DisplayRole) == sample_df.index[2]
+ )