diff --git a/server/lib/nl/common/bad_words.py b/server/lib/nl/common/bad_words.py index 3d726396c4..df6d71e520 100644 --- a/server/lib/nl/common/bad_words.py +++ b/server/lib/nl/common/bad_words.py @@ -20,7 +20,7 @@ from server.lib.config import GLOBAL_CONFIG_BUCKET from shared.lib import gcs -BAD_WORDS_FILE = 'nl_bad_words.txt' +BAD_WORDS_PATH = gcs.make_path(GLOBAL_CONFIG_BUCKET, 'nl_bad_words.txt') _DELIM = ':' @@ -74,8 +74,7 @@ class BannedWords: # Loads a list of bad words from a text file. # def load_bad_words() -> BannedWords: - local_file = gcs.maybe_download( - f'gs://{GLOBAL_CONFIG_BUCKET}/{BAD_WORDS_FILE}') + local_file = gcs.maybe_download(BAD_WORDS_PATH) return load_bad_words_file(local_file) @@ -133,8 +132,7 @@ def load_bad_words_file(local_file: str, validate: bool = False) -> BannedWords: def validate_bad_words(): - local_file = gcs.maybe_download( - f'gs://{GLOBAL_CONFIG_BUCKET}/{BAD_WORDS_FILE}') + local_file = gcs.maybe_download(BAD_WORDS_PATH) load_bad_words_file(local_file, validate=True) diff --git a/shared/lib/gcs.py b/shared/lib/gcs.py index 8d7a5abae2..4c537de6bc 100644 --- a/shared/lib/gcs.py +++ b/shared/lib/gcs.py @@ -33,12 +33,16 @@ def get_path_parts(gcs_path: str) -> Tuple[str, str]: return tuple(gcs_path.removeprefix(GCS_PATH_PREFIX).split('/', 1)) +def make_path(bucket_name: str, blob_name: str) -> str: + return GCS_PATH_PREFIX + bucket_name + '/' + blob_name + + def download_blob(bucket_name: str, blob_name: str, local_path: str, use_anonymous_client: bool = False) -> bool: """ - Downloads the content of a GCS folder to a local folder. + Downloads the content of a GCS blob to a local path. Args: - bucket_name: The name of the GCS bucket. @@ -55,6 +59,7 @@ def download_blob(bucket_name: str, blobs = bucket.list_blobs(prefix=blob_name) count = 0 for blob in blobs: + # When a blob name ends with "/", the blob is a folder. No need to download. if blob.name.endswith("/"): continue # Get the relative path to the input blob. This is used to download folder. @@ -115,8 +120,11 @@ def maybe_download(gcs_path: str, raise ValueError(f"Invalid GCS path: {gcs_path}") bucket_name, blob_name = get_path_parts(gcs_path) local_path = os.path.join(local_path_root, bucket_name, blob_name) - if os.path.exists(local_path): + if os.path.exists(local_path) and len(os.listdir(local_path)) > 0: + # When running locally, we may already have downloaded the path. + # But sometimes after restart, the directories in `/tmp` become empty, + # so ensure that's not the case. return local_path return local_path - if download_blob_by_path(gcs_path, local_path, use_anonymous_client): + if download_blob(bucket_name, blob_name, local_path, use_anonymous_client): return local_path return None diff --git a/shared/lib/gcs_test.py b/shared/lib/gcs_test.py index 95ee8f7287..6e9c701ecd 100644 --- a/shared/lib/gcs_test.py +++ b/shared/lib/gcs_test.py @@ -24,6 +24,7 @@ class TestGCSFunctions(unittest.TestCase): + # All the test files are prepared in this bucket. bucket_name = 'datcom-ci-test' @pytest.fixture(autouse=True) @@ -55,11 +56,18 @@ def test_get_path_parts_invalid_path(self): ['a.txt', 'folder2/a.txt', 'a'], ]) def test_download_file(self, local_file_path, blob_name, content): + """ + Download a file from GCS to a local path. Here the blob_name is a file name. + """ f = self.tmp_path / local_file_path gcs.download_blob(self.bucket_name, blob_name, f) self.assertEqual(f.read_text(), content) def test_download_folder(self): + """ + Test downloading a folder from GCS to a local path. And check all the nested + files exist. + """ gcs_folder = 'folder1' gcs.download_blob(self.bucket_name, gcs_folder, self.tmp_path) p = Path(self.tmp_path) @@ -72,6 +80,8 @@ def test_download_folder(self): [f'gs://{bucket_name}/folder1/folder11/d.txt', 'tmp.txt', 'd'], ]) def test_download_blob_by_path(self, gcs_path, local_file_path, content): + """Download a file based on GCS path. + """ f = self.tmp_path / local_file_path gcs.download_blob_by_path(gcs_path, f) self.assertEqual(f.read_text(), content)