Skip to content

Commit

Permalink
more update
Browse files Browse the repository at this point in the history
  • Loading branch information
shifucun committed May 20, 2024
1 parent 080d2b0 commit cd0b6e5
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 60 deletions.
2 changes: 1 addition & 1 deletion nl_server/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def read_catalog_config() -> CatalogConfig:

if gcs.is_gcs_path(user_data_path):
full_gcs_path = os.path.join(user_data_path, _CATALOG_USER_PATH_SUFFIX)
gcs.download_file_by_path(full_gcs_path, _CATALOG_TMP_PATH)
gcs.download_blob_by_path(full_gcs_path, _CATALOG_TMP_PATH)
all_paths.append(_CATALOG_TMP_PATH)
else:
full_user_path = os.path.join(user_data_path, _CATALOG_USER_PATH_SUFFIX)
Expand Down
7 changes: 3 additions & 4 deletions nl_server/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@
#
class ResourceRegistry:

# Input is parsed runnable config.
# Input is server config object.
def __init__(self, server_config: ServerConfig):
self.name_to_emb: dict[str, Embeddings] = {}
self.name_to_emb_model: Dict[str, EmbeddingsModel] = {}
self.name_to_rank_model: Dict[str, RerankingModel] = {}
self._attribute_model = AttributeModel()
self.load(server_config)

# Note: The caller takes care of exceptions.
Expand All @@ -65,14 +66,12 @@ def attribute_model(self) -> AttributeModel:
def get_model(self, model_name: str) -> EmbeddingsModel:
return self.name_to_emb_model.get(model_name)

# Adds the new models and indexes in a RunnableConfig object to the
# embeddings
# Load the registry from the server config
def load(self, server_config: ServerConfig):
self._server_config = server_config
self._load_models(server_config.models)
for idx_name, idx_info in server_config.indexes.items():
self._set_embeddings(idx_name, idx_info)
self._attribute_model = AttributeModel()

# Loads a dict of model name -> model info
def _load_models(self, models: dict[str, ModelConfig]):
Expand Down
35 changes: 18 additions & 17 deletions nl_server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flask import request
from markupsafe import escape

from nl_server import registry
from nl_server import search
from nl_server.embeddings import Embeddings
from nl_server.registry import REGISTRY_KEY
Expand All @@ -35,12 +36,12 @@

@bp.route('/healthz')
def healthz():
registry: ResourceRegistry = current_app.config[REGISTRY_KEY]
default_indexes = registry.server_config().default_indexes
r: ResourceRegistry = current_app.config[REGISTRY_KEY]
default_indexes = r.server_config().default_indexes
if not default_indexes:
logging.warning('Health Check Failed: Default index name empty!')
return 'Service Unavailable', 500
embeddings: Embeddings = registry.get_index(default_indexes[0])
embeddings: Embeddings = r.get_index(default_indexes[0])
if embeddings:
query = embeddings.store.healthcheck_query
result: VarCandidates = search.search_vars([embeddings], [query]).get(query)
Expand Down Expand Up @@ -70,18 +71,19 @@ def search_vars():
if request.args.get('skip_topics'):
skip_topics = True

r: ResourceRegistry = current_app.config[REGISTRY_KEY]

