Skip to content

Commit

Permalink
MAINT: Check and format with NPY201 (mne-tools#12353)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Jan 11, 2024
1 parent 8eb10e3 commit eefd179
Show file tree
Hide file tree
Showing 18 changed files with 36 additions and 33 deletions.
9 changes: 4 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ repos:
name: ruff lint mne
args: ["--fix"]
files: ^mne/
- id: ruff
name: ruff lint mne preview
args: ["--fix", "--preview", "--select=NPY201"]
files: ^mne/
- id: ruff-format
name: ruff format mne
files: ^mne/

# Ruff tutorials and examples
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
name: ruff lint tutorials and examples
# D103: missing docstring in public function
Expand Down
2 changes: 1 addition & 1 deletion mne/_fiff/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def show_fiff(
tag_id=tag,
show_bytes=show_bytes,
)
if output == str:
if output is str:
out = "\n".join(out)
return out

Expand Down
16 changes: 8 additions & 8 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ def test_search_light():
# transforms
pytest.raises(ValueError, sl.predict, X[:, :, :2])
y_trans = sl.transform(X)
assert X.dtype == y_trans.dtype == float
assert X.dtype == y_trans.dtype == np.dtype(float)
y_pred = sl.predict(X)
assert y_pred.dtype == int
assert y_pred.dtype == np.dtype(int)
assert_array_equal(y_pred.shape, [n_epochs, n_time])
y_proba = sl.predict_proba(X)
assert y_proba.dtype == float
assert y_proba.dtype == np.dtype(float)
assert_array_equal(y_proba.shape, [n_epochs, n_time, 2])

# score
score = sl.score(X, y)
assert_array_equal(score.shape, [n_time])
assert np.sum(np.abs(score)) != 0
assert score.dtype == float
assert score.dtype == np.dtype(float)

sl = SlidingEstimator(logreg)
assert_equal(sl.scoring, None)
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_search_light():
X = rng.randn(*X.shape) # randomize X to avoid AUCs in [0, 1]
score_sl = sl1.score(X, y)
assert_array_equal(score_sl.shape, [n_time])
assert score_sl.dtype == float
assert score_sl.dtype == np.dtype(float)

# Check that scoring was applied adequately
scoring = make_scorer(roc_auc_score, needs_threshold=True)
Expand Down Expand Up @@ -195,9 +195,9 @@ def test_generalization_light():
# transforms
y_pred = gl.predict(X)
assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time])
assert y_pred.dtype == int
assert y_pred.dtype == np.dtype(int)
y_proba = gl.predict_proba(X)
assert y_proba.dtype == float
assert y_proba.dtype == np.dtype(float)
assert_array_equal(y_proba.shape, [n_epochs, n_time, n_time, 2])

# transform to different datasize
Expand All @@ -208,7 +208,7 @@ def test_generalization_light():
score = gl.score(X[:, :, :3], y)
assert_array_equal(score.shape, [n_time, 3])
assert np.sum(np.abs(score)) != 0
assert score.dtype == float
assert score.dtype == np.dtype(float)

gl = GeneralizingEstimator(logreg, scoring="roc_auc")
gl.fit(X, y)
Expand Down
2 changes: 1 addition & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,7 +1537,7 @@ def drop(self, indices, reason="USER", verbose=None):
if indices.ndim > 1:
raise ValueError("indices must be a scalar or a 1-d array")

if indices.dtype == bool:
if indices.dtype == np.dtype(bool):
indices = np.where(indices)[0]
try_idx = np.where(indices < 0, indices + len(self.events), indices)

Expand Down
2 changes: 1 addition & 1 deletion mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def get_params(self, deep=True):
try:
with warnings.catch_warnings(record=True) as w:
value = getattr(self, key, None)
if len(w) and w[0].category == DeprecationWarning:
if len(w) and w[0].category is DeprecationWarning:
# if the parameter is deprecated, don't show it
continue
finally:
Expand Down
2 changes: 1 addition & 1 deletion mne/preprocessing/_annotate_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _check_min_duration(min_duration, raw_duration):

def _reject_short_segments(arr, min_duration_samples):
"""Check if flat or peak segments are longer than the minimum duration."""
assert arr.dtype == bool and arr.ndim == 2
assert arr.dtype == np.dtype(bool) and arr.ndim == 2
for k, ch in enumerate(arr):
onsets, offsets = _mask_to_onsets_offsets(ch)
_mark_inner(arr[k], onsets, offsets, min_duration_samples)
Expand Down
2 changes: 1 addition & 1 deletion mne/preprocessing/artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def annotate_break(
# Log some info
n_breaks = len(break_annotations)
break_times = [
f"{o:.1f}{o+d:.1f} s [{d:.1f} s]"
f"{o:.1f}{o + d:.1f} s [{d:.1f} s]"
for o, d in zip(break_annotations.onset, break_annotations.duration)
]
break_times = "\n ".join(break_times)
Expand Down
4 changes: 2 additions & 2 deletions mne/preprocessing/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
# compute centroid position in spherical "head" coordinates
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
# create the virtual channel info and set the position
virtual_info = create_info([f"virtual {k+1}"], inst.info["sfreq"], "eeg")
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
virtual_info["chs"][0]["loc"][:3] = pos_virtual
# create virtual channel
data = inst.get_data(picks=group_names)
Expand All @@ -182,7 +182,7 @@ def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
nave=inst.nave,
kind=inst.kind,
)
virtual_chs[f"virtual {k+1}"] = virtual_ch
virtual_chs[f"virtual {k + 1}"] = virtual_ch

# add the virtual channels
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
Expand Down
2 changes: 1 addition & 1 deletion mne/preprocessing/tests/test_realign.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _assert_similarity(raw, other, n_events, ratio_other, events_raw=None):
evoked_other = Epochs(other, events_other, **kwargs).average()
assert evoked_raw.nave == evoked_other.nave == len(events_raw)
assert len(evoked_raw.data) == len(evoked_other.data) == 1 # just EEG
if 0.99 <= ratio_other <= 1.01: # when drift is not too large
if 0.99 <= ratio_other <= 1.01: # when drift is not too large
corr = np.corrcoef(evoked_raw.data[0], evoked_other.data[0])[0, 1]
assert 0.9 <= corr <= 1.0
return evoked_raw, events_raw, evoked_other, events_other
Expand Down
2 changes: 1 addition & 1 deletion mne/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _compute_proj(
nrow=1,
ncol=u.size,
)
desc = f"{kind}-{desc_prefix}-PCA-{k+1:02d}"
desc = f"{kind}-{desc_prefix}-PCA-{k + 1:02d}"
logger.info("Adding projection: %s", desc)
proj = Projection(
active=False,
Expand Down
4 changes: 2 additions & 2 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,7 +2271,7 @@ def add_figure(
elif caption is None and len(figs) == 1:
captions = [None]
elif caption is None and len(figs) > 1:
captions = [f"Figure {i+1}" for i in range(len(figs))]
captions = [f"Figure {i + 1}" for i in range(len(figs))]
else:
captions = tuple(caption)

Expand Down Expand Up @@ -3143,7 +3143,7 @@ def _add_raw_butterfly_segments(

del orig_annotations

captions = [f"Segment {i+1} of {len(images)}" for i in range(len(images))]
captions = [f"Segment {i + 1} of {len(images)}" for i in range(len(images))]

self._add_slider(
figs=None,
Expand Down
2 changes: 1 addition & 1 deletion mne/source_space/_source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ def _grid_interp(from_shape, to_shape, trans, order=1, inuse=None):
shape = (np.prod(to_shape), np.prod(from_shape))
if inuse is None:
inuse = np.ones(shape[1], bool)
assert inuse.dtype == bool
assert inuse.dtype == np.dtype(bool)
assert inuse.shape == (shape[1],)
data, indices, indptr = _grid_interp_jit(from_shape, to_shape, trans, order, inuse)
data = np.concatenate(data)
Expand Down
4 changes: 2 additions & 2 deletions mne/stats/cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def _find_clusters(
len_c = c.stop - c.start
elif isinstance(c, tuple):
len_c = len(c)
elif c.dtype == bool:
elif c.dtype == np.dtype(bool):
len_c = np.sum(c)
else:
len_c = len(c)
Expand Down Expand Up @@ -1634,7 +1634,7 @@ def _reshape_clusters(clusters, sample_shape):
"""Reshape cluster masks or indices to be of the correct shape."""
# format of the bool mask and indices are ndarrays
if len(clusters) > 0 and isinstance(clusters[0], np.ndarray):
if clusters[0].dtype == bool: # format of mask
if clusters[0].dtype == np.dtype(bool): # format of mask
clusters = [c.reshape(sample_shape) for c in clusters]
else: # format of indices
clusters = [np.unravel_index(c, sample_shape) for c in clusters]
Expand Down
4 changes: 2 additions & 2 deletions mne/stats/tests/test_cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def test_permutation_adjacency_equiv(numba_conditional):
)
# make sure our output datatype is correct
assert isinstance(clusters[0], np.ndarray)
assert clusters[0].dtype == bool
assert clusters[0].dtype == np.dtype(bool)
assert_array_equal(clusters[0].shape, X.shape[1:])

# make sure all comparisons were done; for TFCE, no perm
Expand Down Expand Up @@ -847,7 +847,7 @@ def test_output_equiv(shape, out_type, adjacency):
assert isinstance(clu[0], slice)
else:
assert isinstance(clu, np.ndarray)
assert clu.dtype == bool
assert clu.dtype == np.dtype(bool)
assert clu.shape == shape
got_mask[clu] = n
else:
Expand Down
4 changes: 3 additions & 1 deletion mne/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def test_crop(tmp_path):
assert raw_read.annotations is not None
assert len(raw_read.annotations.onset) == 0
# test saving and reloading cropped annotations in raw instance
info = create_info([f"EEG{i+1}" for i in range(3)], ch_types=["eeg"] * 3, sfreq=50)
info = create_info(
[f"EEG{i + 1}" for i in range(3)], ch_types=["eeg"] * 3, sfreq=50
)
raw = RawArray(np.zeros((3, 50 * 20)), info)
annotation = mne.Annotations([8, 12, 15], [2] * 3, [1, 2, 3])
raw = raw.set_annotations(annotation)
Expand Down
2 changes: 1 addition & 1 deletion mne/utils/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def _fit(self, X):

def _mask_to_onsets_offsets(mask):
"""Group boolean mask into contiguous onset:offset pairs."""
assert mask.dtype == bool and mask.ndim == 1
assert mask.dtype == np.dtype(bool) and mask.ndim == 1
mask = mask.astype(int)
diff = np.diff(mask)
onsets = np.where(diff > 0)[0] + 1
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/_mpl_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ def _draw_one_scalebar(self, x, y, ch_type):
color = "#AA3377" # purple
kwargs = dict(color=color, zorder=self.mne.zorder["scalebar"])
if ch_type == "time":
label = f"{self.mne.boundary_times[1]/2:.2f} s"
label = f"{self.mne.boundary_times[1] / 2:.2f} s"
text = self.mne.ax_main.text(
x[0] + 0.015,
y[1] - 0.05,
Expand Down
4 changes: 3 additions & 1 deletion mne/viz/evoked_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ def _on_colormap_range(self, event):
if self._show_density:
surf_map["mesh"].update_overlay(name="field", rng=[vmin, vmax])
# Update the GUI widgets
if type == "meg":
# TODO: type is undefined here and only avoids a flake warning because it's
# a builtin!
if type == "meg": # noqa: E721
scaling = DEFAULTS["scalings"]["grad"]
else:
scaling = DEFAULTS["scalings"]["eeg"]
Expand Down

0 comments on commit eefd179

Please sign in to comment.