Skip to content

Commit

Permalink
Lint transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Sep 10, 2024
1 parent 615ec89 commit 3f2e1f1
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions afqinsight/transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Transform AFQ data."""

from collections import OrderedDict

import numpy as np
import pandas as pd
from collections import OrderedDict
from sklearn_pandas import DataFrameMapper

from .utils import CANONICAL_TRACT_NAMES
Expand Down Expand Up @@ -33,7 +34,8 @@ class AFQDataFrameMapper(DataFrameMapper):
Keyword arguments passed to sklearn_pandas.DataFrameMapper. You will
probably not need to change these defaults.
pd_interpolate_params : kwargs, default=dict(method="linear", limit_direction="both", limit_area="inside")
pd_interpolate_params : kwargs,
default=dict(method="linear", limit_direction="both", limit_area="inside")
Keyword arguments passed to pandas.DataFrame.interpolate. Missing
values are interpolated within the tract profile so that no data is
used from other subjects, tracts, or metrics, minimizing the chance
Expand Down Expand Up @@ -280,7 +282,7 @@ def multicol2sets(columns, tract_symmetry=True):
col_vals = np.array([x + (bilateral_symmetry[x[tract_idx]],) for x in col_vals])

col_vals = np.array([tuple([str(el) for el in tup]) for tup in col_vals])
col_sets = np.array(list(map(lambda c: set(c), col_vals)))
col_sets = np.array([set(c) for c in col_vals])

return col_sets

Expand Down Expand Up @@ -343,7 +345,7 @@ def sort_features(features, scores):
Sorted list of columns and scores
"""
res = sorted(
[(feat, score) for feat, score in zip(features, scores)],
zip(features, scores),
key=lambda s: np.abs(s[1]),
reverse=True,
)
Expand Down Expand Up @@ -383,12 +385,12 @@ def beta_hat_by_groups(beta_hat, columns, drop_zeros=False):
label_sets = multicol2sets(columns, tract_symmetry=False)

for tract in columns.levels[columns.names.index("tractID")]:
tract_mask = set([tract]) <= label_sets
tract_mask = set(tract) <= label_sets
all_metrics = np.copy(beta_hat[tract_mask])
if not drop_zeros or any(all_metrics != 0):
betas[tract] = OrderedDict()
for metric in columns.levels[columns.names.index("metric")]:
metric_mask = set([tract, metric]) <= label_sets
metric_mask = set([tract, metric]) <= label_sets # noqa C405
x = np.copy(beta_hat[metric_mask])
if not drop_zeros or any(x != 0):
betas[tract][metric] = x
Expand Down

0 comments on commit 3f2e1f1

Please sign in to comment.