Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

USearch indexing for Hoplite DB. #699

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions chirp/projects/agile2/1_embed_audio_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"from chirp.projects.agile2 import colab_utils\n",
"from chirp.projects.agile2 import embed\n",
"from chirp.projects.agile2 import source_info\n",
"from chirp.projects.hoplite import interface\n"
"from chirp.projects.hoplite import interface\n",
"from chirp.projects.hoplite import brutalism\n",
"from chirp.projects.hoplite import db_loader\n",
"from chirp.projects.hoplite import sqlite_usearch_impl\n"
]
},
{
Expand Down Expand Up @@ -55,7 +58,7 @@
"#@markdown like '/home/me/myproject/site_XYZ/audio_ABC.wav'\n",
"dataset_name = '' #@param {type:'string'}\n",
"dataset_base_path = '' #@param {type:'string'}\n",
"dataset_fileglob = '' #@param {type:'string'}\n",
"dataset_fileglob = '*.wav' #@param {type:'string'}\n",
"\n",
"#@markdown Choose a supported model: `perch_8` or `birdnet_v2.3` are most common\n",
"#@markdown for birds. Other choices include `surfperch` for coral reefs or\n",
Expand All @@ -82,7 +85,8 @@
"configs = colab_utils.load_configs(\n",
" source_info.AudioSources((audio_glob,)),\n",
" db_path,\n",
" model_config_key=model_choice)\n",
" model_config_key=model_choice,\n",
" db_key = 'sqlite_usearch')\n",
"configs"
]
},
Expand Down Expand Up @@ -164,10 +168,14 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hr_AUAfI7UG_"
"id": "ihBNRbwuuwal"
},
"outputs": [],
"source": []
"source": [
"q = db.get_embedding(444)\n",
"%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)\n",
"print([r.embedding_id for r in results])"
]
}
],
"metadata": {
Expand Down
22 changes: 15 additions & 7 deletions chirp/projects/agile2/colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chirp.projects.zoo import model_configs
from etils import epath
from ml_collections import config_dict
import numpy as np


@dataclasses.dataclass
Expand All @@ -49,15 +50,17 @@ def load_configs(
audio_sources: source_info.AudioSources,
db_path: str | None = None,
model_config_key: str = 'perch_8',
db_key: str = 'sqlite_usearch',
) -> AgileConfigs:
"""Load default configs for the notebook and return them as an AgileConfigs.

Args:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
audio_sources: Mapping from dataset name to pairs of `(root directory, file
glob)`.
db_path: Location of the database. If None, the database will be created in
the same directory as the audio.
model_config_key: Name of the embedding model to use.
db_key: The type of database to use.

Returns:
AgileConfigs object with the loaded configs.
Expand All @@ -68,10 +71,7 @@ def load_configs(
'db_path must be specified when embedding multiple datasets.'
)
# Put the DB in the same directory as the audio.
db_path = (
epath.Path(next(iter(audio_sources.audio_globs)).base_path)
/ 'hoplite_db.sqlite'
)
db_path = epath.Path(next(iter(audio_sources.audio_globs)).base_path)

model_key, embedding_dim, model_config = (
model_configs.get_preset_model_config(model_config_key)
Expand All @@ -83,11 +83,19 @@ def load_configs(
)
db_config = config_dict.ConfigDict({
'db_path': db_path,
'embedding_dim': embedding_dim,
})
if db_key == 'sqlite_usearch':
# A sane default.
db_config.usearch_cfg = config_dict.ConfigDict({
'embedding_dim': embedding_dim,
'metric_name': 'IP',
'expansion_add': 256,
'expansion_search': 128,
'dtype': 'float16',
})

return AgileConfigs(
audio_sources_config=audio_sources,
db_config=db_loader.DBConfig('sqlite', db_config),
db_config=db_loader.DBConfig(db_key, db_config),
model_config=db_model_config,
)
4 changes: 4 additions & 0 deletions chirp/projects/hoplite/db_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
"""Database configuration and constructor."""

import dataclasses

from chirp.projects.hoplite import in_mem_impl
from chirp.projects.hoplite import interface
from chirp.projects.hoplite import sqlite_impl
from chirp.projects.hoplite import sqlite_usearch_impl
from ml_collections import config_dict
import numpy as np
import tqdm
Expand All @@ -40,6 +42,8 @@ def load_db(self) -> interface.GraphSearchDBInterface:
"""Load the database from the specified path."""
if self.db_key == 'sqlite':
return sqlite_impl.SQLiteGraphSearchDB.create(**self.db_config)
elif self.db_key == 'sqlite_usearch':
return sqlite_usearch_impl.SQLiteUsearchDB.create(**self.db_config)
elif self.db_key == 'in_mem':
return in_mem_impl.InMemoryGraphSearchDB.create(**self.db_config)
else:
Expand Down
Loading
Loading