Skip to content

Commit

Permalink
Fixed bug in vector metric computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Apr 19, 2024
1 parent 9404bce commit 62d81c6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
7 changes: 3 additions & 4 deletions datasail/cluster/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
if method in get_args(SIM_OPTIONS):
if isinstance(embed, (list, tuple, np.ndarray)):
if isinstance(embed[0], int) or np.issubdtype(embed[0].dtype, int):
if method in ["allbit", "asymmetric", "braunblanquet", "cosine", "kulczynski", "mcconnaughey", "onbit",
if method in ["allbit", "asymmetric", "braunblanquet", "cosine", "kulczynski", "onbit",
"rogotgoldberg", "russel", "sokal"]:
dataset.data = {k: iterable2bitvect(v) for k, v in dataset.data.items()}
else:
Expand All @@ -138,15 +138,14 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
elif method in get_args(DIST_OPTIONS):
dtype = np.bool_ if ["jaccard", "rogerstanimoto", "sokalmichener", "yule"] else np.float64
if isinstance(embed, (
list, tuple, DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect)):
dataset.data = {k: np.array(list(v), dtype=np.float64) for k, v in dataset.data.items()}
dataset.data = {k: np.array(list(v), dtype=dtype) for k, v in dataset.data.items()}
if not isinstance(dataset.data[dataset.names[0]], np.ndarray):
raise ValueError(
f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
if method in ["rogerstanimoto", "sokalmichener", "yule"]:
dataset.data = {k: np.array(list(v), dtype=np.bool_) for k, v in dataset.data.items()}
else:
raise ValueError(f"Unknown method {method}")
fps = [dataset.data[name] for name in dataset.names]
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
author="Roman Joeres",
maintainer="Roman Joeres",
classifiers=[
"Development Status :: 4 - Beta",
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Topic :: Scientific/Engineering :: Bio-Informatics",
],
packages=find_packages(),
setup_requires=['setuptools_scm'],
include_package_data=True,
include_package_data=False,
install_requires=[],
package_data={},
python_requires=">=3.8, <4.0.0",
Expand Down

0 comments on commit 62d81c6

Please sign in to comment.