diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index ed4cfe214e..09c0569f45 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -197,7 +197,9 @@ def __init__( Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) - for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) + for i, (n_in, n_out) in enumerate( + zip(layers_dim[:-1], layers_dim[1:], strict=False) + ) ] ) ) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 0da03933fe..7277a76b64 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import Literal, Tuple +from typing import Literal import numpy as np import pandas as pd @@ -118,7 +118,7 @@ def get_latent_representation( give_mean: bool = True, batch_size: int | None = None, return_dist: bool = False, - ) -> np.ndarray | Tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Return the latent representation for each cell. Parameters @@ -300,7 +300,9 @@ def setup_anndata( ) # Make one-hot embedding with specified order - systems_dict = dict(zip(batch_order, ([float(i) for i in range(0, len(batch_order))]))) + systems_dict = dict( + zip(batch_order, ([float(i) for i in range(0, len(batch_order))]), strict=False) + ) adata.uns["batch_order"] = batch_order system_cat = pd.Series( pd.Categorical(values=adata.obs[batch_key], categories=batch_order, ordered=True), diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index bec63c9fb3..92ebfde80e 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -1,7 +1,8 @@ from __future__ import annotations +from typing import Literal + import torch -from typing_extensions import Literal from scvi import REGISTRY_KEYS from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data diff --git a/src/scvi/external/sysvi/_utils.py b/src/scvi/external/sysvi/_utils.py index 5eadcb8819..ad2e8bcab0 100644 --- a/src/scvi/external/sysvi/_utils.py +++ b/src/scvi/external/sysvi/_utils.py @@ -140,7 +140,7 @@ def dummies_categories(values: pd.Series, categories: list): cov_embed_data = [] for cov_cat_embed_key in cov_cat_embed_keys: cat_order = categ_orders[cov_cat_embed_key] - cat_map = dict(zip(cat_order, range(len(cat_order)))) + cat_map = dict(zip(cat_order, range(len(cat_order)), strict=False)) cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map)) cov_embed_data = pd.concat(cov_embed_data, axis=1) else: