Skip to content

Commit

Permalink
make min_data_points a command like parameter for fit-average-densi…
Browse files Browse the repository at this point in the history
…ties
  • Loading branch information
mgeplf committed May 15, 2024
1 parent d074b92 commit 74798d5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
9 changes: 9 additions & 0 deletions atlas_densities/app/cell_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,13 @@ def measurements_to_average_densities(
help="Path to density groups ids config",
show_default=True,
)
@click.option(
"--min-data-points",
type=int,
default=5,
help="minimum number of datapoints required for running the linear regression.",
show_default=True,
)
@log_args(L)
def fit_average_densities(
hierarchy_path,
Expand All @@ -820,6 +827,7 @@ def fit_average_densities(
fitted_densities_output_path,
fitting_maps_output_path,
group_ids_config_path,
min_data_points,
): # pylint: disable=too-many-arguments, too-many-locals
"""
Estimate average cell densities of brain regions in `hierarchy_path` for the cell types
Expand Down Expand Up @@ -935,6 +943,7 @@ def fit_average_densities(
cell_density_stddev,
region_name=region_name,
group_ids_config=group_ids_config,
min_data_points=min_data_points,
)

# Turn index into column to ease off the save and load operations on csv files
Expand Down
26 changes: 18 additions & 8 deletions atlas_densities/densities/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from atlas_densities.densities import utils
from atlas_densities.densities.measurement_to_density import remove_unknown_regions
from atlas_densities.exceptions import AtlasDensitiesError, AtlasDensitiesWarning
import json

L = logging.getLogger(__name__)

Expand Down Expand Up @@ -368,8 +367,10 @@ def compute_average_intensities(


def linear_fitting_xy(
xdata: list[float], ydata: list[float], sigma: Union[list[float], FloatArray],
min_data_points: int = 5
xdata: list[float],
ydata: list[float],
sigma: Union[list[float], FloatArray],
min_data_points: int,
) -> dict[str, float]:
"""
Compute the coefficient of the linear least-squares regression of the point cloud
Expand Down Expand Up @@ -441,7 +442,10 @@ def _optimize_func(x, coefficient):


def compute_fitting_coefficients(
groups: dict[str, set[str]], average_intensities: pd.DataFrame, densities: pd.DataFrame
groups: dict[str, set[str]],
average_intensities: pd.DataFrame,
densities: pd.DataFrame,
min_data_points: int,
) -> FittingData:
"""
Compute the linear fitting coefficient of the cloud of 2D points (average marker intensity,
Expand Down Expand Up @@ -553,9 +557,10 @@ def compute_fitting_coefficients(
for cell_type in tqdm(cell_types):
cloud = clouds[group_name][cell_type]
L.debug(
f"The length of training data for {group_name} and {cell_type} is {cloud['xdata']}")
f"The length of training data for {group_name} and {cell_type} is {cloud['xdata']}"
)
result[group_name][cell_type] = linear_fitting_xy(
cloud["xdata"], cloud["ydata"], cloud["sigma"]
cloud["xdata"], cloud["ydata"], cloud["sigma"], min_data_points
)
if np.isnan(result[group_name][cell_type]["coefficient"]):
warnings.warn(
Expand Down Expand Up @@ -739,6 +744,7 @@ def linear_fitting( # pylint: disable=too-many-arguments
cell_density_stddevs: Optional[dict[str, float]] = None,
region_name: str = "root",
group_ids_config: dict | None = None,
min_data_points: int = 5,
) -> pd.DataFrame:
"""
Estimate the average densities of every region in `region_map` using a linear fitting
Expand Down Expand Up @@ -781,6 +787,10 @@ def linear_fitting( # pylint: disable=too-many-arguments
standard deviations of average cell densities of the corresponding regions.
region_name: (str) name of the root region of interest
group_ids_config: mapping of regions to their constituent ids
min_data_points: minimum number of datapoints required for running
the linear regression. If the number of datapoints is less than
min_data_points then no fitting is done, and np.nan values are
returned.
Returns:
tuple (densities, fitting_coefficients)
Expand Down Expand Up @@ -842,8 +852,8 @@ def linear_fitting( # pylint: disable=too-many-arguments

L.info("Computing fitting coefficients ...")
fitting_coefficients = compute_fitting_coefficients(
groups, average_intensities, densities.drop(densities.index[indexes])
)
groups, average_intensities, densities.drop(densities.index[indexes]),
min_data_points=min_data_points)
L.info("Fitting unknown average densities ...")
fit_unknown_densities(groups, average_intensities, densities, fitting_coefficients)

Expand Down
22 changes: 15 additions & 7 deletions tests/densities/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,18 @@ def test_compute_average_intensities(region_map, hierarchy_info):


def test_linear_fitting_xy():
actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [1.0, 1.0, 1.0])
min_data_points = 1
actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [1.0, 1.0, 1.0], min_data_points=min_data_points)
assert np.allclose(actual["coefficient"], 2.0)
assert np.allclose(actual["r_square"], 1.0)
assert not np.isinf(actual["standard_deviation"])

actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 1.0, 4.0], [1.0, 1.0, 1e-5])
actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 1.0, 4.0], [1.0, 1.0, 1e-5], min_data_points=min_data_points)
assert np.allclose(actual["coefficient"], 2.0)
assert not np.isinf(actual["standard_deviation"])
assert np.allclose(actual["r_square"], 0.89286)

actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [1.0, 0.0, 1.0])
actual = tested.linear_fitting_xy([0.0, 1.0, 2.0], [0.0, 2.0, 4.0], [1.0, 0.0, 1.0], min_data_points=min_data_points)
assert np.allclose(actual["coefficient"], 2.0)
assert not np.isinf(actual["standard_deviation"])
assert np.allclose(actual["r_square"], 1.0)
Expand Down Expand Up @@ -319,7 +320,8 @@ def test_compute_fitting_coefficients(hierarchy_info):
data = get_fitting_input_data_(hierarchy_info)

actual = tested.compute_fitting_coefficients(
data["groups"], data["intensities"], data["densities"]
data["groups"], data["intensities"], data["densities"],
min_data_points=1,
)

for group_name in ["Whole", "Central lobule"]:
Expand All @@ -340,18 +342,24 @@ def test_compute_fitting_coefficients_exceptions(hierarchy_info):
data["densities"].drop(index=["Central lobule"], inplace=True)

with pytest.raises(AtlasDensitiesError):
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"])
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"],
min_data_points=1,
)

data = get_fitting_input_data_(hierarchy_info)
data["densities"].drop(columns=["pv+"], inplace=True)

with pytest.raises(AtlasDensitiesError):
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"])
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"],
min_data_points=1,
)

data = get_fitting_input_data_(hierarchy_info)
data["densities"].at["Lobule II", "pv+_standard_deviation"] = np.nan
with pytest.raises(AssertionError):
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"])
tested.compute_fitting_coefficients(data["groups"], data["intensities"], data["densities"],
min_data_points=1,
)


@pytest.fixture
Expand Down

0 comments on commit 74798d5

Please sign in to comment.