Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

ENH: Make models inherit from base model #176

Merged
merged 6 commits into from
Jun 13, 2024

Conversation

jhlegarreta
Copy link
Collaborator

Make models inherit from base model.

@jhlegarreta
Copy link
Collaborator Author

@effigies @oesteban Following this comment #166 (comment), I had a look at inheriting the models from ModelBase. Things were unclear, and I ended up adding my questions as comments in the code. Not good practice for reviewing, but would be grateful if you had a look and commented.

The main point is that the BaseClass' fit and predict methods contain a lot of code that is not being call at all. The initialization method is not used either. The latter may be easy to fix with a call to the superclass init from the child classes, but it is unclear how the code in the fit and predict methods of BaseClass should be reused.

This is also related to issue nipreps/nifreeze#16.

So comments would be appreciated.

Copy link
Member

@oesteban oesteban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will make the code much more readable -- left some comments.

src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch 2 times, most recently from df51602 to 5d58f56 Compare April 30, 2024 23:16

def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
return retval, chunk


def _exec_predict(model, gradient, chunk=None, **kwargs):
def _exec_predict_dwi(model, gradient, chunk=None, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this method to contain the dwi label, as it requires a gradient and optionally uses a S0 argument. If gradient in reality should be an index, it may be renamed back. Also, as things are right now, I do not see the need to pass an S0 either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this back to be general, and make gradient an index -- will work on this through a PR to this branch.


gradient = _rasb2dipy(gradient)
self._gtab = _rasb2dipy(self._gtab)
Copy link
Collaborator Author

@jhlegarreta jhlegarreta Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a naming confusion here: the models expect a RAS+b gradient object (which is different from the dipy gtab object) into a gtab parameter. Am I correct @oesteban ?

"""Predict asynchronously chunk-by-chunk the diffusion signal."""
if self._b_max is not None:
gradient[-1] = min(gradient[-1], self._b_max)
index[-1] = min(index[-1], self._b_max)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not OK. If the gradients are capped, not sure how the indices get affected/how they should be checked.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is capping the b value (only the last item of gradient. For some models, very high b-values 'saturate' and it's better to model as if they were lower. This only kicks in after setting b_max so you need to be explicit about it.

((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high))
if gtab is not None
((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high))
if self._gtab is not None
else np.ones((data.shape[-1],), dtype=bool)
Copy link
Collaborator Author

@jhlegarreta jhlegarreta Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd dare to say that self._gtab will not be None, so this if/else block is not necessary to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we do want to use the input gtab, as opposed to the global gtab, which potentially contains the left-out gradient.

@jhlegarreta
Copy link
Collaborator Author

jhlegarreta commented Apr 30, 2024

Gave this another go. Adjusted some docstrings.

More questions/comments (inline and below, long to digest, sorry):

  • I am not sure whether the parent slots are inherited by the child classes

  • I am not sure, and how the AverageDWModel __init__ method kwargs documentation is understood, as (i) it only contains kwargs (we may need to tell something like kwargs can contain... - not sure how this is done properly in Sphinx); (ii) the parameters are defaulted to some values if not found. So maybe they can just be list as regular keyword arguments.

  • Not sure if I follow the timepoint/index rationale for the PET model. Do we have multiple PET volumes for a single session? If we do not, then I am not sure to follow: this would be like requiring DWI volumes from different sessions/timepoints to be aligned, which is different from what we try to do for the DWI case (correcting data from the same session).

    Edit: understood after talking to Martin. So essentially, the case is the same for PET: multiple 3D volumes in one session, each taken at some interval. So much like different gradient directions for DWI.

  • The following:
    https://github.com/nipreps/eddymotion/pull/176/files#diff-a875f501910044a7d95658fb83740e2c5c6c1693e7e6808703d282441db82be8R117

    Looks like it does not apply to PET (see https://github.com/nipreps/eddymotion/pull/176/files#diff-a875f501910044a7d95658fb83740e2c5c6c1693e7e6808703d282441db82be8R382). So should it be moved to the DWI base class?

    Edit: could apply to PET, so no need to pay attention for now.

@jhlegarreta
Copy link
Collaborator Author

jhlegarreta commented May 15, 2024

Sorry to ping you again this morning @oesteban.


model_str = getattr(self, "_model_class", None)
if not model_str:
raise TypeError("No model defined")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oesteban Some tests are now failing because the _model_class is None in this base class, and I am not setting any particular value in the derived classes. What is this property supposed to contain?

e.g.
https://app.circleci.com/pipelines/github/nipreps/eddymotion/1070/workflows/c9747d35-0cb1-49d9-963f-207d20887ce8/jobs/1038

Copy link
Collaborator Author

@jhlegarreta jhlegarreta May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused about the use of this. I now see that _model_class and _modelargs are properties of the wrappers DTIModel and DKIModel. This adds to this docstring https://github.com/nipreps/eddymotion/pull/176/files#diff-a875f501910044a7d95658fb83740e2c5c6c1693e7e6808703d282441db82be8L79

If I say here

kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs}
from importlib import import_module
model_str = "eddymotion.model.AverageDWModel"
module_name, class_name = model_str.rsplit(".", 1)
my_model = getattr(import_module(module_name), class_name)(**kwargs)

and leaving aside that AverageDWModel requires a gtab (as it inherits from BaseDWIModel; setting it to None in its init would make it), the above statement produces a recursive call, since instantiating AverageDWModel calls the superclass init method.

So I am not following what was intended with this block.

Also, I am not sure what we want to do here with the DTI and DKI wrappers either.

Edit: if the DTI/DKI wrappers make sense here, it looks as if the BaseDWIModel should not inherit from BaseModel, or at least, the init method of the latter and its docstring suggest that it is intended to be a superclass for the wrapper classes; however, the TrivialB0Model, AverageDWModel, etc. are not intended to be wrappers around dipy objects, and it does not make sense IMO for them to have _model_class and _modelargs properties. So there seems to be 2 things that are mixed here. The model factory will also need to be adapted following all this.

@oesteban Can you please clarify these aspects?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and leaving aside that AverageDWModel requires a gtab (as it inherits from BaseDWIModel;

The other two tests fail because of this reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look ASAP - sorry for my slow turnaround

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. I can now answer this question. I'm sorry for this particular case—you are prey to an undocumented feature.

The idea was that models can be fit in two ways:

  • Pure leave-one-out fashion: at every iteration of the Estimator, a fully-fledged model is fit without the particular index/orientation. This is typically very slow.
  • Single model: the model is fit on all the data, and each iteration produces the left-out index. These are enable by adding the prefix Full to the model name.

This is implemented in the Estimator, under the understanding that the model is the same, what changes is how you use it.

single_model = model.lower() in (
"b0",
"s0",
"avg",
"average",
"mean",
) or model.lower().startswith("full")
dwmodel = None
if single_model:
if model.lower().startswith("full"):
model = model[4:]
# Factory creates the appropriate model and pipes arguments
dwmodel = ModelFactory.init(
model=model,
**kwargs,
)
dwmodel.fit(dwdata.dataobj, n_jobs=n_jobs)

ATM I cannot comment on why this has some effect on the model itself so a _model_class is necessary, I can't recall the reason. I bet it is just to inform the estimator that fit should not be called every time (which probably should be handled here!)

That said, let's take the average model for example. When instantiated as FullAverage, then it is fit only once before entering the iterator loop of the estimator. If not, at every iteration an average without the particular direction will be calculated in the fit call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll leave my above comment because it explains something useful --- but it is totally unrelated to @jhlegarreta's question. Apologies for the confusion.

After working on the PR and re-reading the code, I understand that _model_class and _modelargs enable using DIPY models without much overhead (see DKI and DTI at the end).

Copy link
Member

@oesteban oesteban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to checkout this code to make a better review of it. A nit pick for the time being.

src/eddymotion/model/base.py Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved

model_str = getattr(self, "_model_class", None)
if not model_str:
raise TypeError("No model defined")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look ASAP - sorry for my slow turnaround

Make models inherit from base model.
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from 5d58f56 to 324d2ee Compare June 3, 2024 23:48
@oesteban
Copy link
Member

oesteban commented Jun 7, 2024

Let's get #166 over the final line and then I move onto this.

Copy link
Member

@oesteban oesteban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like where this is going. I'm going to add the docstring of constants and then work locally on this PR.

src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/model/base.py Outdated Show resolved Hide resolved
Improving the documentation of constants. cc/ @jhlegarreta
Copy link
Member

@oesteban oesteban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I responded to the two major questions in this PR. Happy to chat about the _model_class as it seems the feature may not be implemented in an intuitive way and it is definitely not sufficiently documented.


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
return retval, chunk


def _exec_predict(model, gradient, chunk=None, **kwargs):
def _exec_predict_dwi(model, gradient, chunk=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this back to be general, and make gradient an index -- will work on this through a PR to this branch.


model_str = getattr(self, "_model_class", None)
if not model_str:
raise TypeError("No model defined")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. I can now answer this question. I'm sorry for this particular case—you are prey to an undocumented feature.

The idea was that models can be fit in two ways:

  • Pure leave-one-out fashion: at every iteration of the Estimator, a fully-fledged model is fit without the particular index/orientation. This is typically very slow.
  • Single model: the model is fit on all the data, and each iteration produces the left-out index. These are enable by adding the prefix Full to the model name.

This is implemented in the Estimator, under the understanding that the model is the same, what changes is how you use it.

single_model = model.lower() in (
"b0",
"s0",
"avg",
"average",
"mean",
) or model.lower().startswith("full")
dwmodel = None
if single_model:
if model.lower().startswith("full"):
model = model[4:]
# Factory creates the appropriate model and pipes arguments
dwmodel = ModelFactory.init(
model=model,
**kwargs,
)
dwmodel.fit(dwdata.dataobj, n_jobs=n_jobs)

ATM I cannot comment on why this has some effect on the model itself so a _model_class is necessary, I can't recall the reason. I bet it is just to inform the estimator that fit should not be called every time (which probably should be handled here!)

That said, let's take the average model for example. When instantiated as FullAverage, then it is fit only once before entering the iterator loop of the estimator. If not, at every iteration an average without the particular direction will be calculated in the fit call.

src/eddymotion/model/base.py Outdated Show resolved Hide resolved
@jhlegarreta
Copy link
Collaborator Author

@oesteban Have gone through the comments. Will wait after this #176 (comment).

The main difficulty to make this work now lies in https://github.com/nipreps/eddymotion/pull/176/files#r1616403642. Although you answered to the thread, not sure if the question was addressed: the point is that I do not see why DTIModel and DKIModel exist in here; if they are meant to be removed, the issue related to _model_class would go away, and the inheritance would be easier I think, as it is the e.g. TrivialB0Model and AverageDWModel DWI models the ones that we are interested in subclassing from BaseDWIModel<-BaseModel.

jhlegarreta pushed a commit to jhlegarreta/eddymotion that referenced this pull request Jun 12, 2024
* enh: revise code

* sty: ruff format
jhlegarreta pushed a commit to jhlegarreta/eddymotion that referenced this pull request Jun 12, 2024
* enh: revise code

* sty: ruff format
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from 47cd1e6 to 250d23a Compare June 12, 2024 18:00
jhlegarreta pushed a commit to jhlegarreta/eddymotion that referenced this pull request Jun 12, 2024
* enh: revise code

* sty: ruff format
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from 250d23a to abae1cc Compare June 12, 2024 20:01
jhlegarreta pushed a commit to jhlegarreta/eddymotion that referenced this pull request Jun 12, 2024
* enh: revise code

* sty: ruff format
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from abae1cc to ee73cf9 Compare June 12, 2024 20:30
jhlegarreta pushed a commit to jhlegarreta/eddymotion that referenced this pull request Jun 12, 2024
* enh: revise code

* sty: ruff format
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from ee73cf9 to 104bc3d Compare June 12, 2024 21:09
* enh: revise code

* sty: ruff format
@jhlegarreta jhlegarreta force-pushed the InheritModelsFromBase branch from 104bc3d to a768e72 Compare June 12, 2024 21:12
@jhlegarreta jhlegarreta marked this pull request as ready for review June 12, 2024 21:13
Do not overwrite the gradient table in prediction.

Co-authored-by: Oscar Esteban <[email protected]>
@oesteban oesteban merged commit 8c0bf36 into nipreps:main Jun 13, 2024
6 checks passed
@jhlegarreta jhlegarreta deleted the InheritModelsFromBase branch June 13, 2024 13:13
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants