From 3c59c1b11bbacf6f79af546afd9ec9a2fe06599a Mon Sep 17 00:00:00 2001 From: Prashanth R Date: Thu, 26 Oct 2023 10:40:48 -0700 Subject: [PATCH] Refactor NL Server bootstrap code for readability (#3725) This is to make follow on custom DC specific changes cleaner. There is no intentional logic change other than using a single cache entry (instead of one per index) in test environment. A key simplification is to more directly rely on the embeddings name to extract the base / tuned-model (see `config._parse()`) instead of the current code which uses `models.yaml` and also relies on the naming partially. --- nl_server/__init__.py | 57 +++++------- nl_server/config.py | 106 ++++++++++++++++++++++ nl_server/embeddings.py | 6 +- nl_server/embeddings_store.py | 37 ++++++++ nl_server/loader.py | 111 ++++++------------------ nl_server/routes.py | 14 +-- nl_server/tests/embeddings_test.py | 24 +++-- nl_server/tests/ner_place_model_test.py | 8 +- nl_server/tests/verb_test.py | 8 +- 9 files changed, 220 insertions(+), 151 deletions(-) create mode 100644 nl_server/config.py create mode 100644 nl_server/embeddings_store.py diff --git a/nl_server/__init__.py b/nl_server/__init__.py index d82a2ec547..96c1982526 100644 --- a/nl_server/__init__.py +++ b/nl_server/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import sys @@ -20,54 +19,46 @@ import torch import yaml +from nl_server import config import nl_server.loader as loader import nl_server.routes as routes +_MODEL_YAML = 'models.yaml' +_EMBEDDINGS_YAML = 'embeddings.yaml' + def create_app(): app = Flask(__name__) app.register_blueprint(routes.bp) - flask_env = os.environ.get('FLASK_ENV') - # https://github.com/UKPLab/sentence-transformers/issues/1318 if sys.version_info >= (3, 8) and sys.platform == "darwin": torch.set_num_threads(1) - # Download existing finetuned models (if not already downloaded). - models_downloaded_paths = {} - models_config_path = '/datacommons/nl/models.yaml' - if flask_env in ['local', 'test', 'integration_test', 'webdriver']: - models_config_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - 'deploy/nl/models.yaml') - app.config['MODELS_CONFIG_PATH'] = models_config_path - with open(app.config['MODELS_CONFIG_PATH']) as f: + with open(get_env_path(_MODEL_YAML)) as f: models_map = yaml.full_load(f) - if not models_map: - logging.error("No configuration found for model") - return + assert models_map, 'No models.yaml found!' - models_downloaded_paths = loader.download_models(models_map) + with open(get_env_path(_EMBEDDINGS_YAML)) as f: + embeddings_map = yaml.full_load(f) + assert embeddings_map, 'No embeddings.yaml found!' + app.config[config.NL_EMBEDDINGS_VERSION_KEY] = embeddings_map - assert models_downloaded_paths, "No models were found/downloaded. Check deploy/nl/models.yaml" + loader.load_server_state(app, embeddings_map, models_map) - # Download existing embeddings (if not already downloaded). - embeddings_config_path = '/datacommons/nl/embeddings.yaml' - if flask_env in ['local', 'test', 'integration_test', 'webdriver']: - embeddings_config_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - 'deploy/nl/embeddings.yaml') - app.config['EMBEDDINGS_CONFIG_PATH'] = embeddings_config_path + return app - # Initialize the NL module. - with open(app.config['EMBEDDINGS_CONFIG_PATH']) as f: - embeddings_map = yaml.full_load(f) - if not embeddings_map: - logging.error("No configuration found for embeddings") - return - app.config['EMBEDDINGS_VERSION_MAP'] = embeddings_map - loader.load_embeddings(app, embeddings_map, models_downloaded_paths) +# +# On prod the yaml files are in /datacommons/nl/, whereas +# in test-like environments it is the checked in path +# (deploy/nl/). +# +def get_env_path(file_name: str) -> str: + flask_env = os.environ.get('FLASK_ENV') + if flask_env in ['local', 'test', 'integration_test', 'webdriver']: + return os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + f'deploy/nl/{file_name}') - return app + return f'/datacommons/nl/{file_name}' diff --git a/nl_server/config.py b/nl_server/config.py new file mode 100644 index 0000000000..0967135f17 --- /dev/null +++ b/nl_server/config.py @@ -0,0 +1,106 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List + +from nl_server import embeddings +from nl_server import gcs + +# Index constants. Passed in `url=` +CUSTOM_DC_INDEX = 'custom' +DEFAULT_INDEX_TYPE = 'medium_ft' + +# The default base model we use. +EMBEDDINGS_BASE_MODEL_NAME = 'all-MiniLM-L6-v2' + +# App Config constants. +NL_MODEL_KEY = 'NL_MODEL' +NL_EMBEDDINGS_KEY = 'NL_EMBEDDINGS' +NL_EMBEDDINGS_VERSION_KEY = 'NL_EMBEDDINGS_VERSION_MAP' + + +# Defines one embeddings index config. +@dataclass +class EmbeddingsIndex: + # Name provided in the yaml file, and set in `idx=` URL param. + name: str + + # File name provided in the yaml file. + embeddings_file_name: str + # Local path. + embeddings_local_path: str = "" + + # Fine-tuned model name ("" if embeddings uses base model). + tuned_model: str = "" + # Fine-tuned model local path. + tuned_model_local_path: str = "" + + +# +# Validates the config input, downloads all the files and returns a list of Indexes to load. +# +def load(embeddings_map: Dict[str, str], + models_map: Dict[str, str]) -> List[EmbeddingsIndex]: + # Create Index objects. + indexes = _parse(embeddings_map) + + # This is just a sanity, we can soon deprecate models.yaml + tuned_models_provided = list(set(models_map.values())) + tuned_models_configured = list( + set([i.tuned_model for i in indexes if i.tuned_model])) + assert sorted(tuned_models_configured) == sorted(tuned_models_provided), \ + f'{tuned_models_configured} vs. {tuned_models_provided}' + + # + # Download all the models. + # + model2path = {d: gcs.download_model_folder(d) for d in tuned_models_configured} + for idx in indexes: + if idx.tuned_model: + idx.tuned_model_local_path = model2path[idx.tuned_model] + + # + # Download all the embeddings. + # + for idx in indexes: + idx.embeddings_local_path = gcs.download_embeddings( + idx.embeddings_file_name) + + return indexes + + +def _parse(embeddings_map: Dict[str, str]) -> List[EmbeddingsIndex]: + indexes: List[EmbeddingsIndex] = [] + + for key, value in embeddings_map.items(): + idx = EmbeddingsIndex(name=key, embeddings_file_name=value) + + parts = value.split('.') + assert parts[ + -1] == 'csv', f'Embeddings file {value} name does not end with .csv!' + + if len(parts) == 4: + # Expect: ...csv + # Example: embeddings_sdg_2023_09_12_16_38_04.ft_final_v20230717230459.all-MiniLM-L6-v2.csv + assert parts[ + 2] == EMBEDDINGS_BASE_MODEL_NAME, f'Unexpected base model {parts[3]}' + idx.tuned_model = f'{parts[1]}.{parts[2]}' + else: + # Expect: .csv + # Example: embeddings_small_2023_05_24_23_17_03.csv + assert len(parts) == 2, f'Unexpected file name format {value}' + indexes.append(idx) + + return indexes diff --git a/nl_server/embeddings.py b/nl_server/embeddings.py index 3d1a4f8f21..c469c37006 100644 --- a/nl_server/embeddings.py +++ b/nl_server/embeddings.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Managing the embeddings.""" -from dataclasses import dataclass import logging import os from typing import Dict, List, Union @@ -22,13 +21,12 @@ from sentence_transformers.util import semantic_search import torch +from nl_server import config from nl_server import query_util from shared.lib import constants from shared.lib import detected_variables as vars from shared.lib import utils -MODEL_NAME = 'all-MiniLM-L6-v2' - # A value higher than the highest score. _HIGHEST_SCORE = 1.0 _INIT_SCORE = (_HIGHEST_SCORE + 0.1) @@ -52,7 +50,7 @@ def __init__(self, assert os.path.exists(existing_model_path) self.model = SentenceTransformer(existing_model_path) else: - self.model = SentenceTransformer(MODEL_NAME) + self.model = SentenceTransformer(config.EMBEDDINGS_BASE_MODEL_NAME) self.dataset_embeddings: torch.Tensor = None self.dcids: List[str] = [] self.sentences: List[str] = [] diff --git a/nl_server/embeddings_store.py b/nl_server/embeddings_store.py new file mode 100644 index 0000000000..04817c5936 --- /dev/null +++ b/nl_server/embeddings_store.py @@ -0,0 +1,37 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from nl_server.config import DEFAULT_INDEX_TYPE +from nl_server.config import EmbeddingsIndex +from nl_server.embeddings import Embeddings + + +# +# A simple wrapper class around multiple embeddings indexes. +# +# TODO: Handle custom DC specific logic here. +# +class Store: + + def __init__(self, indexes: List[EmbeddingsIndex]): + self.embeddings_map = {} + for idx in indexes: + self.embeddings_map[idx.name] = Embeddings(idx.embeddings_local_path, + idx.tuned_model_local_path) + + # Note: The caller takes care of exceptions. + def get(self, index_type: str = DEFAULT_INDEX_TYPE) -> Embeddings: + return self.embeddings_map[index_type] diff --git a/nl_server/loader.py b/nl_server/loader.py index 07c66df6a9..95116eea17 100644 --- a/nl_server/loader.py +++ b/nl_server/loader.py @@ -12,113 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os +from typing import Any, Dict -from nl_server import gcs -from nl_server.embeddings import Embeddings +from nl_server import config +from nl_server import embeddings_store from nl_server.nl_attribute_model import NLAttributeModel -nl_embeddings_cache_key_base = 'nl_embeddings' -nl_model_cache_key = 'nl_model' -nl_cache_path = '~/.datacommons/' -nl_cache_expire = 3600 * 24 # Cache for 1 day -nl_cache_size_limit = 16e9 # 16Gb local cache size - -DEFAULT_INDEX_TYPE = 'medium_ft' - - -def nl_embeddings_cache_key(index_type=DEFAULT_INDEX_TYPE): - return f'{nl_embeddings_cache_key_base}_{index_type}' - - -def embeddings_config_key(index_type): - return f'NL_EMBEDDINGS_{index_type.upper()}' +NL_CACHE_PATH = '~/.datacommons/' +NL_EMBEDDINGS_CACHE_KEY = 'nl_embeddings' +NL_MODEL_CACHE_KEY = 'nl_model' +_NL_CACHE_EXPIRE = 3600 * 24 # Cache for 1 day +_NL_CACHE_SIZE_LIMIT = 16e9 # 16Gb local cache size def _use_cache(flask_env): return flask_env in ['local', 'integration_test', 'webdriver'] -def download_models(models_map): - # Download existing models (if not already downloaded). - models_downloaded_paths = {} - for m in models_map: - # Only downloads if not already done so. - download_path = gcs.download_model_folder(models_map[m]) - models_downloaded_paths[models_map[m]] = download_path - - return models_downloaded_paths - - -def load_embeddings(app, embeddings_map, models_downloaded_paths): +def load_server_state(app: Any, embeddings_map: Dict[str, str], + models_map: Dict[str, str]): flask_env = os.environ.get('FLASK_ENV') - # Sanity check that file names aren't mispresented - for sz in embeddings_map.keys(): - if '_ft' in sz: - assert '.ft_final' in embeddings_map[ - sz], f'ft_final not found {embeddings_map[sz]}' - size_str = sz.split("_ft")[0] - assert size_str in embeddings_map[ - sz], f'{size_str} not found {embeddings_map[sz]}' - else: - assert sz in embeddings_map[sz], f'{sz} not found in {embeddings_map[sz]}' - # In local dev, cache the embeddings on disk so each hot reload won't download # the embeddings again. if _use_cache(flask_env): from diskcache import Cache - cache = Cache(nl_cache_path, size_limit=nl_cache_size_limit) + cache = Cache(NL_CACHE_PATH, size_limit=_NL_CACHE_SIZE_LIMIT) cache.expire() - nl_model = cache.get(nl_model_cache_key) - app.config['NL_MODEL'] = nl_model - - missing_embeddings = False - for sz in embeddings_map.keys(): - nl_embeddings = cache.get(nl_embeddings_cache_key(sz)) - if not nl_embeddings: - missing_embeddings = True - break - app.config[embeddings_config_key(sz)] = nl_embeddings - - if nl_model and not missing_embeddings: - logging.info("Using cached embeddings and NL Model in: " + - cache.directory) + nl_model = cache.get(NL_MODEL_CACHE_KEY) + nl_embeddings = cache.get(NL_EMBEDDINGS_CACHE_KEY) + if nl_model and nl_embeddings: + app.config[config.NL_MODEL_KEY] = nl_model + app.config[config.NL_EMBEDDINGS_KEY] = nl_embeddings return - # Download the embeddings from GCS - for sz in sorted(embeddings_map.keys()): - assert sz in embeddings_map, f'{sz} missing from {embeddings_map}' - - existing_model_path = "" - for model_key, model_path in models_downloaded_paths.items(): - if model_key in embeddings_map[sz]: - existing_model_path = model_path - print( - f"Using existing model {model_key} for embeddings ({sz}) version: {embeddings_map[sz]}" - ) - break - - # Checking that the finetuned embeddings have the finetuned model. - if "_ft" in sz: - assert existing_model_path, f"Could not find a finetuned model for finetuned embeddings ({sz}) version: {embeddings_map[sz]}" - - print( - f"Building an Embeddings object with the model: {existing_model_path} (empty means default) and embeddings file ({sz}) version: {embeddings_map[sz]}." - ) - nl_embeddings = Embeddings(gcs.download_embeddings(embeddings_map[sz]), - existing_model_path) - app.config[embeddings_config_key(sz)] = nl_embeddings + nl_embeddings = embeddings_store.Store(config.load(embeddings_map, + models_map)) + app.config[config.NL_EMBEDDINGS_KEY] = nl_embeddings nl_model = NLAttributeModel() - app.config["NL_MODEL"] = nl_model + app.config[config.NL_MODEL_KEY] = nl_model if _use_cache(flask_env): - with Cache(cache.directory, size_limit=nl_cache_size_limit) as reference: - for sz in embeddings_map.keys(): - reference.set(nl_embeddings_cache_key(sz), - app.config[embeddings_config_key(sz)], - expire=nl_cache_expire) - reference.set(nl_model_cache_key, nl_model, expire=nl_cache_expire) + with Cache(cache.directory, size_limit=_NL_CACHE_SIZE_LIMIT) as reference: + reference.set(NL_EMBEDDINGS_CACHE_KEY, + nl_embeddings, + expire=_NL_CACHE_EXPIRE) + reference.set(NL_MODEL_CACHE_KEY, nl_model, expire=_NL_CACHE_EXPIRE) diff --git a/nl_server/routes.py b/nl_server/routes.py index 08473b7f75..32d4fde706 100644 --- a/nl_server/routes.py +++ b/nl_server/routes.py @@ -20,7 +20,7 @@ from flask import request from markupsafe import escape -from nl_server import loader as ld +from nl_server import config bp = Blueprint('main', __name__, url_prefix='/') @@ -41,14 +41,14 @@ def search_sv(): } """ query = str(escape(request.args.get('q'))) - sz = str(escape(request.args.get('sz', ld.DEFAULT_INDEX_TYPE))) + sz = str(escape(request.args.get('sz', config.DEFAULT_INDEX_TYPE))) if not sz: - sz = ld.DEFAULT_INDEX_TYPE + sz = config.DEFAULT_INDEX_TYPE skip_multi_sv = False if request.args.get('skip_multi_sv'): skip_multi_sv = True try: - nl_embeddings = current_app.config[ld.embeddings_config_key(sz)] + nl_embeddings = current_app.config[config.NL_EMBEDDINGS_KEY].get(sz) return json.dumps(nl_embeddings.detect_svs(query, skip_multi_sv)) except Exception as e: logging.error(f'Embeddings-based SV detection failed with error: {e}') @@ -69,7 +69,7 @@ def search_places(): } """ query = str(escape(request.args.get('q'))) - nl_model = current_app.config['NL_MODEL'] + nl_model = current_app.config[config.NL_MODEL_KEY] try: res = nl_model.detect_places_ner(query) return json.dumps({'places': res}) @@ -85,10 +85,10 @@ def search_verbs(): List[str] """ query = str(escape(request.args.get('q'))) - nl_model = current_app.config['NL_MODEL'] + nl_model = current_app.config[config.NL_MODEL_KEY] return json.dumps(nl_model.detect_verbs(query.strip())) @bp.route('/api/embeddings_version_map/', methods=['GET']) def embeddings_version_map(): - return json.dumps(current_app.config['EMBEDDINGS_VERSION_MAP']) \ No newline at end of file + return json.dumps(current_app.config[config.NL_EMBEDDINGS_VERSION_KEY]) diff --git a/nl_server/tests/embeddings_test.py b/nl_server/tests/embeddings_test.py index c9bccef2f6..0a70fef107 100644 --- a/nl_server/tests/embeddings_test.py +++ b/nl_server/tests/embeddings_test.py @@ -19,13 +19,13 @@ from diskcache import Cache from parameterized import parameterized -from sklearn.metrics.pairwise import cosine_similarity import yaml +from nl_server import embeddings_store as store from nl_server import gcs -from nl_server import loader from nl_server.embeddings import Embeddings -from nl_server.loader import nl_cache_path +from nl_server.loader import NL_CACHE_PATH +from nl_server.loader import NL_EMBEDDINGS_CACHE_KEY _root_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -33,15 +33,13 @@ _test_data = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data') -_tuned_model_key = "tuned_model" - # TODO(pradh): Expand tests to other index sizes. def _get_embeddings_file_path() -> str: embeddings_config_path = os.path.join(_root_dir, 'deploy/nl/embeddings.yaml') with open(embeddings_config_path) as f: embeddings = yaml.full_load(f) - embeddings_file = embeddings[loader.DEFAULT_INDEX_TYPE] + embeddings_file = embeddings[store.DEFAULT_INDEX_TYPE] return gcs.download_embeddings(embeddings_file) @@ -49,9 +47,7 @@ def _get_tuned_model_path() -> str: models_config_path = os.path.join(_root_dir, 'deploy/nl/models.yaml') with open(models_config_path) as f: models_map = yaml.full_load(f) - tuned_model_dict = {_tuned_model_key: models_map[_tuned_model_key]} - models_downloaded_paths = loader.download_models(tuned_model_dict) - return models_downloaded_paths[models_map[_tuned_model_key]] + return gcs.download_model_folder(models_map['tuned_model']) class TestEmbeddings(unittest.TestCase): @@ -59,11 +55,11 @@ class TestEmbeddings(unittest.TestCase): @classmethod def setUpClass(cls) -> None: # Look for the Embeddings in the cache if it exists. - cache = Cache(nl_cache_path) + cache = Cache(NL_CACHE_PATH) cache.expire() - cls.nl_embeddings = cache.get(loader.nl_embeddings_cache_key()) + embeddings_store = cache.get(NL_EMBEDDINGS_CACHE_KEY) - if not cls.nl_embeddings: + if not embeddings_store: print( "Could not load the embeddings from the cache for these tests. Loading a new embeddings object." ) @@ -73,11 +69,13 @@ def setUpClass(cls) -> None: # model pointed to in models.yaml. # If the default index is not a "finetuned" index, then the default model can be used. tuned_model_path = "" - if "ft" in loader.DEFAULT_INDEX_TYPE: + if "ft" in store.DEFAULT_INDEX_TYPE: tuned_model_path = _get_tuned_model_path() cls.nl_embeddings = Embeddings(_get_embeddings_file_path(), tuned_model_path) + else: + cls.nl_embeddings = embeddings_store.get() @parameterized.expand([ # All these queries should detect one of the SVs as the top choice. diff --git a/nl_server/tests/ner_place_model_test.py b/nl_server/tests/ner_place_model_test.py index 9a6e3c6438..bb19804e5a 100644 --- a/nl_server/tests/ner_place_model_test.py +++ b/nl_server/tests/ner_place_model_test.py @@ -19,8 +19,8 @@ from diskcache import Cache from parameterized import parameterized -from nl_server.loader import nl_cache_path -from nl_server.loader import nl_model_cache_key +from nl_server.loader import NL_CACHE_PATH +from nl_server.loader import NL_MODEL_CACHE_KEY from nl_server.nl_attribute_model import NLAttributeModel import shared.lib.utils as nl_utils @@ -31,9 +31,9 @@ class TestNERPlaces(unittest.TestCase): def setUpClass(cls) -> None: # Look for the Embeddings model in the cache if it exists. - cache = Cache(nl_cache_path) + cache = Cache(NL_CACHE_PATH) cache.expire() - cls.nl_model = cache.get(nl_model_cache_key) + cls.nl_model = cache.get(NL_MODEL_CACHE_KEY) if not cls.nl_model: logging.error( "Could not load models from the cache for these tests. Loading a new model object." diff --git a/nl_server/tests/verb_test.py b/nl_server/tests/verb_test.py index 92923f76be..4bee51a569 100644 --- a/nl_server/tests/verb_test.py +++ b/nl_server/tests/verb_test.py @@ -19,8 +19,8 @@ from diskcache import Cache from parameterized import parameterized -from nl_server.loader import nl_cache_path -from nl_server.loader import nl_model_cache_key +from nl_server.loader import NL_CACHE_PATH +from nl_server.loader import NL_MODEL_CACHE_KEY from nl_server.nl_attribute_model import NLAttributeModel @@ -30,9 +30,9 @@ class TestVerbs(unittest.TestCase): def setUpClass(cls) -> None: # Look for the Embeddings model in the cache if it exists. - cache = Cache(nl_cache_path) + cache = Cache(NL_CACHE_PATH) cache.expire() - cls.nl_model = cache.get(nl_model_cache_key) + cls.nl_model = cache.get(NL_MODEL_CACHE_KEY) if not cls.nl_model: logging.error( 'Could not load model from the cache for these tests. Loading a new model object.'