diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 02b1545d0..9d9af8ac4 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -30,6 +30,7 @@ from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchchat.cli.download import get_model_dir from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -73,7 +74,7 @@ def __post_init__(self): or (self.pte_path and Path(self.pte_path).is_file()) ): raise RuntimeError( - "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" + f"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path {self.checkpoint_path}" ) if self.dso_path and self.pte_path: @@ -109,10 +110,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": model_config = resolve_model_config(args.model) checkpoint_path = ( - Path(args.model_directory) - / model_config.name + get_model_dir(model_config, args.model_directory) / model_config.checkpoint_file ) + print(f"Using checkpoint path: {checkpoint_path}") # The transformers config is keyed on the last section # of the name/path. params_table = ( @@ -264,8 +265,7 @@ def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs": elif args.model: # Using a named, well-known model model_config = resolve_model_config(args.model) tokenizer_path = ( - Path(args.model_directory) - / model_config.name + get_model_dir(model_config, args.model_directory) / model_config.tokenizer_file ) elif args.checkpoint_path: diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 1d624c6c4..92f8f9987 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -244,7 +244,7 @@ def _add_jit_downloading_args(parser) -> None: "--model-directory", type=Path, default=default_model_dir, - help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}", + help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}. This is overriden by the huggingface cache directory if the model is downloaded from HuggingFace.", ) diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f95cbdaef..12bbae281 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -3,7 +3,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import glob import json import os import re @@ -42,12 +41,7 @@ def convert_hf_checkpoint( print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" - if len(model_map_json_matches): - model_map_json = model_map_json_matches[0] - else: - model_map_json = model_dir / "pytorch_model.bin.index.json" + model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while @@ -62,9 +56,10 @@ def convert_hf_checkpoint( str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True ) del loaded_result # No longer needed - print(f"Moving checkpoint to {model_dir / 'model.pth'}.") - os.rename(consolidated_pth, model_dir / "model.pth") - os.rename(tokenizer_pth, model_dir / "tokenizer.model") + print(f"Symlinking checkpoint to {model_dir / 'model.pth'}.") + consolidated_pth = os.path.realpath(consolidated_pth) + os.symlink(consolidated_pth, model_dir / "model.pth") + os.symlink(tokenizer_pth, model_dir / "tokenizer.model") print("Done.") return else: @@ -81,17 +76,10 @@ def convert_hf_checkpoint( "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", - "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", - "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", - "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias", - "model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias", - "model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", "model.norm.weight": "norm.weight", @@ -100,43 +88,19 @@ def convert_hf_checkpoint( bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): + dim = config.dim return ( - w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:]) + w.view(n_heads, 2, config.head_dim // 2, dim) .transpose(1, 2) - .reshape(w.shape) + .reshape(config.head_dim * n_heads, dim) ) merged_result = {} for file in sorted(bin_files): - - # The state_dict can be loaded from either a torch zip file or - # safetensors. We take our best guess from the name and try all - # possibilities - load_pt_mmap = lambda: torch.load( + state_dict = torch.load( str(file), map_location="cpu", mmap=True, weights_only=True ) - load_pt_no_mmap = lambda: torch.load( - str(file), map_location="cpu", mmap=False, weights_only=True - ) - def load_safetensors(): - import safetensors.torch - with open(file, "rb") as handle: - return safetensors.torch.load(handle.read()) - if "safetensors" in str(file): - loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap] - else: - loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors] - - state_dict = None - for loader in loaders: - try: - state_dict = loader() - break - except Exception: - continue - assert state_dict is not None, f"Unable to load tensors from {file}" merged_result.update(state_dict) - final_result = {} for key, value in merged_result.items(): if "layers" in key: @@ -152,18 +116,16 @@ def load_safetensors(): final_result[new_key] = value for key in tuple(final_result.keys()): - if "wq.weight" in key or "wq.bias" in key: - wk_key = key.replace("wq", "wk") - wv_key = key.replace("wq", "wv") + if "wq" in key: q = final_result[key] - k = final_result[wk_key] - v = final_result[wv_key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] q = permute(q, config.n_heads) k = permute(k, config.n_local_heads) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] - del final_result[wk_key] - del final_result[wv_key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.") torch.save(final_result, model_dir / "model.pth") print("Done.") @@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune( consolidated_pth = model_dir / "original" / "consolidated.pth" tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): - print(f"Moving checkpoint to {model_dir / 'model.pth'}.") - os.rename(consolidated_pth, model_dir / "model.pth") - print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.") - os.rename(tokenizer_pth, model_dir / "tokenizer.model") + print(f"Creating symlink from {consolidated_pth} to {model_dir / 'model.pth'}.") + os.symlink(consolidated_pth, model_dir / "model.pth") + print(f"Creating symlink from {tokenizer_pth} to {model_dir / 'tokenizer.model'}.") + os.symlink(tokenizer_pth, model_dir / "tokenizer.model") print("Done.") else: raise RuntimeError(f"Could not find {consolidated_pth}") diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 14dfeb062..3c6579d6f 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -18,15 +18,19 @@ resolve_model_config, ) +# By default, download models from HuggingFace to the Hugginface hub directory. +# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory. +HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")))) def _download_hf_snapshot( - model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str] + model_config: ModelConfig, hf_token: Optional[str] ): from huggingface_hub import model_info, snapshot_download from requests.exceptions import HTTPError # Download and store the HF model artifacts. - print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr) + model_dir = get_model_dir(model_config, None) + print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True) try: # Fetch the info about the model's repo model_info = model_info(model_config.distribution_path, token=hf_token) @@ -56,8 +60,6 @@ def _download_hf_snapshot( snapshot_download( model_config.distribution_path, - local_dir=artifact_dir, - local_dir_use_symlinks=False, token=hf_token, ignore_patterns=ignore_patterns, ) @@ -76,16 +78,20 @@ def _download_hf_snapshot( else: raise e + # Update the model dir to include the snapshot we just downloaded. + model_dir = get_model_dir(model_config, None) + print("Model downloaded to", model_dir) + # Convert the Multimodal Llama model to the torchtune format. if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}: print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr) - convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name) + convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name) else: # Convert the model to the torchchat format. print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr) convert_hf_checkpoint( - model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True + model_dir=model_dir, model_name=model_config.name, remove_bin_files=True ) @@ -99,12 +105,51 @@ def _download_direct( print(f"Downloading {url}...", file=sys.stderr) urllib.request.urlretrieve(url, str(local_path.absolute())) +def _get_hf_artifact_dir(model_config: ModelConfig) -> Path: + """ + Returns the directory where the model artifacts are stored. + + This is the root folder with blobs, refs and snapshots + """ + assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot) + return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace('/', '--')}" + + +def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path: + """ + Returns the directory where the model artifacts are stored. + For HuggingFace snapshots, this is the HuggingFace cache directory. + For all other distribution channels, we use the models_dir. + + For CLI usage, pass in args.model_directory. + """ + if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + artifact_dir = _get_hf_artifact_dir(model_config) + + # If these paths doesn't exist, it means the model hasn't been downloaded yet. + if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"): + return artifact_dir + snapshot = open(artifact_dir / "refs" / "main", "r").read().strip() + return artifact_dir / "snapshots" / snapshot + else: + return models_dir / model_config.name + def download_and_convert( model: str, models_dir: Path, hf_token: Optional[str] = None ) -> None: model_config = resolve_model_config(model) - model_dir = models_dir / model_config.name + model_dir = get_model_dir(model_config, models_dir) + + # HuggingFace download + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): + _download_hf_snapshot(model_config, hf_token) + return + + # Direct download # Download into a temporary directory. We'll move to the final # location once the download and conversion is complete. This @@ -117,11 +162,6 @@ def download_and_convert( try: if ( - model_config.distribution_channel - == ModelDistributionChannel.HuggingFaceSnapshot - ): - _download_hf_snapshot(model_config, temp_dir, hf_token) - elif ( model_config.distribution_channel == ModelDistributionChannel.DirectDownload ): _download_direct(model_config, temp_dir) @@ -144,9 +184,9 @@ def download_and_convert( def is_model_downloaded(model: str, models_dir: Path) -> bool: model_config = resolve_model_config(model) - + # Check if the model directory exists and is not empty. - model_dir = models_dir / model_config.name + model_dir = get_model_dir(model_config, models_dir) return os.path.isdir(model_dir) and os.listdir(model_dir) @@ -194,13 +234,16 @@ def remove_main(args) -> None: return model_config = resolve_model_config(args.model) - model_dir = args.model_directory / model_config.name + model_dir = get_model_dir(model_config, args.model_directory) if not os.path.isdir(model_dir): - print(f"Model {args.model} has no downloaded artifacts.") + print(f"Model {args.model} has no downloaded artifacts in {model_dir}.") return + if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + # For HuggingFace models, we need to remove the entire root directory. + model_dir = _get_hf_artifact_dir(model_config) - print(f"Removing downloaded model artifacts for {args.model}...") + print(f"Removing downloaded model artifacts for {args.model} at {model_dir}...") shutil.rmtree(model_dir) print("Done.") @@ -216,10 +259,10 @@ def where_main(args) -> None: return model_config = resolve_model_config(args.model) - model_dir = args.model_directory / model_config.name + model_dir = get_model_dir(model_config, args.model_directory) if not os.path.isdir(model_dir): - raise RuntimeError(f"Model {args.model} has no downloaded artifacts.") + raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.") print(str(os.path.abspath(model_dir))) exit(0) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 72a6dfc9b..2adf170bd 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -23,7 +23,7 @@ from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform -from torchchat.cli.download import is_model_downloaded, load_model_configs +from torchchat.cli.download import is_model_downloaded, load_model_configs, get_model_dir from torchchat.generate import Generator, GeneratorArgs from torchchat.model import FlamingoModel @@ -522,7 +522,7 @@ def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]: """ if model_config := load_model_configs().get(model_id): if is_model_downloaded(model_id, args.model_directory): - path = args.model_directory / model_config.name + path = get_model_dir(model_config, args.model_directory) created = int(os.path.getctime(path)) owned_by = getpwuid(os.stat(path).st_uid).pw_name @@ -545,7 +545,7 @@ def get_model_info_list(args) -> ModelInfo: data = [] for model_id, model_config in load_model_configs().items(): if is_model_downloaded(model_id, args.model_directory): - path = args.model_directory / model_config.name + path = get_model_dir(model_config, args.model_directory) created = int(os.path.getctime(path)) owned_by = getpwuid(os.stat(path).st_uid).pw_name