Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shifucun committed May 20, 2024
1 parent c49eb62 commit 8bd8e9a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
8 changes: 3 additions & 5 deletions server/lib/nl/common/bad_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from server.lib.config import GLOBAL_CONFIG_BUCKET
from shared.lib import gcs

BAD_WORDS_FILE = 'nl_bad_words.txt'
BAD_WORDS_PATH = gcs.make_path(GLOBAL_CONFIG_BUCKET, 'nl_bad_words.txt')
_DELIM = ':'


Expand Down Expand Up @@ -74,8 +74,7 @@ class BannedWords:
# Loads a list of bad words from a text file.
#
def load_bad_words() -> BannedWords:
local_file = gcs.maybe_download(
f'gs://{GLOBAL_CONFIG_BUCKET}/{BAD_WORDS_FILE}')
local_file = gcs.maybe_download(BAD_WORDS_PATH)
return load_bad_words_file(local_file)


Expand Down Expand Up @@ -133,8 +132,7 @@ def load_bad_words_file(local_file: str, validate: bool = False) -> BannedWords:


def validate_bad_words():
local_file = gcs.maybe_download(
f'gs://{GLOBAL_CONFIG_BUCKET}/{BAD_WORDS_FILE}')
local_file = gcs.maybe_download(BAD_WORDS_PATH)
load_bad_words_file(local_file, validate=True)


Expand Down
14 changes: 11 additions & 3 deletions shared/lib/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def get_path_parts(gcs_path: str) -> Tuple[str, str]:
return tuple(gcs_path.removeprefix(GCS_PATH_PREFIX).split('/', 1))


def make_path(bucket_name: str, blob_name: str) -> str:
return GCS_PATH_PREFIX + bucket_name + '/' + blob_name


def download_blob(bucket_name: str,
blob_name: str,
local_path: str,
use_anonymous_client: bool = False) -> bool:
"""
Downloads the content of a GCS folder to a local folder.
Downloads the content of a GCS blob to a local path.
Args:
- bucket_name: The name of the GCS bucket.
Expand All @@ -55,6 +59,7 @@ def download_blob(bucket_name: str,
blobs = bucket.list_blobs(prefix=blob_name)
count = 0
for blob in blobs:
# When a blob name ends with "/", the blob is a folder. No need to download.
if blob.name.endswith("/"):
continue
# Get the relative path to the input blob. This is used to download folder.
Expand Down Expand Up @@ -115,8 +120,11 @@ def maybe_download(gcs_path: str,
raise ValueError(f"Invalid GCS path: {gcs_path}")
bucket_name, blob_name = get_path_parts(gcs_path)
local_path = os.path.join(local_path_root, bucket_name, blob_name)
if os.path.exists(local_path):
if os.path.exists(local_path) and len(os.listdir(local_path)) > 0:
# When running locally, we may already have downloaded the path.
# But sometimes after restart, the directories in `/tmp` become empty,
# so ensure that's not the case. return local_path
return local_path
if download_blob_by_path(gcs_path, local_path, use_anonymous_client):
if download_blob(bucket_name, blob_name, local_path, use_anonymous_client):
return local_path
return None
10 changes: 10 additions & 0 deletions shared/lib/gcs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


class TestGCSFunctions(unittest.TestCase):
# All the test files are prepared in this bucket.
bucket_name = 'datcom-ci-test'

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -55,11 +56,18 @@ def test_get_path_parts_invalid_path(self):
['a.txt', 'folder2/a.txt', 'a'],
])
def test_download_file(self, local_file_path, blob_name, content):
"""
Download a file from GCS to a local path. Here the blob_name is a file name.
"""
f = self.tmp_path / local_file_path
gcs.download_blob(self.bucket_name, blob_name, f)
self.assertEqual(f.read_text(), content)

def test_download_folder(self):
"""
Test downloading a folder from GCS to a local path. And check all the nested
files exist.
"""
gcs_folder = 'folder1'
gcs.download_blob(self.bucket_name, gcs_folder, self.tmp_path)
p = Path(self.tmp_path)
Expand All @@ -72,6 +80,8 @@ def test_download_folder(self):
[f'gs://{bucket_name}/folder1/folder11/d.txt', 'tmp.txt', 'd'],
])
def test_download_blob_by_path(self, gcs_path, local_file_path, content):
"""Download a file based on GCS path.
"""
f = self.tmp_path / local_file_path
gcs.download_blob_by_path(gcs_path, f)
self.assertEqual(f.read_text(), content)

0 comments on commit 8bd8e9a

Please sign in to comment.