Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #109 from nipreps/maint/refactor-models
Browse files Browse the repository at this point in the history
ENH: Model building refactor
  • Loading branch information
oesteban authored Dec 9, 2022
2 parents 5426311 + b2a3ced commit cc5f547
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 243 deletions.
25 changes: 14 additions & 11 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,23 @@ def fit(
index_order = np.arange(len(dwdata))
np.random.shuffle(index_order)

if model.lower().startswith("fulldki"):
kwargs["data"] = dwdata.dataobj

single_model = (
model.lower() in ("b0", "s0", "avg", "average", "mean", "fulldki")
model.lower() in ("b0", "s0", "avg", "average", "mean")
or model.lower().startswith("full")
)

# Factory creates the appropriate model and pipes arguments
dwmodel = ModelFactory.init(
gtab=dwdata.gradients,
model=model,
n_jobs=n_jobs,
**kwargs,
) if single_model else None
dwmodel = None
if single_model:
if model.lower().startswith("full"):
model = model[4:]

# Factory creates the appropriate model and pipes arguments
dwmodel = ModelFactory.init(
gtab=dwdata.gradients,
model=model,
**kwargs,
)
dwmodel.fit(dwdata.dataobj, n_jobs=n_jobs)

with TemporaryDirectory() as tmpdir:
print(f"Processing in <{tmpdir}>")
Expand Down
Loading

0 comments on commit cc5f547

Please sign in to comment.