diff --git a/nl_server/gcs.py b/nl_server/gcs.py deleted file mode 100644 index 55c093f7f5..0000000000 --- a/nl_server/gcs.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -BUCKET = 'datcom-nl-models' - -import os -from pathlib import Path -from typing import Any - -from google.cloud import storage - -from shared.lib import gcs as gcs_lib - - -def download_embeddings(embeddings_filename: str) -> str: - # Using use_anonymous_client = True because - # Custom DCs need to download models without DC creds. - return gcs_lib.get_or_download_file(bucket=BUCKET, - filename=embeddings_filename, - use_anonymous_client=True) - - -def local_path(embeddings_file: str) -> str: - return os.path.join(gcs_lib.TEMP_DIR, embeddings_file) - - -def download_folder_from_gcs(gcs_bucket: Any, local_dir: str, - folder_name: str) -> str: - """Downloads a folder from GCS. - - Args: - gcs_bucket: The GCS bucket. - local_dir: the local folder to download to. - folder_name: the GCS folder name for the model. - - Returns the path to the local directory where the folder was downloaded to. - """ - # Get list of files - blobs = gcs_bucket.list_blobs(prefix=folder_name) - for blob in blobs: - file_split = blob.name.split("/") - directory = local_dir - for p in file_split[0:-1]: - directory = os.path.join(directory, p) - Path(directory).mkdir(parents=True, exist_ok=True) - - if blob.name.endswith("/"): - continue - blob.download_to_filename(os.path.join(directory, file_split[-1])) - - return os.path.join(local_dir, folder_name) - - -# Downloads the `folder` or gs:// path from GCS to /tmp/ -# and return its path. -def download_folder(path: str) -> str: - # Using an anonymous client because - # Custom DCs need to download models without DC creds. - sc = storage.Client.create_anonymous_client() - - if gcs_lib.is_gcs_path(path): - bucket_name, base_name = gcs_lib.get_gcs_parts(path) - else: - bucket_name = BUCKET - base_name = path - local_dir_prefix = os.path.join(gcs_lib.TEMP_DIR, bucket_name) - - # Only download if needed. - local_path = os.path.join(local_dir_prefix, base_name) - if os.path.exists(local_path) and len(os.listdir(local_path)) > 0: - # When running locally, we may already have downloaded the path. - # But sometimes after restart, the directories in `/tmp` become - # empty, so ensure that's not the case. - return local_path - - print( - f"Directory ({base_name}) was either not previously downloaded or cannot successfully be loaded. Downloading to: {local_path}" - ) - bucket = sc.bucket(bucket_name=bucket_name) - return download_folder_from_gcs(bucket, local_dir_prefix, base_name) diff --git a/nl_server/loader.py b/nl_server/loader.py index ae37f48992..0614db0875 100644 --- a/nl_server/loader.py +++ b/nl_server/loader.py @@ -28,9 +28,8 @@ from nl_server.nl_attribute_model import NLAttributeModel from shared.lib.custom_dc_util import get_custom_dc_user_data_path from shared.lib.custom_dc_util import is_custom_dc -from shared.lib.gcs import download_gcs_file from shared.lib.gcs import is_gcs_path -from shared.lib.gcs import join_gcs_path +from shared.lib.gcs import maybe_download _EMBEDDINGS_YAML = 'embeddings.yaml' _CUSTOM_EMBEDDINGS_YAML_PATH = 'datacommons/nl/custom_embeddings.yaml' @@ -138,10 +137,10 @@ def _maybe_load_custom_dc_yaml(): # TODO: Consider reading the base path from a "version.txt" instead # of hardcoding `data` if is_gcs_path(base): - gcs_path = join_gcs_path(base, _CUSTOM_EMBEDDINGS_YAML_PATH) + gcs_path = os.path.join(base, _CUSTOM_EMBEDDINGS_YAML_PATH) logging.info('Downloading custom embeddings yaml from GCS path: %s', gcs_path) - file_path = download_gcs_file(gcs_path) + file_path = maybe_download(gcs_path) if not file_path: logging.info( "Custom embeddings yaml in GCS not found: %s. Custom embeddings will not be loaded.", diff --git a/nl_server/model/sentence_transformer.py b/nl_server/model/sentence_transformer.py index c585ac9c1d..a8f7f60131 100644 --- a/nl_server/model/sentence_transformer.py +++ b/nl_server/model/sentence_transformer.py @@ -20,8 +20,8 @@ import torch from nl_server import embeddings -from nl_server import gcs from nl_server.config import LocalModelConfig +from shared.lib import gcs class LocalSentenceTransformerModel(embeddings.EmbeddingsModel): @@ -31,7 +31,7 @@ def __init__(self, model_info: LocalModelConfig): # Download model from gcs if there is a gcs folder specified logging.info(f'Downloading tuned model from: {model_info.gcs_folder}') - model_path = gcs.download_folder(model_info.gcs_folder) + model_path = gcs.maybe_download(model_info.gcs_folder) logging.info(f'Loading tuned model from: {model_path}') self.model = SentenceTransformer(model_path) diff --git a/nl_server/store/lancedb.py b/nl_server/store/lancedb.py index 58eca01967..c96a40742f 100644 --- a/nl_server/store/lancedb.py +++ b/nl_server/store/lancedb.py @@ -18,12 +18,11 @@ import lancedb -from nl_server import gcs from nl_server.config import LanceDBIndexConfig from nl_server.embeddings import EmbeddingsMatch from nl_server.embeddings import EmbeddingsResult from nl_server.embeddings import EmbeddingsStore -from shared.lib.gcs import is_gcs_path +from shared.lib import gcs TABLE_NAME = 'datacommons' @@ -43,9 +42,9 @@ def __init__(self, idx_info: LanceDBIndexConfig) -> None: if idx_info.embeddings_path.startswith('/'): lance_db_dir = idx_info.embeddings_path - elif is_gcs_path(idx_info.embeddings_path): + elif gcs.is_gcs_path(idx_info.embeddings_path): logging.info('Downloading embeddings from GCS path: ') - lance_db_dir = gcs.download_folder(idx_info.embeddings_path) + lance_db_dir = gcs.maybe_download(idx_info.embeddings_path) if not lance_db_dir: raise AssertionError( f'Embeddings not downloaded from GCS. Please check the path: {idx_info.embeddings_path}' diff --git a/nl_server/store/memory.py b/nl_server/store/memory.py index dbfc646e68..26d6fd43b2 100644 --- a/nl_server/store/memory.py +++ b/nl_server/store/memory.py @@ -25,8 +25,8 @@ from nl_server.embeddings import EmbeddingsResult from nl_server.embeddings import EmbeddingsStore from shared.lib.custom_dc_util import use_anonymous_gcs_client -from shared.lib.gcs import download_gcs_file from shared.lib.gcs import is_gcs_path +from shared.lib.gcs import maybe_download class MemoryEmbeddingsStore(EmbeddingsStore): @@ -40,7 +40,7 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None: embeddings_path = idx_info.embeddings_path elif is_gcs_path(idx_info.embeddings_path): logging.info('Downloading embeddings from GCS path: ') - embeddings_path = download_gcs_file( + embeddings_path = maybe_download( idx_info.embeddings_path, use_anonymous_client=use_anonymous_gcs_client()) if not embeddings_path: diff --git a/server/lib/nl/common/bad_words.py b/server/lib/nl/common/bad_words.py index 5f4f9f6796..df6d71e520 100644 --- a/server/lib/nl/common/bad_words.py +++ b/server/lib/nl/common/bad_words.py @@ -20,7 +20,7 @@ from server.lib.config import GLOBAL_CONFIG_BUCKET from shared.lib import gcs -BAD_WORDS_FILE = 'nl_bad_words.txt' +BAD_WORDS_PATH = gcs.make_path(GLOBAL_CONFIG_BUCKET, 'nl_bad_words.txt') _DELIM = ':' @@ -74,8 +74,7 @@ class BannedWords: # Loads a list of bad words from a text file. # def load_bad_words() -> BannedWords: - local_file = gcs.download_file(bucket=GLOBAL_CONFIG_BUCKET, - filename=BAD_WORDS_FILE) + local_file = gcs.maybe_download(BAD_WORDS_PATH) return load_bad_words_file(local_file) @@ -133,8 +132,7 @@ def load_bad_words_file(local_file: str, validate: bool = False) -> BannedWords: def validate_bad_words(): - local_file = gcs.download_file(bucket=GLOBAL_CONFIG_BUCKET, - filename=BAD_WORDS_FILE) + local_file = gcs.maybe_download(BAD_WORDS_PATH) load_bad_words_file(local_file, validate=True) diff --git a/server/lib/topic_cache.py b/server/lib/topic_cache.py index a25669013e..105c22baaa 100644 --- a/server/lib/topic_cache.py +++ b/server/lib/topic_cache.py @@ -25,8 +25,8 @@ from server.lib.nl.explore.params import DCNames from shared.lib.custom_dc_util import get_custom_dc_topic_cache_path from shared.lib.custom_dc_util import is_custom_dc -from shared.lib.gcs import download_gcs_file from shared.lib.gcs import is_gcs_path +from shared.lib.gcs import maybe_download # This might be a topic or svpg @@ -244,7 +244,7 @@ def _get_local_custom_dc_topic_cache_path() -> str: logging.info("Custom DC topic cache will be loaded from: %s", path) if is_gcs_path(path): - return download_gcs_file(path) + return maybe_download(path) if not os.path.exists(path): logging.warning( diff --git a/shared/lib/custom_dc_util.py b/shared/lib/custom_dc_util.py index ea854d14c5..c2c173d328 100644 --- a/shared/lib/custom_dc_util.py +++ b/shared/lib/custom_dc_util.py @@ -17,7 +17,6 @@ import os from shared.lib.gcs import is_gcs_path -from shared.lib.gcs import join_gcs_path _TOPIC_CACHE_PATH = "datacommons/nl/custom_dc_topic_cache.json" @@ -36,8 +35,6 @@ def get_custom_dc_topic_cache_path() -> str: base_path = get_custom_dc_user_data_path() if not base_path: return base_path - if is_gcs_path(base_path): - return join_gcs_path(base_path, _TOPIC_CACHE_PATH) return os.path.join(base_path, _TOPIC_CACHE_PATH) diff --git a/shared/lib/gcs.py b/shared/lib/gcs.py index 857f488c15..419eceda11 100644 --- a/shared/lib/gcs.py +++ b/shared/lib/gcs.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,81 +16,116 @@ import logging import os -from pathlib import Path from typing import Tuple from google.cloud import storage -_GCS_PATH_PREFIX = "gs://" +GCS_PATH_PREFIX = "gs://" def is_gcs_path(path: str) -> bool: - return path.strip().startswith(_GCS_PATH_PREFIX) + return path.startswith(GCS_PATH_PREFIX) and len(path) > len(GCS_PATH_PREFIX) -def join_gcs_path(base_path: str, sub_path: str) -> str: - if base_path.endswith('/'): - return f'{base_path}{sub_path}' - return f'{base_path}/{sub_path}' +def get_path_parts(gcs_path: str) -> Tuple[str, str]: + if not is_gcs_path(gcs_path): + raise ValueError(f"Invalid GCS path: {gcs_path}") + return tuple(gcs_path.removeprefix(GCS_PATH_PREFIX).split('/', 1)) -def get_gcs_parts(gcs_path: str) -> Tuple[str, str]: - return gcs_path[len(_GCS_PATH_PREFIX):].split('/', 1) +def make_path(bucket_name: str, blob_name: str) -> str: + return GCS_PATH_PREFIX + bucket_name + '/' + blob_name -def download_gcs_file(gcs_path: str, use_anonymous_client: bool = False) -> str: - """Downloads the file from the full GCS path (i.e. gs://bucket/path/to/file) - to a local path and returns the latter. +def download_blob(bucket_name: str, + blob_name: str, + local_path: str, + use_anonymous_client: bool = False) -> bool: """ - # If not a GCS path, return the path itself. - if not is_gcs_path(gcs_path): - return gcs_path - - bucket_name, blob_name = get_gcs_parts(gcs_path) - if not blob_name: - return '' - try: - return get_or_download_file(bucket_name, blob_name, use_anonymous_client) - except Exception as e: - logging.warning("Unable to download gcs file: %s (%s)", gcs_path, str(e)) - return '' - - -# -# Downloads the `filename` from GCS to TEMP_DIR -# and returns its path. -# -def download_file(bucket: str, - filename: str, - use_anonymous_client: bool = False) -> str: + Downloads the content of a GCS blob to a local path. + + Args: + - bucket_name: The name of the GCS bucket. + - blob_name: The GCS blob name, could be a folder or a file. + - local_path: The local path to download the blob to. + """ + logging.info("Download %s/%s to %s", bucket_name, blob_name, local_path) if use_anonymous_client: storage_client = storage.Client.create_anonymous_client() else: storage_client = storage.Client() - bucket_object = storage_client.bucket(bucket_name=bucket) - blob = bucket_object.get_blob(filename) - # Download - local_file_path = _get_local_path(bucket, filename) - # Create directory to file if it does not exist. - parent_dir = Path(local_file_path).parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - blob.download_to_filename(local_file_path) - return local_file_path - - -def get_or_download_file(bucket: str, - filename: str, - use_anonymous_client: bool = False) -> str: - """Returns the local file path if the file already exists. - Otherwise it downloads the file from GCS and returns the path it was downloaded to. - """ - local_file_path = _get_local_path(bucket, filename) - if os.path.exists(local_file_path): - logging.info("Using already downloaded GCS file: %s", local_file_path) - return local_file_path - return download_file(bucket, filename, use_anonymous_client) - -def _get_local_path(bucketname: str, filename: str) -> str: - return os.path.join(TEMP_DIR, bucketname, filename) + bucket = storage_client.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=blob_name) + count = 0 + for blob in blobs: + # When a blob name ends with "/", the blob is a folder. No need to download. + if blob.name.endswith("/"): + continue + # Get the relative path to the input blob. This is used to download folder. + relative_path = os.path.relpath(blob.name, blob_name) + if relative_path == ".": + # In this case, the blob is a file. + local_file_path = local_path + else: + # In this case, the blob is a folder. + local_file_path = os.path.join(local_path, relative_path) + # Create the directory if it doesn't exist. + local_dir = os.path.dirname(local_file_path) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + # Download the file. + blob.download_to_filename(local_file_path) + count += 1 + if count == 0: + logging.warning("No object found from %s/%s", bucket_name, blob_name) + return False + return True + + +def download_blob_by_path(gcs_path: str, + local_path: str, + use_anonymous_client: bool = False) -> bool: + """Downloads file/folder given full GCS path (i.e. gs://bucket/path/to/file) + to a local path. + + Args: + gcs_path: The full GCS path (i.e. gs://bucket/path/to/file/). + local_path: The local path to download the blob to. + """ + if not is_gcs_path(gcs_path): + raise ValueError(f"Invalid GCS path: {gcs_path}") + bucket_name, blob_name = get_path_parts(gcs_path) + return download_blob(bucket_name, blob_name, local_path, use_anonymous_client) + + +def maybe_download(gcs_path: str, + local_path_root: str = '/tmp', + use_anonymous_client: bool = False) -> str: + """Downloads file/folder from a GCS path (i.e. gs://bucket/path/to/file) + to a local path. If the local file/folder already exists, then do nothing. + + The local path expands the gcs_path pattern under local_path_root. + For example, if local_path_root is '/tmp', the local path will be + '/tmp/bucket/path/to/file'. + + Args: + gcs_path: The full GCS path (i.e. gs://bucket/path/to/file/). + local_path_root: The local root path to download the gcs resources to. + use_anonymous_client: Whether to use anonymous client to download the file. + Returns: + The local path of the downloaded file/folder. + """ + if not is_gcs_path(gcs_path): + raise ValueError(f"Invalid GCS path: {gcs_path}") + bucket_name, blob_name = get_path_parts(gcs_path) + local_path = os.path.join(local_path_root, bucket_name, blob_name) + if os.path.exists(local_path): + # When running locally, we may already have downloaded the path. + # But sometimes after restart, the directories in `/tmp` become empty, + # so ensure that's not the case. return local_path + if os.path.isfile(local_path) or len(os.listdir(local_path)) > 0: + return local_path + if download_blob(bucket_name, blob_name, local_path, use_anonymous_client): + return local_path + return None diff --git a/shared/lib/gcs_test.py b/shared/lib/gcs_test.py new file mode 100644 index 0000000000..6e9c701ecd --- /dev/null +++ b/shared/lib/gcs_test.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for gcs functions""" + +import os +from pathlib import Path +import unittest + +from parameterized import parameterized +import pytest + +import shared.lib.gcs as gcs + + +class TestGCSFunctions(unittest.TestCase): + # All the test files are prepared in this bucket. + bucket_name = 'datcom-ci-test' + + @pytest.fixture(autouse=True) + def _inject_fixtures(self, tmp_path): + self.tmp_path = tmp_path + + @parameterized.expand([ + ['abc', False], + ['gs://bucket/object', True], + ['gs://bucket', True], + ['gs://', False], + ]) + def test_is_gcs_path(self, path, result): + self.assertEqual(gcs.is_gcs_path(path), result) + + @parameterized.expand([ + ['gs://bucket/object', ('bucket', 'object')], + ['gs://bucket/folder/object', ('bucket', 'folder/object')], + ['gs://bucket', ('bucket',)], + ]) + def test_get_path_parts(self, path, result): + self.assertEqual(gcs.get_path_parts(path), result) + + def test_get_path_parts_invalid_path(self): + self.assertRaises(ValueError, gcs.get_path_parts, '@#@') + + @parameterized.expand([ + ['x/y.txt', 'folder1/folder11/d.txt', 'd'], + ['a.txt', 'folder2/a.txt', 'a'], + ]) + def test_download_file(self, local_file_path, blob_name, content): + """ + Download a file from GCS to a local path. Here the blob_name is a file name. + """ + f = self.tmp_path / local_file_path + gcs.download_blob(self.bucket_name, blob_name, f) + self.assertEqual(f.read_text(), content) + + def test_download_folder(self): + """ + Test downloading a folder from GCS to a local path. And check all the nested + files exist. + """ + gcs_folder = 'folder1' + gcs.download_blob(self.bucket_name, gcs_folder, self.tmp_path) + p = Path(self.tmp_path) + got_files = sorted([ + os.path.relpath(x, self.tmp_path) for x in p.rglob('*') if x.is_file() + ]) + self.assertEqual(got_files, ['b.txt', 'c.txt', 'folder11/d.txt']) + + @parameterized.expand([ + [f'gs://{bucket_name}/folder1/folder11/d.txt', 'tmp.txt', 'd'], + ]) + def test_download_blob_by_path(self, gcs_path, local_file_path, content): + """Download a file based on GCS path. + """ + f = self.tmp_path / local_file_path + gcs.download_blob_by_path(gcs_path, f) + self.assertEqual(f.read_text(), content) diff --git a/tools/nl/embeddings/build_custom_dc_embeddings.py b/tools/nl/embeddings/build_custom_dc_embeddings.py index 9a91802e9c..e438bbc8e0 100644 --- a/tools/nl/embeddings/build_custom_dc_embeddings.py +++ b/tools/nl/embeddings/build_custom_dc_embeddings.py @@ -83,8 +83,8 @@ def download(embeddings_yaml_path: str): # Download embeddings. embeddings_file_name = default_ft_embeddings_info.index_config['embeddings'] print(f"Downloading default embeddings: {embeddings_file_name}") - local_embeddings_path = gcs.download_gcs_file(embeddings_file_name, - use_anonymous_client=True) + local_embeddings_path = gcs.maybe_download(embeddings_file_name, + use_anonymous_client=True) if not local_embeddings_path: print(f"Unable to download default embeddings: {embeddings_file_name}") else: diff --git a/tools/nl/embeddings/file_util.py b/tools/nl/embeddings/file_util.py index 53d473c537..150ccb99f6 100644 --- a/tools/nl/embeddings/file_util.py +++ b/tools/nl/embeddings/file_util.py @@ -92,19 +92,13 @@ def write_string(self, content: str) -> None: self.blob.upload_from_string(content) def join(self, subpath: str) -> str: - return join_gcs_path(self.path, subpath) + return os.path.join(self.path, subpath) def is_gcs_path(path: str) -> bool: return path.startswith(_GCS_PATH_PREFIX) -def join_gcs_path(base_path: str, sub_path: str) -> str: - if base_path.endswith('/'): - return f'{base_path}{sub_path}' - return f'{base_path}/{sub_path}' - - def create_file_handler(path: str) -> FileHandler: if is_gcs_path(path): return GcsFileHandler(path)