Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Handling Unrecognized Arguments for Prepare #696

Merged
merged 45 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e21677e
remove unnecessary kwargs absorption
selmanozleyen May 21, 2024
6582ab8
ensure birth death estimate marginals throws error on unrecognized args
selmanozleyen May 21, 2024
0785fa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
ff1f25f
if cost is not specified no need to set to None, otherwise goes to es…
selmanozleyen May 21, 2024
0892c6b
now any passed and not used kwargs are going to throw an error for pr…
selmanozleyen May 21, 2024
375d3f7
saving the state currently in progress. will document changes in PR b…
selmanozleyen May 27, 2024
9105415
shouldn't specify joint_attr in GWProblem
selmanozleyen May 27, 2024
5b6c252
remove kwargs from _lineage
selmanozleyen May 27, 2024
ba77fac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
e1718c9
Fix `scale_cost` given to prepare in tests/problems/time/test_mixins.py
selmanozleyen May 27, 2024
e15c510
fix lint issue
selmanozleyen May 27, 2024
ad35962
fix mapping proxy problem
selmanozleyen May 27, 2024
9a21bb3
Merge branch 'main' into refactor/arg_check
selmanozleyen Jun 9, 2024
6a3021d
refactor handle_joint_attrs
selmanozleyen Jun 9, 2024
7270058
get rid of unused functions
selmanozleyen Jun 9, 2024
e72aa2a
fix for CI fail
selmanozleyen Jun 12, 2024
029d5b8
update scripts
selmanozleyen Jun 12, 2024
84567b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
2dc9f94
update scripts
selmanozleyen Jun 12, 2024
9e2bf29
update the script
selmanozleyen Jun 12, 2024
8ba07ef
Merge branch 'main' into refactor/arg_check
selmanozleyen Jun 12, 2024
d9a790a
remove tutorials
selmanozleyen Jun 13, 2024
b796d84
update the docs
selmanozleyen Jun 13, 2024
5ff3005
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
7535065
update pyproject
selmanozleyen Jun 13, 2024
6fc0b68
Merge branch 'main' into refactor/arg_check
selmanozleyen Jun 14, 2024
9bb6566
Merge branch 'main' into refactor/arg_check
selmanozleyen Jun 17, 2024
6228773
remove prepare of BDProblem
selmanozleyen Jun 18, 2024
424bdc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2024
09c281e
added throwing a warning
selmanozleyen Jun 18, 2024
810f371
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2024
094507e
Merge branch 'main' into refactor/arg_check
selmanozleyen Jun 25, 2024
7d0a6f5
linting
selmanozleyen Jun 25, 2024
42a6ccf
fix for sparse
selmanozleyen Jun 25, 2024
c90ac6f
set version for docstring inheritence
selmanozleyen Jun 25, 2024
74d8062
fix sparse array
selmanozleyen Jun 25, 2024
c52b5e7
update kwargs
selmanozleyen Jun 25, 2024
f41ec32
fix linting
selmanozleyen Jun 25, 2024
b11efac
update commit
selmanozleyen Jun 28, 2024
c7b126e
Merge branch 'main' into refactor/arg_check
selmanozleyen Jul 3, 2024
a9902e1
update nb
selmanozleyen Jul 3, 2024
a0984cf
fix typo
selmanozleyen Jul 3, 2024
94e6cdb
Merge branch 'main' into refactor/arg_check
selmanozleyen Jul 3, 2024
add9ca6
Merge branch 'main' into refactor/arg_check
selmanozleyen Jul 3, 2024
085eae9
Merge branch 'main' into refactor/arg_check
selmanozleyen Jul 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"ott-jax>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance" # TODO: Set version
]

[project.optional-dependencies]
Expand Down
80 changes: 70 additions & 10 deletions src/moscot/base/problems/birth_death.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import types
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -159,14 +163,53 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
Keyword arguments for :class:`~moscot.base.problems.OTProblem`.
""" # noqa: D205

def prepare(
selmanozleyen marked this conversation as resolved.
Show resolved Hide resolved
self,
xy: Mapping[str, Any],
x: Mapping[str, Any],
y: Mapping[str, Any],
a: Optional[Union[bool, str, ArrayLike]] = None,
b: Optional[Union[bool, str, ArrayLike]] = None,
marginal_kwargs: Dict[str, Any] = types.MappingProxyType({}),
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
) -> "BirthDeathProblem":
"""Prepare the problem by scoring genes for proliferation and apoptosis.

All arguments except for `proliferation_key` and `apoptosis_key` are inherited from :meth:`OTProblem.prepare`.

Parameters
----------
proliferation_key
Key in :attr:`~anndata.AnnData.obs` where proliferation scores are stored.
apoptosis_key
Key in :attr:`~anndata.AnnData.obs` where apoptosis scores are stored.

