Skip to content

Commit

Permalink
Fix miceforest (#800)
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson authored Sep 18, 2024
1 parent 2525089 commit c939530
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 22 deletions.
4 changes: 4 additions & 0 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def df_to_anndata(

# Prepare the AnnData object
X = df.to_numpy(copy=True)
obs.index = obs.index.astype(str)
var = pd.DataFrame(index=df.columns)
var.index = var.index.astype(str)
uns = OrderedDict() # type: ignore

# Handle dtype of X based on presence of numerical columns only
all_numeric = df.select_dtypes(include=[np.number]).shape[1] == df.shape[1]
X = X.astype(np.float32 if all_numeric else object)

adata = AnnData(X=X, obs=obs, var=var, uns=uns, layers={"original": X.copy()})
adata.obs_names = adata.obs_names.astype(str)
adata.var_names = adata.var_names.astype(str)

return adata

Expand Down
36 changes: 22 additions & 14 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def mice_forest_impute(
var_names: Iterable[str] | None = None,
*,
warning_threshold: int = 70,
save_all_iterations: bool = True,
save_all_iterations_data: bool = True,
random_state: int | None = None,
inplace: bool = False,
iterations: int = 5,
Expand All @@ -485,7 +485,7 @@ def mice_forest_impute(
adata: The AnnData object containing the data to impute.
var_names: A list of variable names to impute. If None, impute all variables.
warning_threshold: Threshold of percentage of missing values to display a warning for.
save_all_iterations: Whether to save all imputed values from all iterations or just the latest.
save_all_iterations_data: Whether to save all imputed values from all iterations or just the latest.
Saving all iterations allows for additional plotting, but may take more memory.
random_state: The random state ensures script reproducibility.
inplace: If True, modify the input AnnData object in-place and return None.
Expand Down Expand Up @@ -520,7 +520,7 @@ def mice_forest_impute(
_miceforest_impute(
adata,
var_names,
save_all_iterations,
save_all_iterations_data,
random_state,
inplace,
iterations,
Expand All @@ -536,7 +536,7 @@ def mice_forest_impute(
_miceforest_impute(
adata,
var_names,
save_all_iterations,
save_all_iterations_data,
random_state,
inplace,
iterations,
Expand All @@ -555,31 +555,39 @@ def mice_forest_impute(


def _miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
adata, var_names, save_all_iterations_data, random_state, inplace, iterations, variable_parameters, verbose
) -> None:
import miceforest as mf

data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
data_df = data_df.apply(pd.to_numeric, errors="coerce")

if isinstance(var_names, Iterable):
column_indices = _get_column_indices(adata, var_names)
selected_columns = data_df.iloc[:, column_indices]
selected_columns = selected_columns.reset_index(drop=True)

# Create kernel.
kernel = mf.ImputationKernel(
adata.X[::, column_indices], datasets=1, save_all_iterations=save_all_iterations, random_state=random_state
selected_columns,
num_datasets=1,
save_all_iterations_data=save_all_iterations_data,
random_state=random_state,
)

kernel.mice(iterations=iterations, variable_parameters=variable_parameters, verbose=verbose)

adata.X[::, column_indices] = kernel.complete_data(dataset=0, inplace=inplace)
kernel.mice(iterations=iterations, variable_parameters=variable_parameters or {}, verbose=verbose)
data_df.iloc[:, column_indices] = kernel.complete_data(dataset=0, inplace=inplace)

else:
# Create kernel.
data_df = data_df.reset_index(drop=True)

kernel = mf.ImputationKernel(
adata.X, datasets=1, save_all_iterations=save_all_iterations, random_state=random_state
data_df, num_datasets=1, save_all_iterations_data=save_all_iterations_data, random_state=random_state
)

kernel.mice(iterations=iterations, variable_parameters=variable_parameters, verbose=verbose)
kernel.mice(iterations=iterations, variable_parameters=variable_parameters or {}, verbose=verbose)
data_df = kernel.complete_data(dataset=0, inplace=inplace)

adata.X = kernel.complete_data(dataset=0, inplace=inplace)
adata.X = data_df.values


def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, threshold: int = 75) -> dict[str, int]:
Expand Down
16 changes: 8 additions & 8 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,35 +167,35 @@ def test_missforest_impute_dict(impute_adata):
assert not (np.all([item != item for item in adata_imputed.X]))


@pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_no_copy(impute_iris_adata_adata):
adata_imputed = mice_forest_impute(impute_iris_adata_adata)
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_no_copy(impute_iris_adata):
adata_imputed = mice_forest_impute(impute_iris_adata)

assert id(impute_iris_adata_adata) == id(adata_imputed)
assert id(impute_iris_adata) == id(adata_imputed)


@pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.")
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_copy(impute_iris_adata):
adata_imputed = mice_forest_impute(impute_iris_adata, copy=True)

assert id(impute_iris_adata) != id(adata_imputed)


@pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.")
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_non_numerical_data(impute_titanic_adata):
adata_imputed = mice_forest_impute(impute_titanic_adata)

assert not (np.all([item != item for item in adata_imputed.X]))


@pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.")
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_numerical_data(impute_iris_adata):
adata_imputed = mice_forest_impute(impute_iris_adata)

assert not (np.all([item != item for item in adata_imputed.X]))


@pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.")
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
def test_miceforest_impute_list_str(impute_titanic_adata):
adata_imputed = mice_forest_impute(impute_titanic_adata, var_names=["Cabin", "Age"])

Expand Down

0 comments on commit c939530

Please sign in to comment.