Skip to content

Commit

Permalink
Merge pull request #178 from alan-turing-institute/run-with-fail
Browse files Browse the repository at this point in the history
run compare despite model failing
  • Loading branch information
mastoffel authored Feb 20, 2024
2 parents 31c040c + c59ad62 commit 842fe66
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,35 @@ def compare(self):
)

for i in range(len(self.models)):
# hyperparameter search
if self.param_search:
self.models[i] = optimize_params(
try:
# hyperparameter search
if self.param_search:
self.models[i] = optimize_params(
X=self.X[self.train_idxs],
y=self.y[self.train_idxs],
cv=self.cv,
model=self.models[i],
search_type=self.search_type,
niter=self.param_search_iters,
param_space=None,
n_jobs=self.n_jobs,
logger=self.logger,
)

# run cross validation
fitted_model, cv_results = run_cv(
X=self.X[self.train_idxs],
y=self.y[self.train_idxs],
cv=self.cv,
model=self.models[i],
search_type=self.search_type,
niter=self.param_search_iters,
param_space=None,
metrics=self.metrics,
n_jobs=self.n_jobs,
logger=self.logger,
)

# run cross validation
fitted_model, cv_results = run_cv(
X=self.X[self.train_idxs],
y=self.y[self.train_idxs],
cv=self.cv,
model=self.models[i],
metrics=self.metrics,
n_jobs=self.n_jobs,
logger=self.logger,
)
except Exception as e:
print(f"Error fitting model {get_model_name(self.models[i])}")
print(e) # should be replaced with logging
continue

self.models[i] = fitted_model
self.cv_results[get_model_name(self.models[i])] = cv_results
Expand Down

0 comments on commit 842fe66

Please sign in to comment.