reranker_name = str(escape(request.args.get('reranker', '')))
reranker_model = registry.get_reranking_model(
reranker_model = r.get_reranking_model(
reranker_name) if reranker_name else None

registry = current_app.config[REGISTRY_KEY]
default_indexes = registry.server_config().default_indexes
default_indexes = r.server_config().default_indexes
idx_type_str = str(escape(request.args.get('idx', '')))
if not idx_type_str:
idx_types = default_indexes
else:
idx_types = idx_type_str.split(',')
embeddings = _get_indexes(registry, idx_types)
embeddings = _get_indexes(r, idx_types)

debug_logs = {'sv_detection_query_index_type': idx_types}
results = search.search_vars(embeddings, queries, skip_topics, reranker_model,
Expand All @@ -101,34 +103,33 @@ def detect_verbs():
List[str]
"""
query = str(escape(request.args.get('q')))
registry = current_app.config[REGISTRY_KEY]
return json.dumps(registry.attribute_model().detect_verbs(query.strip()))
r: ResourceRegistry = current_app.config[REGISTRY_KEY]
return json.dumps(r.attribute_model().detect_verbs(query.strip()))


@bp.route('/api/embeddings_version_map/', methods=['GET'])
def embeddings_version_map():
registry: ResourceRegistry = current_app.config[REGISTRY_KEY]
server_config = registry.server_config()
r: ResourceRegistry = current_app.config[REGISTRY_KEY]
server_config = r.server_config()
return json.dumps(asdict(server_config))


@bp.route('/api/load/', methods=['GET'])
def load():
try:
current_app.config[registry.REGISTRY_KEY] = registry.build()
current_app.config[REGISTRY_KEY] = registry.build()
except Exception as e:
logging.error(f'Custom embeddings not loaded due to error: {str(e)}')
registry = current_app.config[REGISTRY_KEY]
server_config = registry.server_config()
r: ResourceRegistry = current_app.config[REGISTRY_KEY]
server_config = r.server_config()
return json.dumps(asdict(server_config))


def _get_indexes(registry: ResourceRegistry,
idx_types: List[str]) -> List[Embeddings]:
def _get_indexes(r: ResourceRegistry, idx_types: List[str]) -> List[Embeddings]:
embeddings: List[Embeddings] = []
for idx in idx_types:
try:
emb = registry.get_index(idx)
emb = r.get_index(idx)
if emb:
embeddings.append(emb)
except Exception as e:
Expand Down
6 changes: 2 additions & 4 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ function run_py_test {

# Tests within tools/nl/embeddings
echo "Running tests within tools/nl/embeddings:"
cd tools/nl/embeddings
pip3 install -r requirements.txt
python3 -m pytest ./ -s
cd ../../..
pip3 install -r tools/nl/embeddings/requirements.txt
python3 -m pytest tools/nl/embeddings/ -s

pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
Expand Down
6 changes: 3 additions & 3 deletions tools/nl/embeddings/build_custom_dc_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

from absl import app
from absl import flags
from file_util import create_file_handler
from file_util import FileHandler
import pandas as pd
import utils
import yaml

from nl_server import config
from nl_server import config_reader
from nl_server.config import CatalogConfig
from nl_server.config import MemoryIndexConfig
from nl_server.registry import ResourceRegistry
from tools.nl.embeddings import utils
from tools.nl.embeddings.file_util import create_file_handler
from tools.nl.embeddings.file_util import FileHandler


class Mode:
Expand Down
17 changes: 10 additions & 7 deletions tools/nl/embeddings/build_custom_dc_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
# limitations under the License.

import os
from pathlib import Path
import tempfile
import unittest

from build_custom_dc_embeddings import EMBEDDINGS_CSV_FILENAME_PREFIX
from build_custom_dc_embeddings import EMBEDDINGS_YAML_FILE_NAME
import build_custom_dc_embeddings as builder
from file_util import create_file_handler
from sentence_transformers import SentenceTransformer
import utils

from nl_server.config import LocalModelConfig
from tools.nl.embeddings import utils
from tools.nl.embeddings.build_custom_dc_embeddings import \
EMBEDDINGS_CSV_FILENAME_PREFIX
from tools.nl.embeddings.build_custom_dc_embeddings import \
EMBEDDINGS_YAML_FILE_NAME
import tools.nl.embeddings.build_custom_dc_embeddings as builder
from tools.nl.embeddings.file_util import create_file_handler

MODEL_NAME = "all-MiniLM-L6-v2"
INPUT_DIR = "testdata/custom_dc/input"
EXPECTED_DIR = "testdata/custom_dc/expected"
INPUT_DIR = Path(__file__).parent / "testdata/custom_dc/input"
EXPECTED_DIR = Path(__file__).parent / "testdata/custom_dc/expected"


def _compare_files(test: unittest.TestCase, output_path, expected_path):
Expand Down
3 changes: 2 additions & 1 deletion tools/nl/embeddings/build_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import lancedb
import pandas as pd
from sentence_transformers import SentenceTransformer
import utils

from tools.nl.embeddings import utils

VERTEX_AI_PROJECT = 'datcom-nl'
VERTEX_AI_PROJECT_LOCATION = 'us-central1'
Expand Down
25 changes: 16 additions & 9 deletions tools/nl/embeddings/build_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import tempfile
import unittest
from unittest import mock
from pathlib import Path

import build_embeddings as be
import pandas as pd
from parameterized import parameterized
from sentence_transformers import SentenceTransformer
import utils

from tools.nl.embeddings import utils
import tools.nl.embeddings.build_embeddings as be


def get_test_sv_data():
Expand Down Expand Up @@ -96,7 +98,7 @@ def testFailure(self):
input_sheets_svs = []

# Filepaths all correspond to the testdata folder.
input_dir = "testdata/input"
input_dir = Path(__file__).parent / "testdata/input"
input_alternatives_filepattern = os.path.join(input_dir,
"*_alternatives.csv")
input_autogen_filepattern = os.path.join(input_dir, 'unknown_*.csv')
Expand All @@ -120,8 +122,8 @@ def testSuccess(self):
tmp="/tmp")

# Filepaths all correspond to the testdata folder.
input_dir = "testdata/input"
expected_dir = "testdata/expected"
input_dir = Path(__file__).parent / "testdata/input"
expected_dir = Path(__file__).parent / "testdata/expected"
input_alternatives_filepattern = os.path.join(input_dir,
"*_alternatives.csv")
input_autogen_filepattern = os.path.join(input_dir, "autogen_*.csv")
Expand Down Expand Up @@ -156,12 +158,16 @@ class TestEndToEndActualDataFiles(unittest.TestCase):
@parameterized.expand(["small", "medium"])
def testInputFilesValidations(self, sz):
# Verify that the required files exist.
sheets_filepath = "data/curated_input/main/sheets_svs.csv"
sheets_filepath = Path(
__file__).parent / "data/curated_input/main/sheets_svs.csv"
# TODO: Fix palm_batch13k_alternatives.csv to not have duplicate
# descriptions. Its technically okay since build_embeddings will take
# care of dups.
input_alternatives_filepattern = "data/alternatives/(palm|other)_alternaties.csv"
output_dcid_sentences_filepath = f'data/preindex/{sz}/sv_descriptions.csv'
parent_folder = str(Path(__file__).parent)
input_alternatives_filepattern = os.path.join(
parent_folder, "data/alternatives/(palm|other)_alternaties.csv")
output_dcid_sentences_filepath = os.path.join(
parent_folder, f'data/preindex/{sz}/sv_descriptions.csv')

# Check that all the files exist.
self.assertTrue(os.path.exists(sheets_filepath))
Expand Down Expand Up @@ -192,7 +198,8 @@ def testInputFilesValidations(self, sz):

@parameterized.expand(["small", "medium"])
def testOutputFileValidations(self, sz):
output_dcid_sentences_filepath = f'data/preindex/{sz}/sv_descriptions.csv'
output_dcid_sentences_filepath = Path(
__file__).parent / f'data/preindex/{sz}/sv_descriptions.csv'

dcid_sentence_df = pd.read_csv(output_dcid_sentences_filepath).fillna("")

Expand Down
14 changes: 6 additions & 8 deletions tools/nl/embeddings/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,22 @@ done
cd ../../..
python3 -m venv .env
source .env/bin/activate
cd tools/nl/embeddings
python3 -m pip install --upgrade pip
pip3 install torch==2.2.2 --extra-index-url https://download.pytorch.org/whl/cpu
pip3 install -r requirements.txt
pip3 install -r tools/nl/embeddings/requirements.txt

if [[ "$MODEL_ENDPOINT_ID" != "" ]];then
python3 build_embeddings.py --embeddings_size=$2 \
python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 \
--vertex_ai_prediction_endpoint_id=$MODEL_ENDPOINT_ID \
--curated_input_dirs="data/curated_input/main" \
--autogen_input_basedir="" \
--alternatives_filepattern=""

elif [[ "$CURATED_INPUT_DIRS" != "" ]]; then
python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --curated_input_dirs=$CURATED_INPUT_DIRS --alternatives_filepattern=$ALTERNATIVES_FILE_PATTERN
python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --curated_input_dirs=$CURATED_INPUT_DIRS --alternatives_filepattern=$ALTERNATIVES_FILE_PATTERN
elif [[ "$LANCEDB_OUTPUT_PATH" != "" ]]; then
python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --lancedb_output_path=$LANCEDB_OUTPUT_PATH --dry_run=True
python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL --lancedb_output_path=$LANCEDB_OUTPUT_PATH --dry_run=True
elif [[ "$FINETUNED_MODEL" != "" ]]; then
python3 build_embeddings.py --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL
python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2 --finetuned_model_gcs=$FINETUNED_MODEL
else
python3 build_embeddings.py --embeddings_size=$2
python3 -m tools.nl.embeddings.build_embeddings --embeddings_size=$2
fi
6 changes: 3 additions & 3 deletions tools/nl/embeddings/run_custom.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

set -e

cd ../../..
python3 -m venv .env
source .env/bin/activate
python3 -m pip install --upgrade pip
pip3 install torch==2.2.2 --extra-index-url https://download.pytorch.org/whl/cpu
pip3 install -r requirements.txt
pip3 install -r tools/nl/embeddings/requirements.txt

python3 build_custom_dc_embeddings.py "$@"
python3 -m tools.nl.embeddings.build_custom_dc_embeddings "$@"
5 changes: 2 additions & 3 deletions tools/nl/embeddings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
from dataclasses import dataclass
import itertools
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

from file_util import create_file_handler
from google.cloud import aiplatform
import pandas as pd

from tools.nl.embeddings.file_util import create_file_handler

# Col names in the input files/sheets.
DCID_COL = 'dcid'
NAME_COL = 'Name'
Expand Down

0 comments on commit cd0b6e5

Please sign in to comment.