diff --git a/atlas_densities/app/cell_densities.py b/atlas_densities/app/cell_densities.py index 8334bde..f60a87b 100644 --- a/atlas_densities/app/cell_densities.py +++ b/atlas_densities/app/cell_densities.py @@ -82,7 +82,6 @@ from atlas_densities.densities.measurement_to_density import ( measurement_to_average_density, remove_non_density_measurements, - remove_unknown_regions, ) from atlas_densities.exceptions import AtlasDensitiesError @@ -622,6 +621,12 @@ def compile_measurements( @app.command() @common_atlas_options +@click.option( + "--region-name", + type=str, + default="root", + help="Name of the root region in the hierarchy", +) @click.option( "--cell-density-path", type=EXISTING_FILE_PATH, @@ -655,6 +660,7 @@ def compile_measurements( def measurements_to_average_densities( annotation_path, hierarchy_path, + region_name, cell_density_path, neuron_density_path, measurements_path, @@ -714,7 +720,6 @@ def measurements_to_average_densities( region_map = RegionMap.load_json(hierarchy_path) L.info("Loading measurements ...") measurements_df = pd.read_csv(measurements_path) - remove_unknown_regions(measurements_df, region_map, annotation.raw) L.info("Measurement to average density: started") average_cell_densities_df = measurement_to_average_density( @@ -725,6 +730,7 @@ def measurements_to_average_densities( overall_cell_density.raw, neuron_density.raw, measurements_df, + region_name, ) remove_non_density_measurements(average_cell_densities_df) @@ -912,8 +918,6 @@ def fit_average_densities( L.info("Loading average densities dataframe ...") average_densities_df = pd.read_csv(average_densities_path) homogenous_regions_df = pd.read_csv(homogenous_regions_path) - remove_unknown_regions(average_densities_df, region_map, annotation.raw, root=region_name) - remove_unknown_regions(homogenous_regions_df, region_map, annotation.raw, root=region_name) L.info("Fitting of average densities: started") fitted_densities_df, fitting_maps = linear_fitting( diff --git a/atlas_densities/densities/fitting.py b/atlas_densities/densities/fitting.py index 1fa8dc6..14428ad 100644 --- a/atlas_densities/densities/fitting.py +++ b/atlas_densities/densities/fitting.py @@ -29,6 +29,7 @@ from tqdm import tqdm from atlas_densities.densities import utils +from atlas_densities.densities.measurement_to_density import remove_unknown_regions from atlas_densities.exceptions import AtlasDensitiesError, AtlasDensitiesWarning if TYPE_CHECKING: # pragma: no cover @@ -625,6 +626,9 @@ def linear_fitting( # pylint: disable=too-many-arguments _check_homogenous_regions_sanity(homogenous_regions) hierarchy_info = utils.get_hierarchy_info(region_map, root=region_name) + remove_unknown_regions(average_densities, region_map, annotation, hierarchy_info) + remove_unknown_regions(homogenous_regions, region_map, annotation, hierarchy_info) + L.info("Creating a data frame from known densities ...") densities = create_dataframe_from_known_densities( hierarchy_info["brain_region"].to_list(), average_densities diff --git a/atlas_densities/densities/measurement_to_density.py b/atlas_densities/densities/measurement_to_density.py index 4cb1ff5..65170e4 100644 --- a/atlas_densities/densities/measurement_to_density.py +++ b/atlas_densities/densities/measurement_to_density.py @@ -261,7 +261,7 @@ def remove_unknown_regions( measurements: "pd.DataFrame", region_map: RegionMap, annotation: AnnotationT, - root: str = "root", + hierarchy_info: "pd.DataFrame", ): """ Drop lines from the measurements dataframe which brain regions are not in the AIBS brain region @@ -274,9 +274,9 @@ def remove_unknown_regions( region_map: RegionMap object to navigate the brain regions hierarchy. annotation: int array of shape (W, H, D) holding the annotation of the whole AIBS mouse brain. (The integers W, H and D are the dimensions of the array). - root: name of the root region to consider in the hierarchy. + hierarchy_info: data frame returned by + :func:`atlas_densities.densities.utils.get_hierarchy_info`. """ - hierarchy_info = get_hierarchy_info(region_map, root) pd.set_option("display.max_colwidth", None) indices_ids = measurements.index[ ~measurements["brain_region"].isin(hierarchy_info["brain_region"]) @@ -309,7 +309,7 @@ def remove_unknown_regions( measurements.drop(indices_ann, inplace=True) -def measurement_to_average_density( +def measurement_to_average_density( # pylint: disable=too-many-arguments region_map: RegionMap, annotation: AnnotationT, voxel_dimensions: Tuple[float, float, float], @@ -317,6 +317,7 @@ def measurement_to_average_density( cell_density: FloatArray, neuron_density: FloatArray, measurements: "pd.DataFrame", + root_region: str = "Basic cell groups and regions", ) -> "pd.DataFrame": """ Compute average cell densities in AIBS brain regions based on experimental `measurements`. @@ -342,6 +343,7 @@ def measurement_to_average_density( in that voxel expressed in number of neurons per mm^3. measurements: dataframe whose columns are described in :func:`atlas_densities.app.densities.compile_measurements`. + root_region: name of the root region in the brain region hierarchy. Returns: dataframe of the same format as `measurements` but where all measurements of type @@ -349,7 +351,8 @@ def measurement_to_average_density( type "cell density". Densities are expressed in number of cells per mm^3. """ - hierarchy_info = get_hierarchy_info(region_map) + hierarchy_info = get_hierarchy_info(region_map, root_region) + remove_unknown_regions(measurements, region_map, annotation, hierarchy_info) # Replace NaN standard deviations by measurement values nan_mask = measurements["standard_deviation"].isna() diff --git a/tests/app/test_cell_densities.py b/tests/app/test_cell_densities.py index ddc894d..df2d5dc 100644 --- a/tests/app/test_cell_densities.py +++ b/tests/app/test_cell_densities.py @@ -271,6 +271,8 @@ def _get_measurements_to_average_densities_result(runner, hierarchy_path, measur hierarchy_path, "--annotation-path", "annotation.nrrd", + "--region-name", + "Basic cell groups and regions", "--cell-density-path", "cell_density.nrrd", "--neuron-density-path", diff --git a/tests/densities/test_measurement_to_density.py b/tests/densities/test_measurement_to_density.py index 4104b68..440d918 100644 --- a/tests/densities/test_measurement_to_density.py +++ b/tests/densities/test_measurement_to_density.py @@ -75,9 +75,7 @@ def test_remove_unknown_regions(region_map, annotations): "source_title": ["Article 1", "Article 2", "Article 1"], } ) - tested.remove_unknown_regions( - measurements, region_map, annotations, root="Basic cell groups and regions" - ) + tested.remove_unknown_regions(measurements, region_map, annotations, get_hierarchy_info()) expected = pd.DataFrame( { "brain_region": [