-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better convert. #384
Better convert. #384
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,16 +3,14 @@ | |||||
import os | ||||||
import shutil | ||||||
from collections import defaultdict | ||||||
from inspect import signature | ||||||
from tempfile import TemporaryDirectory | ||||||
from typing import Dict, List, Optional, Set, Tuple | ||||||
|
||||||
import torch | ||||||
|
||||||
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download | ||||||
from huggingface_hub.file_download import repo_folder_name | ||||||
from safetensors.torch import load_file, save_file | ||||||
from transformers import AutoConfig | ||||||
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete | ||||||
|
||||||
|
||||||
COMMIT_DESCRIPTION = """ | ||||||
|
@@ -34,20 +32,78 @@ | |||||
|
||||||
ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] | ||||||
|
||||||
def _remove_duplicate_names( | ||||||
state_dict: Dict[str, torch.Tensor], | ||||||
*, | ||||||
preferred_names: List[str] = None, | ||||||
discard_names: List[str] = None, | ||||||
) -> Dict[str, List[str]]: | ||||||
if preferred_names is None: | ||||||
preferred_names = [] | ||||||
preferred_names = set(preferred_names) | ||||||
if discard_names is None: | ||||||
discard_names = [] | ||||||
discard_names = set(discard_names) | ||||||
|
||||||
shareds = _find_shared_tensors(state_dict) | ||||||
to_remove = defaultdict(list) | ||||||
for shared in shareds: | ||||||
complete_names = set( | ||||||
[name for name in shared if _is_complete(state_dict[name])] | ||||||
) | ||||||
if not complete_names: | ||||||
if len(shared) == 1: | ||||||
# Force contiguous | ||||||
name = list(shared)[0] | ||||||
state_dict[name] = state_dict[name].clone() | ||||||
complete_names = {name} | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." | ||||||
) | ||||||
|
||||||
class AlreadyExists(Exception): | ||||||
pass | ||||||
keep_name = sorted(list(complete_names))[0] | ||||||
|
||||||
# Mecanism to preferentially select keys to keep | ||||||
# coming from the on-disk file to allow | ||||||
# loading models saved with a different choice | ||||||
# of keep_name | ||||||
preferred = complete_names.difference(discard_names) | ||||||
if preferred: | ||||||
keep_name = sorted(list(preferred))[0] | ||||||
|
||||||
if preferred_names: | ||||||
preferred = preferred_names.intersection(complete_names) | ||||||
if preferred: | ||||||
keep_name = sorted(list(preferred))[0] | ||||||
for name in sorted(shared): | ||||||
if name != keep_name: | ||||||
to_remove[keep_name].append(name) | ||||||
return to_remove | ||||||
|
||||||
def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]: | ||||||
try: | ||||||
import transformers | ||||||
import json | ||||||
|
||||||
config_filename = hf_hub_download( | ||||||
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder | ||||||
) | ||||||
with open(config_filename, "r") as f: | ||||||
config = json.load(f) | ||||||
architecture = config["architectures"][0] | ||||||
|
||||||
class_ = getattr(transformers, architecture) | ||||||
|
||||||
def shared_pointers(tensors): | ||||||
ptrs = defaultdict(list) | ||||||
for k, v in tensors.items(): | ||||||
ptrs[v.data_ptr()].append(k) | ||||||
failing = [] | ||||||
for ptr, names in ptrs.items(): | ||||||
if len(names) > 1: | ||||||
failing.append(names) | ||||||
return failing | ||||||
# Name for this varible depends on transformers version. | ||||||
discard_names = getattr(class_, "_tied_weights_keys", []) | ||||||
|
||||||
except Exception as e: | ||||||
discard_names = [] | ||||||
return discard_names | ||||||
|
||||||
class AlreadyExists(Exception): | ||||||
pass | ||||||
|
||||||
|
||||||
def check_file_size(sf_filename: str, pt_filename: str): | ||||||
|
@@ -70,8 +126,8 @@ def rename(pt_filename: str) -> str: | |||||
return local | ||||||
|
||||||
|
||||||
def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult: | ||||||
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) | ||||||
def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: | ||||||
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) | ||||||
with open(filename, "r") as f: | ||||||
data = json.load(f) | ||||||
|
||||||
|
@@ -82,7 +138,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio | |||||
|
||||||
sf_filename = rename(pt_filename) | ||||||
sf_filename = os.path.join(folder, sf_filename) | ||||||
convert_file(pt_filename, sf_filename) | ||||||
convert_file(pt_filename, sf_filename, discard_names=discard_names) | ||||||
local_filenames.append(sf_filename) | ||||||
|
||||||
index = os.path.join(folder, "model.safetensors.index.json") | ||||||
|
@@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio | |||||
return operations, errors | ||||||
|
||||||
|
||||||
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult: | ||||||
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: | ||||||
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder) | ||||||
|
||||||
sf_name = "model.safetensors" | ||||||
sf_filename = os.path.join(folder, sf_name) | ||||||
convert_file(pt_filename, sf_filename) | ||||||
convert_file(pt_filename, sf_filename, discard_names) | ||||||
operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)] | ||||||
errors: List[Tuple[str, "Exception"]] = [] | ||||||
return operations, errors | ||||||
|
@@ -115,21 +171,25 @@ def convert_single(model_id: str, folder: str, token: Optional[str]) -> Conversi | |||||
def convert_file( | ||||||
pt_filename: str, | ||||||
sf_filename: str, | ||||||
discard_names: List[str], | ||||||
): | ||||||
loaded = torch.load(pt_filename, map_location="cpu") | ||||||
if "state_dict" in loaded: | ||||||
loaded = loaded["state_dict"] | ||||||
shared = shared_pointers(loaded) | ||||||
for shared_weights in shared: | ||||||
for name in shared_weights[1:]: | ||||||
loaded.pop(name) | ||||||
|
||||||
# For tensors to be contiguous | ||||||
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) | ||||||
|
||||||
metadata = {"format": "pt"} | ||||||
for kept_name, to_remove_group in to_removes.items(): | ||||||
for to_remove in to_remove_group: | ||||||
if to_remove not in metadata: | ||||||
metadata[to_remove] = kept_name | ||||||
del loaded[to_remove] | ||||||
# Force tensors to be contiguous | ||||||
loaded = {k: v.contiguous() for k, v in loaded.items()} | ||||||
|
||||||
dirname = os.path.dirname(sf_filename) | ||||||
os.makedirs(dirname, exist_ok=True) | ||||||
save_file(loaded, sf_filename, metadata={"format": "pt"}) | ||||||
save_file(loaded, sf_filename, metadata=metadata) | ||||||
check_file_size(sf_filename, pt_filename) | ||||||
reloaded = load_file(sf_filename) | ||||||
for k in loaded: | ||||||
|
@@ -155,79 +215,10 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) | |||||
return "\n".join(errors) | ||||||
|
||||||
|
||||||
def check_final_model(model_id: str, folder: str, token: Optional[str]): | ||||||
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder) | ||||||
shutil.copy(config, os.path.join(folder, "config.json")) | ||||||
config = AutoConfig.from_pretrained(folder) | ||||||
|
||||||
import transformers | ||||||
|
||||||
class_ = getattr(transformers, config.architectures[0]) | ||||||
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True) | ||||||
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True) | ||||||
|
||||||
if pt_infos != sf_infos: | ||||||
error_string = create_diff(pt_infos, sf_infos) | ||||||
raise ValueError(f"Different infos when reloading the model: {error_string}") | ||||||
|
||||||
pt_params = pt_model.state_dict() | ||||||
sf_params = sf_model.state_dict() | ||||||
|
||||||
pt_shared = shared_pointers(pt_params) | ||||||
sf_shared = shared_pointers(sf_params) | ||||||
if pt_shared != sf_shared: | ||||||
raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}") | ||||||
|
||||||
sig = signature(pt_model.forward) | ||||||
input_ids = torch.arange(10).unsqueeze(0) | ||||||
pixel_values = torch.randn(1, 3, 224, 224) | ||||||
input_values = torch.arange(1000).float().unsqueeze(0) | ||||||
# Hardcoded for whisper basically | ||||||
input_features = torch.zeros((1, 80, 3000)) | ||||||
kwargs = {} | ||||||
if "input_ids" in sig.parameters: | ||||||
kwargs["input_ids"] = input_ids | ||||||
if "input_features" in sig.parameters: | ||||||
kwargs["input_features"] = input_features | ||||||
if "decoder_input_ids" in sig.parameters: | ||||||
kwargs["decoder_input_ids"] = input_ids | ||||||
if "pixel_values" in sig.parameters: | ||||||
kwargs["pixel_values"] = pixel_values | ||||||
if "input_values" in sig.parameters: | ||||||
kwargs["input_values"] = input_values | ||||||
if "bbox" in sig.parameters: | ||||||
kwargs["bbox"] = torch.zeros((1, 10, 4)).long() | ||||||
if "image" in sig.parameters: | ||||||
kwargs["image"] = pixel_values | ||||||
|
||||||
if torch.cuda.is_available(): | ||||||
pt_model = pt_model.cuda() | ||||||
sf_model = sf_model.cuda() | ||||||
kwargs = {k: v.cuda() for k, v in kwargs.items()} | ||||||
|
||||||
try: | ||||||
pt_logits = pt_model(**kwargs)[0] | ||||||
except Exception as e: | ||||||
try: | ||||||
# Musicgen special exception. | ||||||
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long) | ||||||
if torch.cuda.is_available(): | ||||||
decoder_input_ids = decoder_input_ids.cuda() | ||||||
|
||||||
kwargs["decoder_input_ids"] = decoder_input_ids | ||||||
pt_logits = pt_model(**kwargs)[0] | ||||||
except Exception: | ||||||
raise e | ||||||
sf_logits = sf_model(**kwargs)[0] | ||||||
|
||||||
torch.testing.assert_close(sf_logits, pt_logits) | ||||||
print(f"Model {model_id} is ok !") | ||||||
|
||||||
|
||||||
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]: | ||||||
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]: | ||||||
try: | ||||||
main_commit = api.list_repo_commits(model_id)[0].commit_id | ||||||
discussions = api.get_repo_discussions(repo_id=model_id) | ||||||
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to list all commit history to get the last one. This information is available as (nit) I would also rename the variable to something like
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed ! (Made send before only adding the revision) |
||||||
discussions = api.get_repo_discussions(repo_id=model_id, revision=revision) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops :( |
||||||
except Exception: | ||||||
return None | ||||||
for discussion in discussions: | ||||||
|
@@ -239,15 +230,15 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss | |||||
return None | ||||||
|
||||||
|
||||||
def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: | ||||||
def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: | ||||||
operations = [] | ||||||
errors = [] | ||||||
|
||||||
extensions = set([".bin", ".ckpt"]) | ||||||
for filename in filenames: | ||||||
prefix, ext = os.path.splitext(filename) | ||||||
if ext in extensions: | ||||||
pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder) | ||||||
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder) | ||||||
dirname, raw_filename = os.path.split(filename) | ||||||
if raw_filename == "pytorch_model.bin": | ||||||
# XXX: This is a special case to handle `transformers` and the | ||||||
|
@@ -257,25 +248,25 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti | |||||
sf_in_repo = f"{prefix}.safetensors" | ||||||
sf_filename = os.path.join(folder, sf_in_repo) | ||||||
try: | ||||||
convert_file(pt_filename, sf_filename) | ||||||
convert_file(pt_filename, sf_filename, discard_names=[]) | ||||||
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)) | ||||||
except Exception as e: | ||||||
errors.append((pt_filename, e)) | ||||||
return operations, errors | ||||||
|
||||||
|
||||||
def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: | ||||||
def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: | ||||||
pr_title = "Adding `safetensors` variant of this model" | ||||||
info = api.model_info(model_id) | ||||||
info = api.model_info(model_id, revision=revision) | ||||||
filenames = set(s.rfilename for s in info.siblings) | ||||||
|
||||||
with TemporaryDirectory() as d: | ||||||
with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) you can use (nit 2) I would set a default directory in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. I removed that part actually to keep everything in a real temporary folder ( I had set this up for testing to use a real SSD on the machine I was using (/tmp was mounted to a slower disk) |
||||||
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) | ||||||
os.makedirs(folder) | ||||||
new_pr = None | ||||||
try: | ||||||
operations = None | ||||||
pr = previous_pr(api, model_id, pr_title) | ||||||
pr = previous_pr(api, model_id, pr_title, revision=revision) | ||||||
|
||||||
library_name = getattr(info, "library_name", None) | ||||||
if any(filename.endswith(".safetensors") for filename in filenames) and not force: | ||||||
|
@@ -285,19 +276,21 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn | |||||
new_pr = pr | ||||||
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}") | ||||||
elif library_name == "transformers": | ||||||
|
||||||
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token) | ||||||
if "pytorch_model.bin" in filenames: | ||||||
operations, errors = convert_single(model_id, folder, token=api.token) | ||||||
operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) | ||||||
elif "pytorch_model.bin.index.json" in filenames: | ||||||
operations, errors = convert_multi(model_id, folder, token=api.token) | ||||||
operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) | ||||||
else: | ||||||
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") | ||||||
check_final_model(model_id, folder, token=api.token) | ||||||
else: | ||||||
operations, errors = convert_generic(model_id, folder, filenames, token=api.token) | ||||||
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token) | ||||||
|
||||||
if operations: | ||||||
new_pr = api.create_commit( | ||||||
repo_id=model_id, | ||||||
revision=revision, | ||||||
operations=operations, | ||||||
commit_message=pr_title, | ||||||
commit_description=COMMIT_DESCRIPTION, | ||||||
|
@@ -324,6 +317,11 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn | |||||
type=str, | ||||||
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--revision", | ||||||
type=str, | ||||||
help="The revision to convert", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--force", | ||||||
action="store_true", | ||||||
|
@@ -346,26 +344,17 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn | |||||
" Continue [Y/n] ?" | ||||||
) | ||||||
if txt.lower() in {"", "y"}: | ||||||
try: | ||||||
commit_info, errors = convert(api, model_id, force=args.force) | ||||||
string = f""" | ||||||
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force) | ||||||
string = f""" | ||||||
### Success 🔥 | ||||||
Yay! This model was successfully converted and a PR was open using your token, here: | ||||||
[{commit_info.pr_url}]({commit_info.pr_url}) | ||||||
""" | ||||||
if errors: | ||||||
string += "\nErrors during conversion:\n" | ||||||
string += "\n".join( | ||||||
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors | ||||||
) | ||||||
print(string) | ||||||
except Exception as e: | ||||||
print( | ||||||
f""" | ||||||
### Error 😢😢😢 | ||||||
|
||||||
{e} | ||||||
""" | ||||||
""" | ||||||
if errors: | ||||||
string += "\nErrors during conversion:\n" | ||||||
string += "\n".join( | ||||||
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors | ||||||
) | ||||||
print(string) | ||||||
else: | ||||||
print(f"Answer was `{txt}` aborting.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
revision
parameter to downloadpt_filename
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot ! Thanks