Skip to content

Commit

Permalink
add StratifiedKfold when stratification_coord is given (brain-score#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
mschrimpf authored Jan 21, 2020
1 parent caebeb7 commit c9e8c1e
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions brainscore/metrics/transformations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
from collections import OrderedDict

import itertools
import logging
import math
import numpy as np
import xarray as xr
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit, KFold
from tqdm import tqdm

from brainio_base.assemblies import DataAssembly, walk_coords
from brainio_collection.transform import subset
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit, KFold, StratifiedKFold
from tqdm import tqdm

from brainscore.metrics import Score
from brainscore.metrics.utils import unique_ordered
from brainscore.utils import fullname
Expand Down Expand Up @@ -180,17 +180,19 @@ def __init__(self,
train_size = self.Defaults.train_size
if kfold:
assert (train_size is None or train_size == self.Defaults.train_size) and test_size is None
assert not bool(stratification_coord)
self._split = KFold(n_splits=splits, shuffle=True, random_state=random_state)
elif stratification_coord:
self._split = StratifiedShuffleSplit(
n_splits=splits, train_size=train_size, test_size=test_size, random_state=random_state)
if stratification_coord:
self._split = StratifiedKFold(n_splits=splits, shuffle=True, random_state=random_state)
else:
self._split = KFold(n_splits=splits, shuffle=True, random_state=random_state)
else:
self._split = ShuffleSplit(
n_splits=splits, train_size=train_size, test_size=test_size, random_state=random_state)
if stratification_coord:
self._split = StratifiedShuffleSplit(
n_splits=splits, train_size=train_size, test_size=test_size, random_state=random_state)
else:
self._split = ShuffleSplit(
n_splits=splits, train_size=train_size, test_size=test_size, random_state=random_state)
self._split_coord = split_coord
self._stratification_coord = stratification_coord
self._kfold = kfold
self._unique_split_values = unique_split_values

self._logger = logging.getLogger(fullname(self))
Expand Down

0 comments on commit c9e8c1e

Please sign in to comment.