From 05f2c45500e54779e173c3f126ff05c1b9adf2d9 Mon Sep 17 00:00:00 2001 From: Liam Clark Date: Tue, 12 Mar 2024 12:16:48 +0200 Subject: [PATCH 1/2] Modifies json_utils.pull_neptune_data to now download files concurrently - Neptune logs are now downloaded concurrently yielding increased download speed. - We now only fetch the `sys/id` column when initialising the project instead of downloading the full table. - The downloaded JSON files are extracted directly into the `store_directory` creating a flat file structure. i.e. No more nested JSON files. - Downloaded files are also named according to their `run_id`. --- marl_eval/json_tools/json_utils.py | 126 +++++++++++++++++------------ 1 file changed, 76 insertions(+), 50 deletions(-) diff --git a/marl_eval/json_tools/json_utils.py b/marl_eval/json_tools/json_utils.py index 934bacef..58203450 100644 --- a/marl_eval/json_tools/json_utils.py +++ b/marl_eval/json_tools/json_utils.py @@ -18,6 +18,8 @@ import os import zipfile from collections import defaultdict +from concurrent.futures import as_completed, ThreadPoolExecutor +from pathlib import Path from typing import Dict, List, Tuple import neptune @@ -111,66 +113,90 @@ def concatenate_json_files( def pull_neptune_data( - project_name: str, - tag: List, - store_directory: str = "./downloaded_json_data", - neptune_data_key: str = "metrics", + project_name: str, + tags: List[str], + store_directory: str = "./downloaded_json_data", + neptune_data_key: str = "metrics", + disable_progress_bar: bool = False, ) -> None: - """Pulls experiment json data from Neptune to a local directory. + """Downloads logs from a Neptune project based on provided tags. Args: project_name (str): Name of the Neptune project. - tag (List): List of tags for the experiment(s) that contain the - desired JSON files. - store_directory (str, optional): Directory to store the data. - Default: ./downloaded_json_data. - neptune_data_key (str, optional): Key in the neptune run where the - json data is stored. Default: metrics. + tags (List[str]): List of tags associated with the desired experiments. + store_directory (str, optional): Directory to store the downloaded logs. + Default is "./downloaded_json_data". + neptune_data_key (str, optional): Key for the Neptune data to download. + Default is "metrics". + disable_progress_bar (bool, optional): Whether to hide a progress bar during download. + Default is False. + + Raises: + ValueError: If the provided project name or tags are invalid. """ - # Get the run ids - project = neptune.init_project(project=project_name) - runs_table_df = project.fetch_runs_table(state="inactive", tag=tag).to_pandas() - run_ids = runs_table_df["sys/id"].values.tolist() - - # Check if store_directory exists - if not os.path.exists(store_directory): - os.makedirs(store_directory) + # Create the log directory if it doesn't exist + os.makedirs(store_directory, exist_ok=True) - # Suppress neptune logger + # Disable Neptune logging neptune_logger = logging.getLogger("neptune") neptune_logger.setLevel(logging.ERROR) - # Download and unzip the data - for run_id in tqdm(run_ids, desc="Downloading Neptune Data"): - run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only") - for j, data_key in enumerate( - run.get_structure()[neptune_data_key].keys(), start=1 - ): - # Create a unique filename - file_path = f"{store_directory}/{data_key}_{run_id}_{j}" - run[f"{neptune_data_key}/{data_key}"].download(destination=file_path) - # Try to unzip the file else continue to the next file - try: - with zipfile.ZipFile(file_path, "r") as zip_ref: - # Create a directory to store unzipped data - os.makedirs(f"{file_path}_unzip", exist_ok=True) - # Unzip the data - zip_ref.extractall(f"{file_path}_unzip") - # Remove the zip file - os.remove(file_path) - except zipfile.BadZipFile: - # If the file is not zipped continue to the next file - # as it is already downloaded and doesn't need to be - # unzipped. - continue - except Exception as e: - print( - f"The following error occurred while unzipping or storing JSON \ - data for run {run_id} at path {file_path}: {e}" - ) - run.stop() + # Initialize the Neptune project + try: + project = neptune.init_project(project=project_name) + except Exception as e: + raise ValueError(f"Invalid project name '{project_name}': {e}") + + # Fetch runs based on provided tags + try: + runs_table_df = project.fetch_runs_table( + state="inactive", + columns=['sys/id'], + tag=tags, + sort_by='sys/id' + ).to_pandas() + except Exception as e: + raise ValueError(f"Invalid tags {tags}: {e}") + + run_ids = runs_table_df["sys/id"].values.tolist() + + # Download logs concurrently + with ThreadPoolExecutor() as executor: + futures = [executor.submit(download_and_extract_data, project_name, run_id, store_directory, neptune_data_key) for run_id in run_ids] + for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading JSON logs", disable=disable_progress_bar): + future.result() # Restore neptune logger level neptune_logger.setLevel(logging.INFO) - print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}") + + +def download_and_extract_data(project_name, run_id, store_directory, neptune_data_key): + try: + with neptune.init_run(project=project_name, with_id=run_id, mode="read-only") as run: + for j, data_key in enumerate(run.get_structure()[neptune_data_key].keys(), start=1): + file_path = f"{store_directory}/{run_id}" + if j > 1: + file_path += f"_{j}" + run[f"{neptune_data_key}/{data_key}"].download(destination=file_path) + extract_zip_file(file_path) + except Exception as e: + print(f"Error downloading data for run {run_id}: {e}") + + +def extract_zip_file(file_path): + try: + with zipfile.ZipFile(file_path, "r") as zip_ref: + for member in zip_ref.infolist(): + if not member.is_dir(): + target_path = Path(f"{file_path}{Path(member.filename).suffix}") + target_path.parent.mkdir(parents=True, exist_ok=True) + with zip_ref.open(member) as src, target_path.open("wb") as dest: + dest.write(src.read()) + # Remove the zip file + os.remove(file_path) + except zipfile.BadZipFile: + # If the file is not zipped, no action is required + pass + except Exception as e: + print(f"Error while unzipping or storing JSON data at path {file_path}: {e}") From 5869627d90ed74fc082dbb9486b7dcd2573d958c Mon Sep 17 00:00:00 2001 From: Liam Clark Date: Tue, 12 Mar 2024 12:31:46 +0200 Subject: [PATCH 2/2] Fix formatting --- marl_eval/json_tools/json_utils.py | 53 ++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/marl_eval/json_tools/json_utils.py b/marl_eval/json_tools/json_utils.py index 58203450..93633868 100644 --- a/marl_eval/json_tools/json_utils.py +++ b/marl_eval/json_tools/json_utils.py @@ -18,7 +18,7 @@ import os import zipfile from collections import defaultdict -from concurrent.futures import as_completed, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Tuple @@ -113,11 +113,11 @@ def concatenate_json_files( def pull_neptune_data( - project_name: str, - tags: List[str], - store_directory: str = "./downloaded_json_data", - neptune_data_key: str = "metrics", - disable_progress_bar: bool = False, + project_name: str, + tags: List[str], + store_directory: str = "./downloaded_json_data", + neptune_data_key: str = "metrics", + disable_progress_bar: bool = False, ) -> None: """Downloads logs from a Neptune project based on provided tags. @@ -128,7 +128,7 @@ def pull_neptune_data( Default is "./downloaded_json_data". neptune_data_key (str, optional): Key for the Neptune data to download. Default is "metrics". - disable_progress_bar (bool, optional): Whether to hide a progress bar during download. + disable_progress_bar (bool, optional): Whether to hide a progress bar. Default is False. Raises: @@ -150,10 +150,7 @@ def pull_neptune_data( # Fetch runs based on provided tags try: runs_table_df = project.fetch_runs_table( - state="inactive", - columns=['sys/id'], - tag=tags, - sort_by='sys/id' + state="inactive", columns=["sys/id"], tag=tags, sort_by="sys/id" ).to_pandas() except Exception as e: raise ValueError(f"Invalid tags {tags}: {e}") @@ -162,8 +159,22 @@ def pull_neptune_data( # Download logs concurrently with ThreadPoolExecutor() as executor: - futures = [executor.submit(download_and_extract_data, project_name, run_id, store_directory, neptune_data_key) for run_id in run_ids] - for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading JSON logs", disable=disable_progress_bar): + futures = [ + executor.submit( + _download_and_extract_data, + project_name, + run_id, + store_directory, + neptune_data_key, + ) + for run_id in run_ids + ] + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Downloading JSON logs", + disable=disable_progress_bar, + ): future.result() # Restore neptune logger level @@ -171,20 +182,26 @@ def pull_neptune_data( print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}") -def download_and_extract_data(project_name, run_id, store_directory, neptune_data_key): +def _download_and_extract_data( + project_name: str, run_id: str, store_directory: str, neptune_data_key: str +) -> None: try: - with neptune.init_run(project=project_name, with_id=run_id, mode="read-only") as run: - for j, data_key in enumerate(run.get_structure()[neptune_data_key].keys(), start=1): + with neptune.init_run( + project=project_name, with_id=run_id, mode="read-only" + ) as run: + for j, data_key in enumerate( + run.get_structure()[neptune_data_key].keys(), start=1 + ): file_path = f"{store_directory}/{run_id}" if j > 1: file_path += f"_{j}" run[f"{neptune_data_key}/{data_key}"].download(destination=file_path) - extract_zip_file(file_path) + _extract_zip_file(file_path) except Exception as e: print(f"Error downloading data for run {run_id}: {e}") -def extract_zip_file(file_path): +def _extract_zip_file(file_path: str) -> None: try: with zipfile.ZipFile(file_path, "r") as zip_ref: for member in zip_ref.infolist():