Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better convert. #384

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 115 additions & 126 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import os
import shutil
from collections import defaultdict
from inspect import signature
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Set, Tuple

import torch

from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete


COMMIT_DESCRIPTION = """
Expand All @@ -34,20 +32,78 @@

ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]

def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)

shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)

class AlreadyExists(Exception):
pass
keep_name = sorted(list(complete_names))[0]

# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]

if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove

def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
try:
import transformers
import json

config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
)
with open(config_filename, "r") as f:
config = json.load(f)
architecture = config["architectures"][0]

class_ = getattr(transformers, architecture)

def shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for ptr, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])

except Exception as e:
discard_names = []
return discard_names

class AlreadyExists(Exception):
pass


def check_file_size(sf_filename: str, pt_filename: str):
Expand All @@ -70,8 +126,8 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
with open(filename, "r") as f:
data = json.load(f)

Expand All @@ -82,7 +138,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio

sf_filename = rename(pt_filename)
sf_filename = os.path.join(folder, sf_filename)
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names=discard_names)
local_filenames.append(sf_filename)

index = os.path.join(folder, "model.safetensors.index.json")
Expand All @@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio
return operations, errors


def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use revision parameter to download pt_filename?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot ! Thanks


sf_name = "model.safetensors"
sf_filename = os.path.join(folder, sf_name)
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names)
operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
errors: List[Tuple[str, "Exception"]] = []
return operations, errors
Expand All @@ -115,21 +171,25 @@ def convert_single(model_id: str, folder: str, token: Optional[str]) -> Conversi
def convert_file(
pt_filename: str,
sf_filename: str,
discard_names: List[str],
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)

# For tensors to be contiguous
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)

metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
save_file(loaded, sf_filename, metadata=metadata)
check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
Expand All @@ -155,79 +215,10 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
return "\n".join(errors)


def check_final_model(model_id: str, folder: str, token: Optional[str]):
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder)
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

import transformers

class_ = getattr(transformers, config.architectures[0])
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)

if pt_infos != sf_infos:
error_string = create_diff(pt_infos, sf_infos)
raise ValueError(f"Different infos when reloading the model: {error_string}")

pt_params = pt_model.state_dict()
sf_params = sf_model.state_dict()

pt_shared = shared_pointers(pt_params)
sf_shared = shared_pointers(sf_params)
if pt_shared != sf_shared:
raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")

sig = signature(pt_model.forward)
input_ids = torch.arange(10).unsqueeze(0)
pixel_values = torch.randn(1, 3, 224, 224)
input_values = torch.arange(1000).float().unsqueeze(0)
# Hardcoded for whisper basically
input_features = torch.zeros((1, 80, 3000))
kwargs = {}
if "input_ids" in sig.parameters:
kwargs["input_ids"] = input_ids
if "input_features" in sig.parameters:
kwargs["input_features"] = input_features
if "decoder_input_ids" in sig.parameters:
kwargs["decoder_input_ids"] = input_ids
if "pixel_values" in sig.parameters:
kwargs["pixel_values"] = pixel_values
if "input_values" in sig.parameters:
kwargs["input_values"] = input_values
if "bbox" in sig.parameters:
kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
if "image" in sig.parameters:
kwargs["image"] = pixel_values

if torch.cuda.is_available():
pt_model = pt_model.cuda()
sf_model = sf_model.cuda()
kwargs = {k: v.cuda() for k, v in kwargs.items()}

try:
pt_logits = pt_model(**kwargs)[0]
except Exception as e:
try:
# Musicgen special exception.
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
if torch.cuda.is_available():
decoder_input_ids = decoder_input_ids.cuda()

kwargs["decoder_input_ids"] = decoder_input_ids
pt_logits = pt_model(**kwargs)[0]
except Exception:
raise e
sf_logits = sf_model(**kwargs)[0]

torch.testing.assert_close(sf_logits, pt_logits)
print(f"Model {model_id} is ok !")


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
try:
main_commit = api.list_repo_commits(model_id)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id)
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to list all commit history to get the last one. This information is available as model_info(...).sha

(nit) I would also rename the variable to something like revision_commit instead of main_commit (since pulling from revision and not main)

Suggested change
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
revision_commit = api.model_info(model_id, revision=revision).sha

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed ! (Made send before only adding the revision)

discussions = api.get_repo_discussions(repo_id=model_id, revision=revision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_repo_discussions lists all discussions/PRs in the Community Tab. It doesn't have a revision parameter => previous_pr will always fail (return None)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops :(

except Exception:
return None
for discussion in discussions:
Expand All @@ -239,15 +230,15 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
return None


def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
operations = []
errors = []

extensions = set([".bin", ".ckpt"])
for filename in filenames:
prefix, ext = os.path.splitext(filename)
if ext in extensions:
pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder)
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
dirname, raw_filename = os.path.split(filename)
if raw_filename == "pytorch_model.bin":
# XXX: This is a special case to handle `transformers` and the
Expand All @@ -257,25 +248,25 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti
sf_in_repo = f"{prefix}.safetensors"
sf_filename = os.path.join(folder, sf_in_repo)
try:
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names=[])
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
except Exception as e:
errors.append((pt_filename, e))
return operations, errors


def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id)
info = api.model_info(model_id, revision=revision)
filenames = set(s.rfilename for s in info.siblings)

with TemporaryDirectory() as d:
with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) you can use huggingface_hub.constants.HF_HOME to retrieve the user hf home (will check for HF_HOME or XDG_CACHE_HOME env variable + default to ~/.cache/huggingface)

(nit 2) I would set a default directory in HF_HOME + "/safetensors_converter" just in case the converter crash without cleaning the folder afterwards (at least all tmp directories will be in the same place)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I removed that part actually to keep everything in a real temporary folder (/tmp). It prevents cache reuse, but should make the OS clean up correctly.

I had set this up for testing to use a real SSD on the machine I was using (/tmp was mounted to a slower disk)

folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
operations = None
pr = previous_pr(api, model_id, pr_title)
pr = previous_pr(api, model_id, pr_title, revision=revision)

library_name = getattr(info, "library_name", None)
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
Expand All @@ -285,19 +276,21 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
new_pr = pr
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
elif library_name == "transformers":

discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
if "pytorch_model.bin" in filenames:
operations, errors = convert_single(model_id, folder, token=api.token)
operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
elif "pytorch_model.bin.index.json" in filenames:
operations, errors = convert_multi(model_id, folder, token=api.token)
operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
check_final_model(model_id, folder, token=api.token)
else:
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token)

if operations:
new_pr = api.create_commit(
repo_id=model_id,
revision=revision,
operations=operations,
commit_message=pr_title,
commit_description=COMMIT_DESCRIPTION,
Expand All @@ -324,6 +317,11 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
parser.add_argument(
"--revision",
type=str,
help="The revision to convert",
)
parser.add_argument(
"--force",
action="store_true",
Expand All @@ -346,26 +344,17 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
" Continue [Y/n] ?"
)
if txt.lower() in {"", "y"}:
try:
commit_info, errors = convert(api, model_id, force=args.force)
string = f"""
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
string = f"""
### Success 🔥
Yay! This model was successfully converted and a PR was open using your token, here:
[{commit_info.pr_url}]({commit_info.pr_url})
"""
if errors:
string += "\nErrors during conversion:\n"
string += "\n".join(
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
)
print(string)
except Exception as e:
print(
f"""
### Error 😢😢😢

{e}
"""
"""
if errors:
string += "\nErrors during conversion:\n"
string += "\n".join(
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
)
print(string)
else:
print(f"Answer was `{txt}` aborting.")
Loading