"""
self.proliferation_key = proliferation_key
self.apoptosis_key = apoptosis_key
marginal_kwargs = dict(marginal_kwargs)
if proliferation_key is not None:
marginal_kwargs["proliferation_key"] = proliferation_key
if apoptosis_key is not None:
marginal_kwargs["apoptosis_key"] = apoptosis_key
return super().prepare(xy=xy, x=x, y=y, a=a, b=b, marginal_kwargs=marginal_kwargs)

def estimate_marginals(
self, # type: BirthDeathProblemProtocol
adata: AnnData,
source: bool,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
scaling: Optional[float] = None,
**kwargs: Any,
beta_max: float = 1.7,
beta_min: float = 0.3,
beta_center: float = 0.25,
beta_width: float = 0.5,
delta_max: float = 1.7,
delta_min: float = 0.3,
delta_center: float = 0.1,
delta_width: float = 0.2,
) -> ArrayLike:
"""Estimate the source or target :term:`marginals` based on marker genes, either with the
`birth-death process <https://en.wikipedia.org/wiki/Birth%E2%80%93death_process>`_,
Expand All @@ -189,9 +232,22 @@ def estimate_marginals(
If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel
with proliferation and apoptosis scores.
If :obj:`None`, parameters corresponding to the birth and death processes will be used.
kwargs
Keyword arguments for :func:`~moscot.base.problems.birth_death.beta` and
:func:`~moscot.base.problems.birth_death.delta`.
beta_max
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_min
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_center
Argument for :func:`~moscot.base.problems.birth_death.beta`
beta_width
Argument for :func:`~moscot.base.problems.birth_death.beta`
delta_max
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_min
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_center
Argument for :func:`~moscot.base.problems.birth_death.delta`
delta_width
Argument for :func:`~moscot.base.problems.birth_death.delta`

Returns
-------
Expand Down Expand Up @@ -223,12 +279,18 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any)
self.apoptosis_key = apoptosis_key

if scaling:
beta_fn = delta_fn = lambda x, *_, **__: x
beta_fn = delta_fn = lambda x: x
else:
beta_fn, delta_fn = beta, delta
beta_fn = partial(
beta, beta_max=beta_max, beta_min=beta_min, beta_center=beta_center, beta_width=beta_width
)
delta_fn = partial(
delta, delta_max=delta_max, delta_min=delta_min, delta_center=delta_center, delta_width=delta_width
)

scaling = 1.0
birth = estimate(proliferation_key, fn=beta_fn, **kwargs)
death = estimate(apoptosis_key, fn=delta_fn, **kwargs)
birth = estimate(proliferation_key, fn=beta_fn)
death = estimate(apoptosis_key, fn=delta_fn)

prior_growth = np.exp((birth - death) * self.delta / scaling)

