diff --git a/nl_server/config_reader.py b/nl_server/config_reader.py index cb03cf6e82..4a04781bad 100644 --- a/nl_server/config_reader.py +++ b/nl_server/config_reader.py @@ -90,7 +90,7 @@ def read_catalog_config() -> CatalogConfig: if gcs.is_gcs_path(user_data_path): full_gcs_path = os.path.join(user_data_path, _CATALOG_USER_PATH_SUFFIX) - gcs.download_file_by_path(full_gcs_path, _CATALOG_TMP_PATH) + gcs.download_blob_by_path(full_gcs_path, _CATALOG_TMP_PATH) all_paths.append(_CATALOG_TMP_PATH) else: full_user_path = os.path.join(user_data_path, _CATALOG_USER_PATH_SUFFIX) diff --git a/nl_server/registry.py b/nl_server/registry.py index 8f7425b026..e1ed87fd3f 100644 --- a/nl_server/registry.py +++ b/nl_server/registry.py @@ -41,11 +41,12 @@ # class ResourceRegistry: - # Input is parsed runnable config. + # Input is server config object. def __init__(self, server_config: ServerConfig): self.name_to_emb: dict[str, Embeddings] = {} self.name_to_emb_model: Dict[str, EmbeddingsModel] = {} self.name_to_rank_model: Dict[str, RerankingModel] = {} + self._attribute_model = AttributeModel() self.load(server_config) # Note: The caller takes care of exceptions. @@ -65,14 +66,12 @@ def attribute_model(self) -> AttributeModel: def get_model(self, model_name: str) -> EmbeddingsModel: return self.name_to_emb_model.get(model_name) - # Adds the new models and indexes in a RunnableConfig object to the - # embeddings + # Load the registry from the server config def load(self, server_config: ServerConfig): self._server_config = server_config self._load_models(server_config.models) for idx_name, idx_info in server_config.indexes.items(): self._set_embeddings(idx_name, idx_info) - self._attribute_model = AttributeModel() # Loads a dict of model name -> model info def _load_models(self, models: dict[str, ModelConfig]): diff --git a/nl_server/routes.py b/nl_server/routes.py index c5ca4e1220..263f87b249 100644 --- a/nl_server/routes.py +++ b/nl_server/routes.py @@ -22,6 +22,7 @@ from flask import request from markupsafe import escape +from nl_server import registry from nl_server import search from nl_server.embeddings import Embeddings from nl_server.registry import REGISTRY_KEY @@ -35,12 +36,12 @@ @bp.route('/healthz') def healthz(): - registry: ResourceRegistry = current_app.config[REGISTRY_KEY] - default_indexes = registry.server_config().default_indexes + r: ResourceRegistry = current_app.config[REGISTRY_KEY] + default_indexes = r.server_config().default_indexes if not default_indexes: logging.warning('Health Check Failed: Default index name empty!') return 'Service Unavailable', 500 - embeddings: Embeddings = registry.get_index(default_indexes[0]) + embeddings: Embeddings = r.get_index(default_indexes[0]) if embeddings: query = embeddings.store.healthcheck_query result: VarCandidates = search.search_vars([embeddings], [query]).get(query) @@ -70,18 +71,19 @@ def search_vars(): if request.args.get('skip_topics'): skip_topics = True + r: ResourceRegistry = current_app.config[REGISTRY_KEY] + reranker_name = str(escape(request.args.get('reranker', ''))) - reranker_model = registry.get_reranking_model( + reranker_model = r.get_reranking_model( reranker_name) if reranker_name else None - registry = current_app.config[REGISTRY_KEY] - default_indexes = registry.server_config().default_indexes + default_indexes = r.server_config().default_indexes idx_type_str = str(escape(request.args.get('idx', ''))) if not idx_type_str: idx_types = default_indexes else: idx_types = idx_type_str.split(',') - embeddings = _get_indexes(registry, idx_types) + embeddings = _get_indexes(r, idx_types) debug_logs = {'sv_detection_query_index_type': idx_types} results = search.search_vars(embeddings, queries, skip_topics, reranker_model, @@ -101,34 +103,33 @@ def detect_verbs(): List[str] """ query = str(escape(request.args.get('q'))) - registry = current_app.config[REGISTRY_KEY] - return json.dumps(registry.attribute_model().detect_verbs(query.strip())) + r: ResourceRegistry = current_app.config[REGISTRY_KEY] + return json.dumps(r.attribute_model().detect_verbs(query.strip())) @bp.route('/api/embeddings_version_map/', methods=['GET']) def embeddings_version_map(): - registry: ResourceRegistry = current_app.config[REGISTRY_KEY] - server_config = registry.server_config() + r: ResourceRegistry = current_app.config[REGISTRY_KEY] + server_config = r.server_config() return json.dumps(asdict(server_config)) @bp.route('/api/load/', methods=['GET']) def load(): try: - current_app.config[registry.REGISTRY_KEY] = registry.build() + current_app.config[REGISTRY_KEY] = registry.build() except Exception as e: logging.error(f'Custom embeddings not loaded due to error: {str(e)}') - registry = current_app.config[REGISTRY_KEY] - server_config = registry.server_config() + r: ResourceRegistry = current_app.config[REGISTRY_KEY] + server_config = r.server_config() return json.dumps(asdict(server_config)) -def _get_indexes(registry: ResourceRegistry, - idx_types: List[str]) -> List[Embeddings]: +def _get_indexes(r: ResourceRegistry, idx_types: List[str]) -> List[Embeddings]: embeddings: List[Embeddings] = [] for idx in idx_types: try: - emb = registry.get_index(idx) + emb = r.get_index(idx) if emb: embeddings.append(emb) except Exception as e: diff --git a/run_test.sh b/run_test.sh index 8e478869f5..4215c515ed 100755 --- a/run_test.sh +++ b/run_test.sh @@ -106,10 +106,8 @@ function run_py_test { # Tests within tools/nl/embeddings echo "Running tests within tools/nl/embeddings:" - cd tools/nl/embeddings - pip3 install -r requirements.txt - python3 -m pytest ./ -s - cd ../../.. + pip3 install -r tools/nl/embeddings/requirements.txt + python3 -m pytest tools/nl/embeddings/ -s pip3 install yapf==0.40.2 -q if ! command -v isort &> /dev/null diff --git a/tools/nl/embeddings/build_custom_dc_embeddings.py b/tools/nl/embeddings/build_custom_dc_embeddings.py index 0bc410fae1..ae73dbfcf1 100644 --- a/tools/nl/embeddings/build_custom_dc_embeddings.py +++ b/tools/nl/embeddings/build_custom_dc_embeddings.py @@ -17,10 +17,7 @@ from absl import app from absl import flags -from file_util import create_file_handler -from file_util import FileHandler import pandas as pd -import utils import yaml from nl_server import config @@ -28,6 +25,9 @@ from nl_server.config import CatalogConfig from nl_server.config import MemoryIndexConfig from nl_server.registry import ResourceRegistry +from tools.nl.embeddings import utils +from tools.nl.embeddings.file_util import create_file_handler +from tools.nl.embeddings.file_util import FileHandler class Mode: diff --git a/tools/nl/embeddings/build_custom_dc_embeddings_test.py b/tools/nl/embeddings/build_custom_dc_embeddings_test.py index bd911840d9..19180534a0 100644 --- a/tools/nl/embeddings/build_custom_dc_embeddings_test.py +++ b/tools/nl/embeddings/build_custom_dc_embeddings_test.py @@ -13,21 +13,24 @@ # limitations under the License. import os +from pathlib import Path import tempfile import unittest -from build_custom_dc_embeddings import EMBEDDINGS_CSV_FILENAME_PREFIX -from build_custom_dc_embeddings import EMBEDDINGS_YAML_FILE_NAME -import build_custom_dc_embeddings as builder -from file_util import create_file_handler from sentence_transformers import SentenceTransformer -import utils from nl_server.config import LocalModelConfig +from tools.nl.embeddings import utils +from tools.nl.embeddings.build_custom_dc_embeddings import \ + EMBEDDINGS_CSV_FILENAME_PREFIX +from tools.nl.embeddings.build_custom_dc_embeddings import \ + EMBEDDINGS_YAML_FILE_NAME +import tools.nl.embeddings.build_custom_dc_embeddings as builder +from tools.nl.embeddings.file_util import create_file_handler MODEL_NAME = "all-MiniLM-L6-v2" -INPUT_DIR = "testdata/custom_dc/input" -EXPECTED_DIR = "testdata/custom_dc/expected" +INPUT_DIR = Path(__file__).parent / "testdata/custom_dc/input" +EXPECTED_DIR = Path(__file__).parent / "testdata/custom_dc/expected" def _compare_files(test: unittest.TestCase, output_path, expected_path): diff --git a/tools/nl/embeddings/build_embeddings.py b/tools/nl/embeddings/build_embeddings.py index 88c0e39c80..f0e7770740 100644 --- a/tools/nl/embeddings/build_embeddings.py +++ b/tools/nl/embeddings/build_embeddings.py @@ -31,7 +31,8 @@ import lancedb import pandas as pd from sentence_transformers import SentenceTransformer -import utils + +from tools.nl.embeddings import utils VERTEX_AI_PROJECT = 'datcom-nl' VERTEX_AI_PROJECT_LOCATION = 'us-central1' diff --git a/tools/nl/embeddings/build_embeddings_test.py b/tools/nl/embeddings/build_embeddings_test.py index ba17a1d864..0705120932 100644 --- a/tools/nl/embeddings/build_embeddings_test.py +++ b/tools/nl/embeddings/build_embeddings_test.py @@ -17,12 +17,14 @@ import tempfile import unittest from unittest import mock +from pathlib import Path -import build_embeddings as be import pandas as pd from parameterized import parameterized from sentence_transformers import SentenceTransformer -import utils + +from tools.nl.embeddings import utils +import tools.nl.embeddings.build_embeddings as be def get_test_sv_data(): @@ -96,7 +98,7 @@ def testFailure(self): input_sheets_svs = [] # Filepaths all correspond to the testdata folder. - input_dir = "testdata/input" + input_dir = Path(__file__).parent / "testdata/input" input_alternatives_filepattern = os.path.join(input_dir, "*_alternatives.csv") input_autogen_filepattern = os.path.join(input_dir, 'unknown_*.csv') @@ -120,8 +122,8 @@ def testSuccess(self): tmp="/tmp") # Filepaths all correspond to the testdata folder. - input_dir = "testdata/input" - expected_dir = "testdata/expected" + input_dir = Path(__file__).parent / "testdata/input" + expected_dir = Path(__file__).parent / "testdata/expected" input_alternatives_filepattern = os.path.join(input_dir, "*_alternatives.csv") input_autogen_filepattern = os.path.join(input_dir, "autogen_*.csv") @@ -156,12 +158,16 @@ class TestEndToEndActualDataFiles(unittest.TestCase): @parameterized.expand(["small", "medium"]) def testInputFilesValidations(self, sz): # Verify that the required files exist. - sheets_filepath = "data/curated_input/main/sheets_svs.csv" + sheets_filepath = Path( + __file__).parent / "data/curated_input/main/sheets_svs.csv" # TODO: Fix palm_batch13k_alternatives.csv to not have duplicate # descriptions. Its technically okay since build_embeddings will take # care of dups. - input_alternatives_filepattern = "data/alternatives/(palm|other)_alternaties.csv" - output_dcid_sentences_filepath = f'data/preindex/{sz}/sv_descriptions.csv' + parent_folder = str(Path(__file__).parent) + input_alternatives_filepattern = os.path.join( + parent_folder, "data/alternatives/(palm|other)_alternaties.csv") + output_dcid_sentences_filepath = os.path.join( + parent_folder, f'data/preindex/{sz}/sv_descriptions.csv') # Check that all the files exist. self.assertTrue(os.path.exists(sheets_filepath)) @@ -192,7 +198,8 @@ def testInputFilesValidations(self, sz): @parameterized.expand(["small", "medium"]) def testOutputFileValidations(self, sz): - output_dcid_sentences_filepath = f'data/preindex/{sz}/sv_descriptions.csv' + output_dcid_sentences_filepath = Path( + __file__).parent / f'data/preindex/{sz}/sv_descriptions.csv' dcid_sentence_df = pd.read_csv(output_dcid_sentences_filepath).fillna("") diff --git a/tools/nl/embeddings/run.sh b/tools/nl/embeddings/run.sh index 978d1db582..59d247a5c2 100755 --- a/tools/nl/embeddings/run.sh +++ b/tools/nl/embeddings/run.sh @@ -97,24 +97,22 @@ done cd ../../.. python3 -m venv .env source .env/bin/activate -cd tools/nl/embeddings -python3 -m pip install --upgrade pip pip3 install torch==2.2.2 --extra-index-url https://download.pytorch.org/whl/cpu -pip3 install -r requirements.txt +pip3 install -r tools/nl/embeddings/requirements.txt if [[ "$MODEL_ENDPOINT_ID" != "" ]];then - python3 build_embeddings.py --embeddings_size=$2 \ + python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 \ --vertex_ai_prediction_endpoint_id=$MODEL_ENDPOINT_ID \ --curated_input_dirs="data/curated_input/main" \ --autogen_input_basedir="" \ --alternatives_filepattern="" elif [[ "$CURATED_INPUT_DIRS" != "" ]]; then - python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --curated_input_dirs=$CURATED_INPUT_DIRS --alternatives_filepattern=$ALTERNATIVES_FILE_PATTERN + python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --curated_input_dirs=$CURATED_INPUT_DIRS --alternatives_filepattern=$ALTERNATIVES_FILE_PATTERN elif [[ "$LANCEDB_OUTPUT_PATH" != "" ]]; then - python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --lancedb_output_path=$LANCEDB_OUTPUT_PATH --dry_run=True + python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --lancedb_output_path=$LANCEDB_OUTPUT_PATH --dry_run=True elif [[ "$FINETUNED_MODEL" != "" ]]; then - python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL + python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL else - python3 build_embeddings.py --embeddings_size=$2 + python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 fi diff --git a/tools/nl/embeddings/run_custom.sh b/tools/nl/embeddings/run_custom.sh index cc98169296..b617dd803d 100755 --- a/tools/nl/embeddings/run_custom.sh +++ b/tools/nl/embeddings/run_custom.sh @@ -15,10 +15,10 @@ set -e +cd ../../.. python3 -m venv .env source .env/bin/activate -python3 -m pip install --upgrade pip pip3 install torch==2.2.2 --extra-index-url https://download.pytorch.org/whl/cpu -pip3 install -r requirements.txt +pip3 install -r tools/nl/embeddings/requirements.txt -python3 build_custom_dc_embeddings.py "$@" \ No newline at end of file +python3 -m tools.nl.embeddings.build_custom_dc_embeddings "$@" \ No newline at end of file diff --git a/tools/nl/embeddings/utils.py b/tools/nl/embeddings/utils.py index 0643fcf7e1..15826892ca 100644 --- a/tools/nl/embeddings/utils.py +++ b/tools/nl/embeddings/utils.py @@ -16,14 +16,13 @@ from dataclasses import dataclass import itertools import logging -import os -from pathlib import Path from typing import Any, Dict, List, Tuple -from file_util import create_file_handler from google.cloud import aiplatform import pandas as pd +from tools.nl.embeddings.file_util import create_file_handler + # Col names in the input files/sheets. DCID_COL = 'dcid' NAME_COL = 'Name'