Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 19, 2024
1 parent 9a2a33e commit c049ea1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(
Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:]))
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:], strict=False)
)
]
)
)
Expand Down
8 changes: 5 additions & 3 deletions src/scvi/external/sysvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections.abc import Sequence
from typing import Literal, Tuple
from typing import Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_latent_representation(
give_mean: bool = True,
batch_size: int | None = None,
return_dist: bool = False,
) -> np.ndarray | Tuple[np.ndarray, np.ndarray]:
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Return the latent representation for each cell.
Parameters
Expand Down Expand Up @@ -300,7 +300,9 @@ def setup_anndata(
)

# Make one-hot embedding with specified order
systems_dict = dict(zip(batch_order, ([float(i) for i in range(0, len(batch_order))])))
systems_dict = dict(
zip(batch_order, ([float(i) for i in range(0, len(batch_order))]), strict=False)
)
adata.uns["batch_order"] = batch_order
system_cat = pd.Series(
pd.Categorical(values=adata.obs[batch_key], categories=batch_order, ordered=True),
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/sysvi/_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Literal

import torch
from typing_extensions import Literal

from scvi import REGISTRY_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/sysvi/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def dummies_categories(values: pd.Series, categories: list):
cov_embed_data = []
for cov_cat_embed_key in cov_cat_embed_keys:
cat_order = categ_orders[cov_cat_embed_key]
cat_map = dict(zip(cat_order, range(len(cat_order))))
cat_map = dict(zip(cat_order, range(len(cat_order)), strict=False))
cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map))
cov_embed_data = pd.concat(cov_embed_data, axis=1)
else:
Expand Down

0 comments on commit c049ea1

Please sign in to comment.