Expand Down Expand Up @@ -287,7 +349,6 @@ def beta(
beta_min: float = 0.3,
beta_center: float = 0.25,
beta_width: float = 0.5,
**_: Any,
) -> ArrayLike:
"""Birth process."""
return _gen_logistic(p, beta_max, beta_min, beta_center, beta_width)
Expand All @@ -299,7 +360,6 @@ def delta(
delta_min: float = 0.3,
delta_center: float = 0.1,
delta_width: float = 0.2,
**_: Any,
) -> ArrayLike:
"""Death process."""
return _gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
54 changes: 42 additions & 12 deletions src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def _create_problems(
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
a: Optional[Union[bool, str, ArrayLike]] = None,
b: Optional[Union[bool, str, ArrayLike]] = None,
marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> Dict[Tuple[K, K], B]:
from moscot.base.problems.birth_death import BirthDeathProblem

if TYPE_CHECKING:
assert isinstance(self._policy, SubsetPolicy)
Expand Down Expand Up @@ -187,10 +188,7 @@ def _create_problems(
if y_data:
y = dict(y)
y["tagged_array"] = y_data
if isinstance(problem, BirthDeathProblem):
kwargs["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined]
kwargs["apoptosis_key"] = self.apoptosis_key # type: ignore[attr-defined]
problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, **kwargs)
problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, a=a, b=b, marginal_kwargs=marginal_kwargs)

return problems

Expand All @@ -200,13 +198,18 @@ def prepare(
key: Optional[str],
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
xy: Mapping[str, Any] = types.MappingProxyType({}),
x: Mapping[str, Any] = types.MappingProxyType({}),
y: Mapping[str, Any] = types.MappingProxyType({}),
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
a: Optional[Union[bool, str, ArrayLike]] = None,
b: Optional[Union[bool, str, ArrayLike]] = None,
marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> "BaseCompoundProblem[K, B]":
"""Prepare the individual :term:`OT` subproblems.

Expand All @@ -224,6 +227,12 @@ def prepare(
for the :class:`~moscot.utils.subset_policy.ExplicitPolicy`. Only used when ``policy = 'explicit'``.
reference
Reference for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. Only used when ``policy = 'star'``.
xy
Data for the :term:`linear term`.
x
Data for the source :term:`quadratic term`.
y
Data for the target :term:`quadratic term`.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Expand All @@ -236,8 +245,24 @@ def prepare(
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
kwargs
Keyword arguments for the subproblems' :meth:`~moscot.base.problems.OTProblem.prepare` method.
a
Source :term:`marginals`. Valid options are:

- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
b
Target :term:`marginals`. Valid options are:

- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
marginal_kwargs
Keyword arguments for the :meth:`~moscot.base.problems.OTProblem.estimate_marginals` method.

Returns
-------
Expand All @@ -264,13 +289,18 @@ def prepare(
# when refactoring the callback, consider changing this
self._problem_manager = ProblemManager(self, policy=policy)
problems = self._create_problems(
xy_callback=xy_callback,
x=x,
y=y,
xy=xy,
a=a,
b=b,
x_callback=x_callback,
y_callback=y_callback,
xy_callback_kwargs=xy_callback_kwargs,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
**kwargs,
xy_callback_kwargs=xy_callback_kwargs,
marginal_kwargs=marginal_kwargs,
)
self._problem_manager.add_problems(problems)

Expand Down
25 changes: 14 additions & 11 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import cloudpickle
from docstring_inheritance import NumpyDocstringInheritanceMeta

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,7 +44,11 @@
__all__ = ["BaseProblem", "OTProblem"]


class BaseProblem(abc.ABC):
class CombinedMeta(abc.ABCMeta, NumpyDocstringInheritanceMeta):
pass


class BaseProblem(abc.ABC, metaclass=CombinedMeta):
"""Base class for all :term:`OT` problems."""

def __init__(self):
Expand Down Expand Up @@ -289,7 +294,7 @@ def prepare(
y: Mapping[str, Any],
a: Optional[Union[bool, str, ArrayLike]] = None,
b: Optional[Union[bool, str, ArrayLike]] = None,
**kwargs: Any,
marginal_kwargs: Dict[str, Any] = types.MappingProxyType({}),
) -> "OTProblem":
"""Prepare the :term:`OT` problem.

Expand Down Expand Up @@ -326,7 +331,7 @@ def prepare(
from :attr:`adata_tgt`, otherwise use uniform marginals.
- :class:`~numpy.ndarray` - array of shape ``[m,]`` containing the target marginals.
- :obj:`None` - uniform marginals.
kwargs
marginal_kwargs
Keyword arguments for :meth:`estimate_marginals` when ``a = True`` or ``b = True``.

Returns
Expand Down Expand Up @@ -370,8 +375,8 @@ def prepare(
else:
raise ValueError("Unable to prepare the data. Either only supply `xy=...`, or `x=..., y=...`, or all.")
# fmt: on
self._a = self._create_marginals(self.adata_src, data=a, source=True, **kwargs)
self._b = self._create_marginals(self.adata_tgt, data=b, source=False, **kwargs)
self._a = self._create_marginals(self.adata_src, data=a, source=True, marginal_kwargs=marginal_kwargs)
self._b = self._create_marginals(self.adata_tgt, data=b, source=False, marginal_kwargs=marginal_kwargs)
return self

@wrap_solve
Expand Down Expand Up @@ -626,7 +631,7 @@ def _spatial_norm_callback(
raise ValueError("When `term` is `y`, `adata_y` cannot be `None`.")
adata = adata_y
if attr is None:
raise ValueError("`attrs` cannot be `None` with this callback.")
raise ValueError("`attr` cannot be `None` with this callback.")
spatial = TaggedArray._extract_data(adata, attr=attr, key=key)

logger.info(f"Normalizing spatial coordinates of `{term}`.")
Expand Down Expand Up @@ -669,10 +674,9 @@ def _create_marginals(
source: bool,
data: Optional[Union[bool, str, ArrayLike]] = None,
marginal_kwargs: Dict[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
) -> ArrayLike:
if data is True:
marginals = self.estimate_marginals(adata, source=source, **marginal_kwargs, **kwargs)
if data is True: # this is the only case when kwargs are passed
marginals = self.estimate_marginals(adata, source=source, **marginal_kwargs)
elif data is False or data is None:
marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs
elif isinstance(data, str):
Expand All @@ -690,7 +694,7 @@ def _create_marginals(
)
return marginals

def estimate_marginals(self, adata: AnnData, *, source: bool, **kwargs: Any) -> ArrayLike:
def estimate_marginals(self, adata: AnnData, *, source: bool) -> ArrayLike:
"""Estimate the source or target :term:`marginals`.

.. note::
Expand All @@ -709,7 +713,6 @@ def estimate_marginals(self, adata: AnnData, *, source: bool, **kwargs: Any) ->
-------
The estimated source or target marginals of shape ``[n,]`` or ``[m,]``, depending on the ``source``.
"""
del kwargs
return np.ones((adata.n_obs,), dtype=float) / adata.n_obs

def set_graph_xy(
Expand Down
Loading
Loading