diff --git a/pysaliency/baseline_utils.py b/pysaliency/baseline_utils.py index 370ef25..406ebe3 100644 --- a/pysaliency/baseline_utils.py +++ b/pysaliency/baseline_utils.py @@ -355,12 +355,17 @@ def _normalize_regularization_factors(args): class CrossvalMultipleRegularizations(object): - """ Class for computing crossvalidation scores of a fixation KDE with multiple regularization models""" - def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, verbose=False): + """Class for computing crossvalidation scores of a fixation KDE with multiple regularization models + + n_jobs: number of parallel jobs to use in cross_val_score + verbose: verbosity level for cross_val_score + """ + def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, n_jobs=None, verbose=False): self.stimuli = stimuli self.fixations = fixations self.cv = crossvalidation + self.n_jobs = n_jobs self.verbose = verbose X_areas = fixations_to_scikit_learn( @@ -413,13 +418,13 @@ def score(self, log_bandwidth, *args, **kwargs): class CrossvalGoldMultipleRegularizations(CrossvalMultipleRegularizations): - def __init__(self, stimuli, fixations, regularization_models, verbose=False): + def __init__(self, stimuli, fixations, regularization_models, n_jobs=None, verbose=False): if fixations.subject_count > 1: crossvalidation_factory = ScikitLearnImageSubjectCrossValidationGenerator else: crossvalidation_factory = ScikitLearnWithinImageCrossValidationGenerator - super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory, verbose=verbose) + super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory, n_jobs=n_jobs, verbose=verbose) # baseline models