From 3a2854f075b4614baea488410469ba2feffd16e2 Mon Sep 17 00:00:00 2001 From: chejennifer <69875368+chejennifer@users.noreply.github.com> Date: Wed, 15 May 2024 16:12:20 -0700 Subject: [PATCH] [nl server] get default and enabled indexes from deployment (#4243) Get default index and enabled indexes from deployment because this could be instance specific --- .../dc_website/templates/config_maps.yaml | 2 +- deploy/helm_charts/dc_website/values.yaml | 6 +- deploy/helm_charts/envs/autopush.yaml | 55 +++++--- deploy/helm_charts/envs/bard.yaml | 8 ++ deploy/helm_charts/envs/biomedical.yaml | 8 ++ deploy/helm_charts/envs/climate_trace.yaml | 8 ++ deploy/helm_charts/envs/dev.yaml | 55 +++++--- deploy/helm_charts/envs/internal.yaml | 8 ++ deploy/helm_charts/envs/magic_eye.yaml | 8 ++ deploy/helm_charts/envs/prod.yaml | 8 ++ deploy/helm_charts/envs/staging.yaml | 8 ++ deploy/helm_charts/envs/unsdg.yaml | 8 ++ deploy/helm_charts/envs/unsdg_staging.yaml | 8 ++ deploy/helm_charts/envs/worldbank.yaml | 8 ++ nl_server/config.py | 122 +++++++----------- nl_server/custom_dc_constants.py | 22 ++++ nl_server/embeddings_map.py | 16 +-- nl_server/loader.py | 105 ++++++++++----- nl_server/routes.py | 12 +- nl_server/tests/custom_embeddings_test.py | 113 ++++++++-------- nl_server/tests/embeddings_test.py | 19 ++- 21 files changed, 384 insertions(+), 223 deletions(-) create mode 100644 nl_server/custom_dc_constants.py diff --git a/deploy/helm_charts/dc_website/templates/config_maps.yaml b/deploy/helm_charts/dc_website/templates/config_maps.yaml index 1b0e825f66..76d91b9514 100644 --- a/deploy/helm_charts/dc_website/templates/config_maps.yaml +++ b/deploy/helm_charts/dc_website/templates/config_maps.yaml @@ -81,7 +81,7 @@ metadata: data: embeddings.yaml: {{ required "NL embeddings file is required" .Values.nl.embeddings | quote }} models.yaml: {{ required "NL models file is required" .Values.nl.models | quote }} - vertex_ai_models.json: {{ .Values.nl.vertex_ai_models | toJson | quote }} + embeddings_spec.json: {{ .Values.nl.embeddingsSpec | toJson | quote }} {{- end }} {{- if .Values.website.redis.enabled }} diff --git a/deploy/helm_charts/dc_website/values.yaml b/deploy/helm_charts/dc_website/values.yaml index ffe2caf537..f579816a0c 100644 --- a/deploy/helm_charts/dc_website/values.yaml +++ b/deploy/helm_charts/dc_website/values.yaml @@ -136,7 +136,11 @@ nl: models: memory: "2G" workers: 1 - vertex_ai_models: + embeddingsSpec: + defaultIndex: "" + enabledIndexes: [] + vertexAIModels: + enableReranking: false ############################################################################### diff --git a/deploy/helm_charts/envs/autopush.yaml b/deploy/helm_charts/envs/autopush.yaml index ea15278681..5b9ca739cf 100644 --- a/deploy/helm_charts/envs/autopush.yaml +++ b/deploy/helm_charts/envs/autopush.yaml @@ -44,27 +44,40 @@ serviceAccount: nl: enabled: true - vertex_ai_models: - dc-all-minilm-l6-v2-model: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "8518340991868993536" - uae-large-v1-model: - project_id: datcom-nl - location: us-central1 - prediction_endpoint_id: "8110162693219942400" - sfr-embedding-mistral-model: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "224012300019826688" - cross-encoder-ms-marco-miniilm-l6-v2: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "3977846152316846080" - cross-encoder-mxbai-rerank-base-v1: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "284894457873039360" + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "base_uae_mem", + "bio_ft", + "medium_ft", + "medium_lance_ft", + "medium_vertex_ft", + "medium_vertex_mistral", + "sdg_ft", + "undata_ft", + ] + vertexAIModels: + dc-all-minilm-l6-v2-model: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "8518340991868993536" + uae-large-v1-model: + project_id: datcom-nl + location: us-central1 + prediction_endpoint_id: "8110162693219942400" + sfr-embedding-mistral-model: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "224012300019826688" + cross-encoder-ms-marco-miniilm-l6-v2: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "3977846152316846080" + cross-encoder-mxbai-rerank-base-v1: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "284894457873039360" + enableReranking: true serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/bard.yaml b/deploy/helm_charts/envs/bard.yaml index 36513bdb44..0ba2a1c282 100644 --- a/deploy/helm_charts/envs/bard.yaml +++ b/deploy/helm_charts/envs/bard.yaml @@ -44,6 +44,14 @@ nl: enabled: true memory: "10G" workers: 3 + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] nodejs: enabled: true diff --git a/deploy/helm_charts/envs/biomedical.yaml b/deploy/helm_charts/envs/biomedical.yaml index 3ada5b2df3..73d73c7a52 100644 --- a/deploy/helm_charts/envs/biomedical.yaml +++ b/deploy/helm_charts/envs/biomedical.yaml @@ -42,3 +42,11 @@ serviceGroups: nl: enabled: true + embeddingsSpec: + defaultIndex: "bio_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] diff --git a/deploy/helm_charts/envs/climate_trace.yaml b/deploy/helm_charts/envs/climate_trace.yaml index 6aa16898c4..02d4494f64 100644 --- a/deploy/helm_charts/envs/climate_trace.yaml +++ b/deploy/helm_charts/envs/climate_trace.yaml @@ -34,6 +34,14 @@ serviceAccount: nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/dev.yaml b/deploy/helm_charts/envs/dev.yaml index cbd16ee4e1..3b610f9255 100644 --- a/deploy/helm_charts/envs/dev.yaml +++ b/deploy/helm_charts/envs/dev.yaml @@ -41,27 +41,40 @@ serviceGroups: nl: enabled: true - vertex_ai_models: - dc-all-minilm-l6-v2-model: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "8518340991868993536" - uae-large-v1-model: - project_id: datcom-nl - location: us-central1 - prediction_endpoint_id: "1400502935879680000" - sfr-embedding-mistral-model: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "224012300019826688" - cross-encoder-ms-marco-miniilm-l6-v2: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "3977846152316846080" - cross-encoder-mxbai-rerank-base-v1: - project_id: datcom-website-dev - location: us-central1 - prediction_endpoint_id: "284894457873039360" + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "base_uae_mem", + "bio_ft", + "medium_ft", + "medium_lance_ft", + "medium_vertex_ft", + "medium_vertex_mistral", + "sdg_ft", + "undata_ft", + ] + vertexAIModels: + dc-all-minilm-l6-v2-model: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "8518340991868993536" + uae-large-v1-model: + project_id: datcom-nl + location: us-central1 + prediction_endpoint_id: "8110162693219942400" + sfr-embedding-mistral-model: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "224012300019826688" + cross-encoder-ms-marco-miniilm-l6-v2: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "3977846152316846080" + cross-encoder-mxbai-rerank-base-v1: + project_id: datcom-website-dev + location: us-central1 + prediction_endpoint_id: "284894457873039360" + enableReranking: true nodejs: enabled: true diff --git a/deploy/helm_charts/envs/internal.yaml b/deploy/helm_charts/envs/internal.yaml index 6adc298a88..520f05374d 100644 --- a/deploy/helm_charts/envs/internal.yaml +++ b/deploy/helm_charts/envs/internal.yaml @@ -36,6 +36,14 @@ ingress: certName: website-ssl-certificate nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/magic_eye.yaml b/deploy/helm_charts/envs/magic_eye.yaml index 06fee9be2f..7cb44d73c3 100644 --- a/deploy/helm_charts/envs/magic_eye.yaml +++ b/deploy/helm_charts/envs/magic_eye.yaml @@ -37,6 +37,14 @@ ingress: nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/prod.yaml b/deploy/helm_charts/envs/prod.yaml index 4f6b3807d8..8f4c0f9374 100644 --- a/deploy/helm_charts/envs/prod.yaml +++ b/deploy/helm_charts/envs/prod.yaml @@ -41,6 +41,14 @@ serviceAccount: nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/staging.yaml b/deploy/helm_charts/envs/staging.yaml index 183abbff24..81cf36eab8 100644 --- a/deploy/helm_charts/envs/staging.yaml +++ b/deploy/helm_charts/envs/staging.yaml @@ -32,6 +32,14 @@ serviceAccount: nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/unsdg.yaml b/deploy/helm_charts/envs/unsdg.yaml index 2cc8306c66..04edf82e9d 100644 --- a/deploy/helm_charts/envs/unsdg.yaml +++ b/deploy/helm_charts/envs/unsdg.yaml @@ -40,6 +40,14 @@ ingress: nl: enabled: true + embeddingsSpec: + defaultIndex: "sdg_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/unsdg_staging.yaml b/deploy/helm_charts/envs/unsdg_staging.yaml index 1739517e7e..144839cbef 100644 --- a/deploy/helm_charts/envs/unsdg_staging.yaml +++ b/deploy/helm_charts/envs/unsdg_staging.yaml @@ -37,6 +37,14 @@ mixer: nl: enabled: true + embeddingsSpec: + defaultIndex: "sdg_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] serviceGroups: recon: null diff --git a/deploy/helm_charts/envs/worldbank.yaml b/deploy/helm_charts/envs/worldbank.yaml index 0a6736d31b..419ce03369 100644 --- a/deploy/helm_charts/envs/worldbank.yaml +++ b/deploy/helm_charts/envs/worldbank.yaml @@ -49,6 +49,14 @@ serviceGroups: memoryLimit: "8G" nl: enabled: true + embeddingsSpec: + defaultIndex: "medium_ft" + enabledIndexes: [ + "bio_ft" + "medium_ft", + "sdg_ft", + "undata_ft", + ] svg: blocklistFile: ["dc/g/Uncategorized", "oecd/g/OECD"] diff --git a/nl_server/config.py b/nl_server/config.py index 3e09114691..ca08aea855 100644 --- a/nl_server/config.py +++ b/nl_server/config.py @@ -15,28 +15,18 @@ from abc import ABC from dataclasses import dataclass from enum import Enum -import json -import logging -import os -from pathlib import Path from typing import Dict -import yaml - from shared.lib import constants -from shared.lib.custom_dc_util import is_custom_dc # Index constants. Passed in `url=` CUSTOM_DC_INDEX: str = 'custom_ft' -DEFAULT_INDEX_TYPE: str = 'medium_ft' # App Config constants. ATTRIBUTE_MODEL_KEY: str = 'ATTRIBUTE_MODEL' NL_EMBEDDINGS_KEY: str = 'NL_EMBEDDINGS' NL_EMBEDDINGS_VERSION_KEY: str = 'NL_EMBEDDINGS_VERSION_MAP' -VERTEX_AI_MODELS_KEY: str = 'VERTEX_AI_MODELS' - -_VERTEX_AI_MODEL_CONFIG_PATH: str = '/datacommons/nl/vertex_ai_models.json' +EMBEDDINGS_SPEC_KEY: str = 'EMBEDDINGS_SPEC' class StoreType(str, Enum): @@ -106,81 +96,42 @@ class EmbeddingsConfig: models: Dict[str, ModelConfig] -# -# Get Dict of vertex ai model to its info -# -def _get_vertex_ai_model_info() -> Dict[str, any]: - # Custom DC doesn't use vertex ai so just return an empty dict - # TODO: if we want to use vertex ai for custom dc, can add a file with the - # config to the custom dc docker image here: https://github.com/datacommonsorg/website/blob/master/build/web_compose/Dockerfile#L67 - if is_custom_dc(): - return {} - - # This is the path to model info when deployed in gke. - if os.path.exists(_VERTEX_AI_MODEL_CONFIG_PATH): - with open(_VERTEX_AI_MODEL_CONFIG_PATH) as f: - return json.load(f) or {} - # If that path doesn't exist, assume we are running locally and use the values - # from autopush. - else: - current_file_path = Path(__file__) - autopush_env_values = f'{current_file_path.parent.parent}/deploy/helm_charts/envs/autopush.yaml' - with open(autopush_env_values) as f: - autopush_env = yaml.full_load(f) - return autopush_env['nl']['vertex_ai_models'] +# Determines whether model is enabled +def _is_model_enabled(model_name: str, model_info: Dict[str, str], + used_models: set[str], reranking_enabled: bool): + if model_name in used_models: + return True + if model_info['usage'] == ModelUsage.RERANKING and reranking_enabled: + return True + return False # -# Parse the input `embeddings.yaml` dict representation into EmbeddingsInfo +# Parse the input `embeddings.yaml` dict representation into EmbeddingsConfig # object. # -def parse(embeddings_map: Dict[str, any]) -> EmbeddingsConfig: - get_vertex_ai_model_info = _get_vertex_ai_model_info() +def parse(embeddings_map: Dict[str, any], vertex_ai_model_info: Dict[str, any], + reranking_enabled: bool) -> EmbeddingsConfig: if embeddings_map['version'] == 1: - return parse_v1(embeddings_map, get_vertex_ai_model_info) + return parse_v1(embeddings_map, vertex_ai_model_info, reranking_enabled) else: raise AssertionError('Could not parse embeddings map: unsupported version.') # # Parses the v1 version of the `embeddings.yaml` dict representation into -# EmbeddingsInfo object. +# EmbeddingsConfig object. # -def parse_v1(embeddings_map: Dict[str, any], - vertex_ai_model_info: Dict[str, any]) -> EmbeddingsConfig: - # parse the models - models = {} - for model_name, model_info in embeddings_map.get('models', {}).items(): - model_type = model_info['type'] - score_threshold = model_info.get('score_threshold', - constants.SV_SCORE_DEFAULT_THRESHOLD) - if model_type == ModelType.LOCAL: - models[model_name] = LocalModelConfig(type=model_type, - score_threshold=score_threshold, - usage=model_info['usage'], - gcs_folder=model_info['gcs_folder']) - elif model_type == ModelType.VERTEXAI: - if model_name not in vertex_ai_model_info: - logging.error( - f'Could not find vertex ai model information for {model_name}') - continue - models[model_name] = VertexAIModelConfig( - type=model_type, - score_threshold=score_threshold, - usage=model_info['usage'], - project_id=vertex_ai_model_info[model_name]['project_id'], - prediction_endpoint_id=vertex_ai_model_info[model_name] - ['prediction_endpoint_id'], - location=vertex_ai_model_info[model_name]['location']) - else: - raise AssertionError( - 'Error parsing information for model {model_name}: unsupported type {model_type}' - ) +def parse_v1(embeddings_map: Dict[str, any], vertex_ai_model_info: Dict[str, + any], + reranking_enabled: bool) -> EmbeddingsConfig: + used_models = set() # parse the indexes indexes = {} for index_name, index_info in embeddings_map.get('indexes', {}).items(): store_type = index_info['store'] + used_models.add(index_info['model']) if store_type == StoreType.MEMORY: indexes[index_name] = MemoryIndexConfig( store_type=store_type, @@ -205,11 +156,32 @@ def parse_v1(embeddings_map: Dict[str, any], 'Error parsing information for index {index_name}: unsupported store type {store_type}' ) - return EmbeddingsConfig(indexes=indexes, models=models) - + # parse the models + models = {} + for model_name, model_info in embeddings_map.get('models', {}).items(): + if not _is_model_enabled(model_name, model_info, used_models, + reranking_enabled): + continue -# Returns true if VERTEXAI type models and VERTEXAI type stores are allowed -def allow_vertex_ai() -> bool: - return os.environ.get('FLASK_ENV') in [ - 'local', 'test', 'integration_test', 'autopush', 'dev' - ] + model_type = model_info['type'] + score_threshold = model_info.get('score_threshold', + constants.SV_SCORE_DEFAULT_THRESHOLD) + if model_type == ModelType.LOCAL: + models[model_name] = LocalModelConfig(type=model_type, + score_threshold=score_threshold, + usage=model_info['usage'], + gcs_folder=model_info['gcs_folder']) + elif model_type == ModelType.VERTEXAI: + models[model_name] = VertexAIModelConfig( + type=model_type, + score_threshold=score_threshold, + usage=model_info['usage'], + project_id=vertex_ai_model_info[model_name]['project_id'], + prediction_endpoint_id=vertex_ai_model_info[model_name] + ['prediction_endpoint_id'], + location=vertex_ai_model_info[model_name]['location']) + else: + raise AssertionError( + 'Error parsing information for model {model_name}: unsupported type {model_type}' + ) + return EmbeddingsConfig(indexes=indexes, models=models) diff --git a/nl_server/custom_dc_constants.py b/nl_server/custom_dc_constants.py new file mode 100644 index 0000000000..d40b326d95 --- /dev/null +++ b/nl_server/custom_dc_constants.py @@ -0,0 +1,22 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants used for custom dc instances +TODO: should be moved to be part of deployment""" + +# TODO: move this to be part of custom dc deployment +CUSTOM_DC_EMBEDDINGS_SPEC = { + 'defaultIndex': 'medium_ft', + 'enabledIndexes': ['medium_ft'], + 'enableReranking': False +} diff --git a/nl_server/embeddings_map.py b/nl_server/embeddings_map.py index ae084085d6..5b01520951 100644 --- a/nl_server/embeddings_map.py +++ b/nl_server/embeddings_map.py @@ -15,14 +15,11 @@ import logging from typing import Dict -from nl_server.config import allow_vertex_ai -from nl_server.config import DEFAULT_INDEX_TYPE from nl_server.config import EmbeddingsConfig from nl_server.config import IndexConfig from nl_server.config import ModelConfig from nl_server.config import ModelType from nl_server.config import ModelUsage -from nl_server.config import parse from nl_server.config import StoreType from nl_server.embeddings import Embeddings from nl_server.embeddings import EmbeddingsModel @@ -40,17 +37,16 @@ # class EmbeddingsMap: - # Input is the in-memory representation of `embeddings.yaml` structure. - def __init__(self, embeddings_dict: dict[str, dict[str, str]]): + # Input is parsed embeddings config. + def __init__(self, embeddings_config: EmbeddingsConfig): self.embeddings_map: dict[str, Embeddings] = {} self.name_to_emb_model: Dict[str, EmbeddingsModel] = {} self.name_to_rank_model: Dict[str, RerankingModel] = {} - embeddings_info = parse(embeddings_dict) - self.reset_index(embeddings_info) + self.reset_index(embeddings_config) # Note: The caller takes care of exceptions. - def get_index(self, index_type: str = DEFAULT_INDEX_TYPE) -> Embeddings: + def get_index(self, index_type: str) -> Embeddings: return self.embeddings_map.get(index_type) def get_reranking_model(self, model_name: str) -> RerankingModel: @@ -73,7 +69,7 @@ def _load_models(self, models: dict[str, ModelConfig]): # try creating a model object from the model info try: - if (allow_vertex_ai() and model_info.type == ModelType.VERTEXAI): + if model_info.type == ModelType.VERTEXAI: if model_info.usage == ModelUsage.EMBEDDINGS: model = VertexAIEmbeddingsModel(model_info) self.name_to_emb_model[model_name] = model @@ -105,7 +101,7 @@ def _set_embeddings(self, idx_name: str, idx_info: IndexConfig): else: logging.info('Not loading LanceDB in Custom DC environment!') return - elif idx_info.store_type == StoreType.VERTEXAI and allow_vertex_ai(): + elif idx_info.store_type == StoreType.VERTEXAI: store = VertexAIStore(idx_info) except Exception as e: logging.error(f'error loading index {idx_name}: {str(e)} ') diff --git a/nl_server/loader.py b/nl_server/loader.py index e019a5455c..8c9db212f9 100644 --- a/nl_server/loader.py +++ b/nl_server/loader.py @@ -12,23 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass +import json import logging import os -from typing import Dict +from pathlib import Path +from typing import Dict, List from flask import Flask import yaml from nl_server import config +from nl_server.custom_dc_constants import CUSTOM_DC_EMBEDDINGS_SPEC import nl_server.embeddings_map as emb_map from nl_server.nl_attribute_model import NLAttributeModel from shared.lib.custom_dc_util import get_custom_dc_user_data_path +from shared.lib.custom_dc_util import is_custom_dc from shared.lib.gcs import download_gcs_file from shared.lib.gcs import is_gcs_path from shared.lib.gcs import join_gcs_path _EMBEDDINGS_YAML = 'embeddings.yaml' _CUSTOM_EMBEDDINGS_YAML_PATH = 'datacommons/nl/custom_embeddings.yaml' +_EMBEDDINGS_SPEC_PATH: str = '/datacommons/nl/embeddings_spec.json' +_LOCAL_ENV_VALUES_PATH: str = f'{Path(__file__).parent.parent}/deploy/helm_charts/envs/autopush.yaml' + + +@dataclass +class EmbeddingsSpec: + default_index: str + enabled_indexes: List[str] + vertex_ai_model_info: Dict[str, any] + enable_reranking: bool # @@ -37,10 +52,15 @@ def load_server_state(app: Flask): flask_env = os.environ.get('FLASK_ENV') - embeddings_dict = _load_yaml(flask_env) - nl_embeddings = emb_map.EmbeddingsMap(embeddings_dict) + embeddings_spec = _get_embeddings_spec() + embeddings_dict = _load_yaml(flask_env, embeddings_spec.enabled_indexes) + parsed_embeddings_dict = config.parse(embeddings_dict, + embeddings_spec.vertex_ai_model_info, + embeddings_spec.enable_reranking) + nl_embeddings = emb_map.EmbeddingsMap(parsed_embeddings_dict) attribute_model = NLAttributeModel() - _update_app_config(app, attribute_model, nl_embeddings, embeddings_dict) + _update_app_config(app, attribute_model, nl_embeddings, embeddings_dict, + embeddings_spec) def load_custom_embeddings(app: Flask): @@ -54,13 +74,16 @@ def load_custom_embeddings(app: Flask): on a local path. """ flask_env = os.environ.get('FLASK_ENV') - embeddings_map = _load_yaml(flask_env) + embeddings_spec = _get_embeddings_spec() + embeddings_map = _load_yaml(flask_env, embeddings_spec.enabled_indexes) # This lookup will raise an error if embeddings weren't already initialized previously. # This is intentional. nl_embeddings: emb_map.EmbeddingsMap = app.config[config.NL_EMBEDDINGS_KEY] try: - embeddings_info = config.parse(embeddings_map) + embeddings_info = config.parse(embeddings_map, + embeddings_spec.vertex_ai_model_info, + embeddings_spec.enable_reranking) # Reset the custom DC index. nl_embeddings.reset_index(embeddings_info) except Exception as e: @@ -68,33 +91,24 @@ def load_custom_embeddings(app: Flask): # Update app config. _update_app_config(app, app.config[config.ATTRIBUTE_MODEL_KEY], nl_embeddings, - embeddings_map) + embeddings_map, embeddings_spec) # Takes an embeddings map and returns a version that only has the default -# index and its model info -def _get_default_only_emb_map(embeddings_map: Dict[str, any]) -> Dict[str, any]: - default_model_name = embeddings_map['indexes'][ - config.DEFAULT_INDEX_TYPE]['model'] - return { - 'version': embeddings_map['version'], - 'indexes': { - config.DEFAULT_INDEX_TYPE: - embeddings_map['indexes'][config.DEFAULT_INDEX_TYPE] - }, - 'models': { - default_model_name: embeddings_map['models'][default_model_name] - } - } - - -def _load_yaml(flask_env: str) -> Dict[str, any]: +# enabled indexes and its model info +def _get_enabled_only_emb_map(embeddings_map: Dict[str, any], + enabled_indexes: List[str]) -> Dict[str, any]: + indexes = {} + for index_name in enabled_indexes: + indexes[index_name] = embeddings_map['indexes'][index_name] + embeddings_map['indexes'] = indexes + return embeddings_map + + +def _load_yaml(flask_env: str, enabled_indexes: List[str]) -> Dict[str, any]: with open(get_env_path(flask_env, _EMBEDDINGS_YAML)) as f: embeddings_map = yaml.full_load(f) - - # For custom DC dev env, only keep the default index. - if _is_custom_dc_dev(flask_env): - embeddings_map = _get_default_only_emb_map(embeddings_map) + embeddings_map = _get_enabled_only_emb_map(embeddings_map, enabled_indexes) assert embeddings_map, 'No embeddings.yaml found!' @@ -107,15 +121,14 @@ def _load_yaml(flask_env: str) -> Dict[str, any]: return embeddings_map -def _update_app_config(app: Flask, - attribute_model: NLAttributeModel, +def _update_app_config(app: Flask, attribute_model: NLAttributeModel, nl_embeddings: emb_map.EmbeddingsMap, embeddings_version_map: Dict[str, any], - vertex_ai_models: Dict[str, Dict] = None): + embeddings_spec: EmbeddingsSpec): app.config[config.ATTRIBUTE_MODEL_KEY] = attribute_model app.config[config.NL_EMBEDDINGS_KEY] = nl_embeddings app.config[config.NL_EMBEDDINGS_VERSION_KEY] = embeddings_version_map - app.config[config.VERTEX_AI_MODELS_KEY] = vertex_ai_models or {} + app.config[config.EMBEDDINGS_SPEC_KEY] = embeddings_spec def _maybe_load_custom_dc_yaml(): @@ -166,3 +179,31 @@ def get_env_path(flask_env: str, file_name: str) -> str: def _is_custom_dc_dev(flask_env: str) -> bool: return flask_env == 'custom_dev' + + +# Get embedding index information for an instance +# +def _get_embeddings_spec() -> EmbeddingsSpec: + embeddings_spec_dict = None + + # If custom dc, get from constant + if is_custom_dc(): + embeddings_spec_dict = CUSTOM_DC_EMBEDDINGS_SPEC + # otherwise try to get from gke. + elif os.path.exists(_EMBEDDINGS_SPEC_PATH): + with open(_EMBEDDINGS_SPEC_PATH) as f: + embeddings_spec_dict = json.load(f) or {} + # If that path doesn't exist, assume we are running locally and use the values + # from autopush. + else: + with open(_LOCAL_ENV_VALUES_PATH) as f: + env_values = yaml.full_load(f) + embeddings_spec_dict = env_values['nl']['embeddingsSpec'] + + return EmbeddingsSpec( + embeddings_spec_dict.get('defaultIndex', ''), + embeddings_spec_dict.get('enabledIndexes', []), + # When vertexAIModels the key exists, the value can be None. If value is + # None, we still want to use an empty object. + vertex_ai_model_info=embeddings_spec_dict.get('vertexAIModels') or {}, + enable_reranking=embeddings_spec_dict.get('enableReranking', False)) diff --git a/nl_server/routes.py b/nl_server/routes.py index 1b90a0f089..e8a7a07368 100644 --- a/nl_server/routes.py +++ b/nl_server/routes.py @@ -35,8 +35,12 @@ @bp.route('/healthz') def healthz(): + default_index_type = current_app.config[ + config.EMBEDDINGS_SPEC_KEY].default_index + if not default_index_type: + return 'Service Unavailable', 500 nl_embeddings = current_app.config[config.NL_EMBEDDINGS_KEY].get_index( - config.DEFAULT_INDEX_TYPE) + default_index_type) if nl_embeddings: result: VarCandidates = search.search_vars( [nl_embeddings], ['life expectancy'])['life expectancy'] @@ -58,9 +62,11 @@ def search_vars(): queries = request.json.get('queries', []) queries = [str(escape(q)) for q in queries] - idx = str(escape(request.args.get('idx', config.DEFAULT_INDEX_TYPE))) + default_index_type = current_app.config[ + config.EMBEDDINGS_SPEC_KEY].default_index + idx = str(escape(request.args.get('idx', default_index_type))) if not idx: - idx = config.DEFAULT_INDEX_TYPE + idx = default_index_type emb_map: EmbeddingsMap = current_app.config[config.NL_EMBEDDINGS_KEY] diff --git a/nl_server/tests/custom_embeddings_test.py b/nl_server/tests/custom_embeddings_test.py index 170292a8f6..6addc49caa 100644 --- a/nl_server/tests/custom_embeddings_test.py +++ b/nl_server/tests/custom_embeddings_test.py @@ -64,28 +64,30 @@ def setUpClass(cls) -> None: cls.default_file = _copy(_DEFAULT_FILE) cls.custom_file = _copy(_CUSTOM_FILE) - cls.custom = emb_map.EmbeddingsMap({ - 'version': 1, - 'indexes': { - 'medium_ft': { - 'embeddings': cls.default_file, - 'store': 'MEMORY', - 'model': _TUNED_MODEL_NAME - }, - 'custom_ft': { - 'embeddings': cls.custom_file, - 'store': 'MEMORY', - 'model': _TUNED_MODEL_NAME - } - }, - 'models': { - _TUNED_MODEL_NAME: { - 'type': 'LOCAL', - 'usage': 'EMBEDDINGS', - 'gcs_folder': _TUNED_MODEL_GCS - } - } - }) + cls.custom = emb_map.EmbeddingsMap( + parse( + { + 'version': 1, + 'indexes': { + 'medium_ft': { + 'embeddings': cls.default_file, + 'store': 'MEMORY', + 'model': _TUNED_MODEL_NAME + }, + 'custom_ft': { + 'embeddings': cls.custom_file, + 'store': 'MEMORY', + 'model': _TUNED_MODEL_NAME + } + }, + 'models': { + _TUNED_MODEL_NAME: { + 'type': 'LOCAL', + 'usage': 'EMBEDDINGS', + 'gcs_folder': _TUNED_MODEL_GCS + } + } + }, {}, False)) def test_entries(self): self.assertEqual(1, len(self.custom.get_index('medium_ft').store.dcids)) @@ -113,46 +115,49 @@ def test_queries(self, query: str, index: str, expected: str): _test_query(self, indexes, query, expected) def test_merge_custom_embeddings(self): - embeddings = emb_map.EmbeddingsMap({ - 'version': 1, - 'indexes': { - 'medium_ft': { - 'embeddings': self.default_file, - 'store': 'MEMORY', - 'model': _TUNED_MODEL_NAME - }, - }, - 'models': { - _TUNED_MODEL_NAME: { - 'type': 'LOCAL', - 'usage': 'EMBEDDINGS', - 'gcs_folder': _TUNED_MODEL_GCS - } - } - }) + embeddings = emb_map.EmbeddingsMap( + parse( + { + 'version': 1, + 'indexes': { + 'medium_ft': { + 'embeddings': self.default_file, + 'store': 'MEMORY', + 'model': _TUNED_MODEL_NAME + }, + }, + 'models': { + _TUNED_MODEL_NAME: { + 'type': 'LOCAL', + 'usage': 'EMBEDDINGS', + 'gcs_folder': _TUNED_MODEL_GCS + } + } + }, {}, False)) _test_query(self, [embeddings.get_index("medium_ft")], "money", "dc/topic/sdg_1") _test_query(self, [embeddings.get_index("medium_ft")], "food", "") embeddings.reset_index( - parse({ - 'version': 1, - 'indexes': { - 'custom_ft': { - 'embeddings': self.custom_file, - 'store': 'MEMORY', - 'model': _TUNED_MODEL_NAME + parse( + { + 'version': 1, + 'indexes': { + 'custom_ft': { + 'embeddings': self.custom_file, + 'store': 'MEMORY', + 'model': _TUNED_MODEL_NAME + }, }, - }, - 'models': { - _TUNED_MODEL_NAME: { - 'type': 'LOCAL', - 'usage': 'EMBEDDINGS', - 'gcs_folder': _TUNED_MODEL_GCS + 'models': { + _TUNED_MODEL_NAME: { + 'type': 'LOCAL', + 'usage': 'EMBEDDINGS', + 'gcs_folder': _TUNED_MODEL_GCS + } } - } - })) + }, {}, False)) emb_list = [ embeddings.get_index("custom_ft"), diff --git a/nl_server/tests/embeddings_test.py b/nl_server/tests/embeddings_test.py index a5feb4afe8..893e8e8440 100644 --- a/nl_server/tests/embeddings_test.py +++ b/nl_server/tests/embeddings_test.py @@ -20,7 +20,6 @@ from parameterized import parameterized import yaml -from nl_server import embeddings_map as emb_map from nl_server.config import parse from nl_server.embeddings import Embeddings from nl_server.model.sentence_transformer import LocalSentenceTransformerModel @@ -32,11 +31,20 @@ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -def _get_embeddings_info(): +def _get_embeddings_spec(): + autopush_values_path = os.path.join(_root_dir, + 'deploy/helm_charts/envs/autopush.yaml') + with open(autopush_values_path) as f: + autopush_values = yaml.full_load(f) + return autopush_values['nl']['embeddingsSpec'] + + +def _get_embeddings_info(embeddings_spec): embeddings_config_path = os.path.join(_root_dir, 'deploy/nl/embeddings.yaml') with open(embeddings_config_path) as f: embeddings_map = yaml.full_load(f) - return parse(embeddings_map) + return parse(embeddings_map, embeddings_spec['vertexAIModels'], + embeddings_spec['enableReranking']) def _get_contents( @@ -48,9 +56,10 @@ class TestEmbeddings(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - embeddings_info = _get_embeddings_info() + embeddings_spec = _get_embeddings_spec() + embeddings_info = _get_embeddings_info(embeddings_spec) # TODO(pradh): Expand tests to other index sizes. - idx_info = embeddings_info.indexes[emb_map.DEFAULT_INDEX_TYPE] + idx_info = embeddings_info.indexes[embeddings_spec['defaultIndex']] model_info = embeddings_info.models[idx_info.model] cls.nl_embeddings = Embeddings( model=LocalSentenceTransformerModel(model_info),