diff --git a/llm/generate.py b/llm/generate.py index 11fccd0..c8b3c46 100644 --- a/llm/generate.py +++ b/llm/generate.py @@ -7,8 +7,6 @@ import argparse import json import sys -import re -from collections import Counter from typing import List from huggingface_hub import snapshot_download import utils.marsgen as mg @@ -20,7 +18,6 @@ create_folder_if_not_exists, delete_directory, copy_file, - get_all_files_in_directory, check_if_folder_empty, ) @@ -29,63 +26,40 @@ MODEL_STORE_DIR = "model-store" HANDLER = "handler.py" MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json") -FILE_EXTENSIONS_TO_IGNORE = [ - ".safetensors", - ".safetensors.index.json", - ".h5", - ".ot", - ".tflite", - ".msgpack", - ".onnx", +PREFERRED_MODEL_FORMATS = [".safetensors", ".bin"] # In order of Preference +OTHER_MODEL_FORMATS = [ + "*.pt", + "*.h5", + "*.gguf", + "*.msgpack", + "*.tflite", + "*.ot", + "*.onnx", ] -def get_ignore_pattern_list(extension_list: List[str]) -> List[str]: +def get_ignore_pattern_list(gen_model: GenerateDataModel) -> List[str]: """ - This function takes a list of file extensions and returns a list of patterns - that can be used to filter out files with these extensions. + This method creates a list of file extensions to ignore from a priority list based on files + present in the Hugging Face Repo. It filters out extensions not found in the repository and + returns them as ignore patterns prefixed with '*' which is expected by Hugging Face client. Args: - extension_list (list): A list of file extensions. + gen_model (GenerateDataModel): An instance of the GenerateDataModel class Returns: - list: A list of patterns with '*' prepended to each extension, suitable for filtering files. + list(str): A list of patterns with '*' prepended to each extension, + suitable for filtering files. """ - return ["*" + pattern for pattern in extension_list] - - -def compare_lists(list1: List[str], list2: List[str]) -> bool: - """ - This function checks if two lists are equal by - comparing their contents, regardless of the order. - Args: - list1 (list): The first list to compare. - list2 (list): The second list to compare. - - Returns: - bool: True if the lists have the same elements, False otherwise. - """ - return Counter(list1) == Counter(list2) - - -def filter_files_by_extension( - filenames: List[str], extensions_to_remove: List[str] -) -> List[str]: - """ - This function takes a list of filenames and a list - of extensions to remove. It returns a new list of filenames - after filtering out those with specified extensions. - Args: - filenames (list): A list of filenames to be filtered. - extensions_to_remove (list): A list of file extensions to remove. - Returns: - list: A list of filenames after filtering. - """ - pattern = "|".join([re.escape(suffix) + "$" for suffix in extensions_to_remove]) - # for the extensions in FILE_EXTENSIONS_TO_IGNORE - # pattern will be '\.safetensors$|\.safetensors\.index\.json$' - filtered_filenames = [ - filename for filename in filenames if not re.search(pattern, filename) - ] - return filtered_filenames + repo_file_extensions = hf.get_repo_file_extensions(gen_model) + for desired_extension in PREFERRED_MODEL_FORMATS: + if desired_extension in repo_file_extensions: + ignore_list = [ + "*" + ignore_extension + for ignore_extension in PREFERRED_MODEL_FORMATS + if ignore_extension != desired_extension + ] + ignore_list.extend(OTHER_MODEL_FORMATS) + return ignore_list + return [] def set_config(gen_model: GenerateDataModel) -> None: @@ -141,24 +115,6 @@ class with relevant information. config_file.writelines(config_info) -def check_if_model_files_exist(gen_model: GenerateDataModel) -> bool: - """ - This function compares the list of files in the downloaded model - directory with the list of files in the HuggingFace repository. - It takes into account any files to ignore based on predefined extensions. - Args: - gen_model (GenerateDataModel): An instance of the GenerateDataModel - class with relevant information. - Returns: - bool: True if the downloaded model files match the expected - repository files, False otherwise. - """ - extra_files_list = get_all_files_in_directory(gen_model.mar_utils.model_path) - repo_files = hf.get_repo_files_list(gen_model) - repo_files = filter_files_by_extension(repo_files, FILE_EXTENSIONS_TO_IGNORE) - return compare_lists(extra_files_list, repo_files) - - def check_if_mar_file_exist(gen_model: GenerateDataModel) -> bool: """ This function checks if the Model Archive (MAR) file for the @@ -266,27 +222,24 @@ class with relevant information. Returns: GenerateDataModel: An instance of the GenerateDataModel class. """ - if os.path.exists(gen_model.mar_utils.model_path) and check_if_model_files_exist( - gen_model - ): - print( - ( - "## Skipping downloading as model files of the needed" - " repo version are already present\n" - ) - ) - return gen_model print("## Starting model files download\n") delete_directory(gen_model.mar_utils.model_path) create_folder_if_not_exists(gen_model.mar_utils.model_path) + + tmp_hf_cache = os.path.join(gen_model.mar_utils.model_path, "tmp_hf_cache") + create_folder_if_not_exists(tmp_hf_cache) + snapshot_download( repo_id=gen_model.repo_info.repo_id, revision=gen_model.repo_info.repo_version, local_dir=gen_model.mar_utils.model_path, - local_dir_use_symlinks=False, token=gen_model.repo_info.hf_token, - ignore_patterns=get_ignore_pattern_list(FILE_EXTENSIONS_TO_IGNORE), + local_dir_use_symlinks=False, + cache_dir=tmp_hf_cache, + force_download=True, + ignore_patterns=get_ignore_pattern_list(gen_model), ) + delete_directory(tmp_hf_cache) print("## Successfully downloaded model_files\n") return gen_model @@ -305,23 +258,16 @@ class with relevant information. print("## Skipping generation of model archive file as it is present\n") else: check_if_path_exists(gen_model.mar_utils.model_path, "model_path", is_dir=True) - if not gen_model.is_custom: - if not check_if_model_files_exist(gen_model): - # checking if local model files are same the repository files - print("## Model files do not match HuggingFace repository Files") - sys.exit(1) - else: - if check_if_folder_empty(gen_model.mar_utils.model_path): - print( - f"\n##Error: {gen_model.model_name} model files for the custom" - f" model not found in the provided path: {gen_model.mar_utils.model_path}" - ) - sys.exit(1) - else: - print( - f"\n## Generating MAR file for custom model files: {gen_model.model_name} \n" - ) + if check_if_folder_empty(gen_model.mar_utils.model_path): + print( + f"\n##Error: {gen_model.model_name} model files for the custom" + f" model not found in the provided path: {gen_model.mar_utils.model_path}" + ) + sys.exit(1) + print( + f"\n## Generating MAR file for custom model files: {gen_model.model_name} \n" + ) create_folder_if_not_exists(gen_model.mar_utils.mar_output) mg.generate_mars( diff --git a/llm/requirements.txt b/llm/requirements.txt index d11051b..7c0d86d 100644 --- a/llm/requirements.txt +++ b/llm/requirements.txt @@ -1,4 +1,4 @@ torch-model-archiver==0.8.1 kubernetes==28.1.0 kserve==0.11.1 -huggingface-hub==0.20.1 \ No newline at end of file +huggingface-hub==0.22.2 \ No newline at end of file diff --git a/llm/utils/hf_utils.py b/llm/utils/hf_utils.py index 9def319..e60de05 100644 --- a/llm/utils/hf_utils.py +++ b/llm/utils/hf_utils.py @@ -3,9 +3,10 @@ """ import sys -from typing import List +import os from huggingface_hub import HfApi from huggingface_hub.utils import ( + GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError, HfHubHTTPError, @@ -14,20 +15,19 @@ from utils.generate_data_model import GenerateDataModel -def get_repo_files_list(gen_model: GenerateDataModel) -> List[str]: +def get_repo_file_extensions(gen_model: GenerateDataModel) -> set: """ - This function returns a list of all files in the HuggingFace repo of + This function returns set of all file extensions in the Hugging Face repo of the model. Args: - gen_model (GenerateDataModel): An instance of the GenerateDataModel - class with relevant information. + gen_model (GenerateDataModel): An instance of the GenerateDataModel class Returns: - repo_files (list): all files in the HuggingFace repo of - the model + repo_file_extension (set): The set of all file extensions in the + Hugging Face repo of the model Raises: sys.exit(1): If repo_id, repo_version or huggingface token - is not valid, the function will terminate - the program with an exit code of 1. + is not valid, the function will terminate + the program with an exit code of 1. """ try: hf_api = HfApi() @@ -36,19 +36,19 @@ class with relevant information. revision=gen_model.repo_info.repo_version, token=gen_model.repo_info.hf_token, ) - return repo_files + return {os.path.splitext(file_name)[1] for file_name in repo_files} except ( - HfHubHTTPError, - HFValidationError, + GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError, + HfHubHTTPError, + HFValidationError, + ValueError, KeyError, ): print( - ( - "\n## Error: Please check either repo_id, repo_version " - "or huggingface token is not correct\n" - ) + "## Error: Please check either repo_id, repo_version" + " or HuggingFace ID is not correct\n" ) sys.exit(1)