diff --git a/pysaliency/baseline_utils.py b/pysaliency/baseline_utils.py index f81cbe6..370ef25 100644 --- a/pysaliency/baseline_utils.py +++ b/pysaliency/baseline_utils.py @@ -356,11 +356,12 @@ def _normalize_regularization_factors(args): class CrossvalMultipleRegularizations(object): """ Class for computing crossvalidation scores of a fixation KDE with multiple regularization models""" - def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation): + def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, verbose=False): self.stimuli = stimuli self.fixations = fixations self.cv = crossvalidation + self.verbose = verbose X_areas = fixations_to_scikit_learn( self.fixations, normalize=stimuli, @@ -406,19 +407,19 @@ def score(self, log_bandwidth, *args, **kwargs): bandwidth=10**log_bandwidth, regularizations=10**log_regularizations, regularizing_log_likelihoods=self.regularization_log_likelihoods), - self.X, cv=self.cv, verbose=1).sum() / len(self.X) / np.log(2) + self.X, cv=self.cv, verbose=self.verbose).sum() / len(self.X) / np.log(2) val += np.log2(self.mean_area) return val class CrossvalGoldMultipleRegularizations(CrossvalMultipleRegularizations): - def __init__(self, stimuli, fixations, regularization_models): + def __init__(self, stimuli, fixations, regularization_models, verbose=False): if fixations.subject_count > 1: crossvalidation_factory = ScikitLearnImageSubjectCrossValidationGenerator else: crossvalidation_factory = ScikitLearnWithinImageCrossValidationGenerator - super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory) + super().__init__(stimuli, fixations, regularization_models, crossvalidation_factory=crossvalidation_factory, verbose=verbose) # baseline models