Skip to content

Commit

Permalink
Add generic GCS download functions and use them in NL server/tools (#…
Browse files Browse the repository at this point in the history
…4252)

Added 3 generic GCS functions:

* download_blob() which downloads a blob(file or folder) given bucket,
blob name and local path.
* download_blob_by_path() which is a wrapper of download_blob() that
takes gcs path
* maybe_download() which downloads the blob under 'local_path_prefix'
and keeps the gcs path structure, and skip the download if already
exists.

Use these functions throughout NL apps and remove unused functions.
  • Loading branch information
shifucun authored May 21, 2024
1 parent feec7ab commit 78122be
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 182 deletions.
91 changes: 0 additions & 91 deletions nl_server/gcs.py

This file was deleted.

7 changes: 3 additions & 4 deletions nl_server/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
from nl_server.nl_attribute_model import NLAttributeModel
from shared.lib.custom_dc_util import get_custom_dc_user_data_path
from shared.lib.custom_dc_util import is_custom_dc
from shared.lib.gcs import download_gcs_file
from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import join_gcs_path
from shared.lib.gcs import maybe_download

_EMBEDDINGS_YAML = 'embeddings.yaml'
_CUSTOM_EMBEDDINGS_YAML_PATH = 'datacommons/nl/custom_embeddings.yaml'
Expand Down Expand Up @@ -138,10 +137,10 @@ def _maybe_load_custom_dc_yaml():
# TODO: Consider reading the base path from a "version.txt" instead
# of hardcoding `data`
if is_gcs_path(base):
gcs_path = join_gcs_path(base, _CUSTOM_EMBEDDINGS_YAML_PATH)
gcs_path = os.path.join(base, _CUSTOM_EMBEDDINGS_YAML_PATH)
logging.info('Downloading custom embeddings yaml from GCS path: %s',
gcs_path)
file_path = download_gcs_file(gcs_path)
file_path = maybe_download(gcs_path)
if not file_path:
logging.info(
"Custom embeddings yaml in GCS not found: %s. Custom embeddings will not be loaded.",
Expand Down
4 changes: 2 additions & 2 deletions nl_server/model/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch

from nl_server import embeddings
from nl_server import gcs
from nl_server.config import LocalModelConfig
from shared.lib import gcs


class LocalSentenceTransformerModel(embeddings.EmbeddingsModel):
Expand All @@ -31,7 +31,7 @@ def __init__(self, model_info: LocalModelConfig):

# Download model from gcs if there is a gcs folder specified
logging.info(f'Downloading tuned model from: {model_info.gcs_folder}')
model_path = gcs.download_folder(model_info.gcs_folder)
model_path = gcs.maybe_download(model_info.gcs_folder)
logging.info(f'Loading tuned model from: {model_path}')
self.model = SentenceTransformer(model_path)

Expand Down
7 changes: 3 additions & 4 deletions nl_server/store/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@

import lancedb

from nl_server import gcs
from nl_server.config import LanceDBIndexConfig
from nl_server.embeddings import EmbeddingsMatch
from nl_server.embeddings import EmbeddingsResult
from nl_server.embeddings import EmbeddingsStore
from shared.lib.gcs import is_gcs_path
from shared.lib import gcs

TABLE_NAME = 'datacommons'

Expand All @@ -43,9 +42,9 @@ def __init__(self, idx_info: LanceDBIndexConfig) -> None:

if idx_info.embeddings_path.startswith('/'):
lance_db_dir = idx_info.embeddings_path
elif is_gcs_path(idx_info.embeddings_path):
elif gcs.is_gcs_path(idx_info.embeddings_path):
logging.info('Downloading embeddings from GCS path: ')
lance_db_dir = gcs.download_folder(idx_info.embeddings_path)
lance_db_dir = gcs.maybe_download(idx_info.embeddings_path)
if not lance_db_dir:
raise AssertionError(
f'Embeddings not downloaded from GCS. Please check the path: {idx_info.embeddings_path}'
Expand Down
4 changes: 2 additions & 2 deletions nl_server/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from nl_server.embeddings import EmbeddingsResult
from nl_server.embeddings import EmbeddingsStore
from shared.lib.custom_dc_util import use_anonymous_gcs_client
from shared.lib.gcs import download_gcs_file
from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import maybe_download


class MemoryEmbeddingsStore(EmbeddingsStore):
Expand All @@ -40,7 +40,7 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None:
embeddings_path = idx_info.embeddings_path
elif is_gcs_path(idx_info.embeddings_path):
logging.info('Downloading embeddings from GCS path: ')
embeddings_path = download_gcs_file(
embeddings_path = maybe_download(
idx_info.embeddings_path,
use_anonymous_client=use_anonymous_gcs_client())
if not embeddings_path:
Expand Down
8 changes: 3 additions & 5 deletions server/lib/nl/common/bad_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from server.lib.config import GLOBAL_CONFIG_BUCKET
from shared.lib import gcs

BAD_WORDS_FILE = 'nl_bad_words.txt'
BAD_WORDS_PATH = gcs.make_path(GLOBAL_CONFIG_BUCKET, 'nl_bad_words.txt')
_DELIM = ':'


Expand Down Expand Up @@ -74,8 +74,7 @@ class BannedWords:
# Loads a list of bad words from a text file.
#
def load_bad_words() -> BannedWords:
local_file = gcs.download_file(bucket=GLOBAL_CONFIG_BUCKET,
filename=BAD_WORDS_FILE)
local_file = gcs.maybe_download(BAD_WORDS_PATH)
return load_bad_words_file(local_file)


Expand Down Expand Up @@ -133,8 +132,7 @@ def load_bad_words_file(local_file: str, validate: bool = False) -> BannedWords:


def validate_bad_words():
local_file = gcs.download_file(bucket=GLOBAL_CONFIG_BUCKET,
filename=BAD_WORDS_FILE)
local_file = gcs.maybe_download(BAD_WORDS_PATH)
load_bad_words_file(local_file, validate=True)


Expand Down
4 changes: 2 additions & 2 deletions server/lib/topic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from server.lib.nl.explore.params import DCNames
from shared.lib.custom_dc_util import get_custom_dc_topic_cache_path
from shared.lib.custom_dc_util import is_custom_dc
from shared.lib.gcs import download_gcs_file
from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import maybe_download


# This might be a topic or svpg
Expand Down Expand Up @@ -244,7 +244,7 @@ def _get_local_custom_dc_topic_cache_path() -> str:
logging.info("Custom DC topic cache will be loaded from: %s", path)

if is_gcs_path(path):
return download_gcs_file(path)
return maybe_download(path)

if not os.path.exists(path):
logging.warning(
Expand Down
3 changes: 0 additions & 3 deletions shared/lib/custom_dc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os

from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import join_gcs_path

_TOPIC_CACHE_PATH = "datacommons/nl/custom_dc_topic_cache.json"

Expand All @@ -36,8 +35,6 @@ def get_custom_dc_topic_cache_path() -> str:
base_path = get_custom_dc_user_data_path()
if not base_path:
return base_path
if is_gcs_path(base_path):
return join_gcs_path(base_path, _TOPIC_CACHE_PATH)
return os.path.join(base_path, _TOPIC_CACHE_PATH)


Expand Down
Loading

0 comments on commit 78122be

Please sign in to comment.