From 8559a6b9b6234c68a36613b3f0ed49a78529cff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmmerer?= Date: Sun, 19 Nov 2023 22:55:00 +0100 Subject: [PATCH] specify number of parallel jobs for CrossvalMultipleRegularizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmmerer --- pysaliency/baseline_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pysaliency/baseline_utils.py b/pysaliency/baseline_utils.py index 370ef25..406ebe3 100644 --- a/pysaliency/baseline_utils.py +++ b/pysaliency/baseline_utils.py @@ -355,12 +355,17 @@ def _normalize_regularization_factors(args): class CrossvalMultipleRegularizations(object): - """ Class for computing crossvalidation scores of a fixation KDE with multiple regularization models""" - def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, verbose=False): + """Class for computing crossvalidation scores of a fixation KDE with multiple regularization models + + n_jobs: number of parallel jobs to use in cross_val_score + verbose: verbosity level for cross_val_score + """ + def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, n_jobs=None, verbose=False): self.stimuli = stimuli self.fixations = fixations self.cv = crossvalidation + self.n_jobs = n_jobs self.verbose = verbose X_areas = fixations_to_scikit_learn( @@ -413,13 +418,13 @@ def score(self, log_bandwidth, *args, **kwargs): class CrossvalGoldMultipleRegularizations(CrossvalMultipleRegularizations): - def __init__(self, stimuli, fixations, regularization_models, verbose=False): + def __init__(self, stimuli, fixations, regularization_models, n_jobs=None, verbose=False): if fixations.subject_count > 1: crossvalidation_factory = ScikitLearnImageSubjectCrossValidationGenerator else: crossvalidation_factory = ScikitLearnWithinImageCrossValidationGenerator - super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory, verbose=verbose) + super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory, n_jobs=n_jobs, verbose=verbose) # baseline models