diff --git a/.gitignore b/.gitignore
index 9ca20c67..46594cae 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,7 +20,7 @@ parts/
sdist/
var/
wheels/
-models/
+hordelib/models/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py
index 6dbc9f40..996132b8 100644
--- a/hordelib/comfy_horde.py
+++ b/hordelib/comfy_horde.py
@@ -410,6 +410,14 @@ def _set_comfyui_paths(self) -> None:
_comfy_supported_pt_extensions,
)
+ _comfy_folder_names_and_paths["facerestore_models"] = (
+ [
+ str(UserSettings.get_model_directory() / "gfpgan"),
+ str(UserSettings.get_model_directory() / "codeformer"),
+ ],
+ _comfy_supported_pt_extensions,
+ )
+
_comfy_folder_names_and_paths["controlnet"] = (
[
_comfy_folder_names_and_paths["controlnet"][0][0],
diff --git a/hordelib/horde.py b/hordelib/horde.py
index ddad5ae4..404dea7e 100644
--- a/hordelib/horde.py
+++ b/hordelib/horde.py
@@ -214,7 +214,7 @@ class HordeLib:
}
# pipeline parameter <- hordelib payload parameter mapping
- PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING = { # FIXME
+ PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING: dict[str, str | Callable] = { # FIXME
"sampler.sampler_name": "sampler_name",
"sampler.cfg": "cfg_scale",
"sampler.denoise": "denoising_strength",
@@ -820,6 +820,9 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
# values for steps on things like stable cascade
if isinstance(key, FunctionType):
pipeline_params[newkey] = key(payload)
+ elif not isinstance(key, str):
+ logger.error(f"Invalid key {key}")
+ raise RuntimeError(f"Invalid key {key}")
elif "*" in key:
key, multiplier = key.split("*", 1)
elif key in payload:
diff --git a/hordelib/nodes/facerestore/__init__.py b/hordelib/nodes/facerestore/__init__.py
deleted file mode 100644
index b2fdac14..00000000
--- a/hordelib/nodes/facerestore/__init__.py
+++ /dev/null
@@ -1,204 +0,0 @@
-import os
-import model_management
-import torch
-import comfy.utils
-import numpy as np
-import cv2
-import math
-from hordelib.nodes.facerestore.facelib.utils.face_restoration_helper import FaceRestoreHelper
-from hordelib.nodes.facerestore.facelib.detection.retinaface import retinaface
-from torchvision.transforms.functional import normalize
-import threading
-from loguru import logger
-
-
-def img2tensor(imgs, bgr2rgb=True, float32=True):
- """Numpy array to tensor.
-
- Args:
- imgs (list[ndarray] | ndarray): Input images.
- bgr2rgb (bool): Whether to change bgr to rgb.
- float32 (bool): Whether to change to float32.
-
- Returns:
- list[tensor] | tensor: Tensor images. If returned results only have
- one element, just return tensor.
- """
-
- def _totensor(img, bgr2rgb, float32):
- if img.shape[2] == 3 and bgr2rgb:
- if img.dtype == "float64":
- img = img.astype("float32")
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = torch.from_numpy(img.transpose(2, 0, 1))
- if float32:
- img = img.float()
- return img
-
- if isinstance(imgs, list):
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
- else:
- return _totensor(imgs, bgr2rgb, float32)
-
-
-def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
- """Convert torch Tensors into image numpy arrays.
-
- After clamping to [min, max], values will be normalized to [0, 1].
-
- Args:
- tensor (Tensor or list[Tensor]): Accept shapes:
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
- 2) 3D Tensor of shape (3/1 x H x W);
- 3) 2D Tensor of shape (H x W).
- Tensor channel should be in RGB order.
- rgb2bgr (bool): Whether to change rgb to bgr.
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
- to uint8 type with range [0, 255]; otherwise, float type with
- range [0, 1]. Default: ``np.uint8``.
- min_max (tuple[int]): min and max values for clamp.
-
- Returns:
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
- shape (H x W). The channel order is BGR.
- """
- if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
- raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
-
- if torch.is_tensor(tensor):
- tensor = [tensor]
- result = []
- for _tensor in tensor:
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
-
- n_dim = _tensor.dim()
- if n_dim == 4:
- img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
- img_np = img_np.transpose(1, 2, 0)
- if rgb2bgr:
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
- elif n_dim == 3:
- img_np = _tensor.numpy()
- img_np = img_np.transpose(1, 2, 0)
- if img_np.shape[2] == 1: # gray image
- img_np = np.squeeze(img_np, axis=2)
- else:
- if rgb2bgr:
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
- elif n_dim == 2:
- img_np = _tensor.numpy()
- else:
- raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}")
- if out_type == np.uint8:
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
- img_np = (img_np * 255.0).round()
- img_np = img_np.astype(out_type)
- result.append(img_np)
- if len(result) == 1:
- result = result[0]
- return result
-
-
-class FaceRestoreWithModel:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "upscale_model": ("UPSCALE_MODEL",),
- "image": ("IMAGE",),
- "facedetection": (
- [
- "retinaface_resnet50",
- "retinaface_mobile0.25",
- "YOLOv5l",
- "YOLOv5n",
- ],
- ),
- }
- }
-
- RETURN_TYPES = ("IMAGE",)
-
- FUNCTION = "restore_face"
-
- CATEGORY = "facerestore"
-
- def restore_face(self, upscale_model, image, facedetection):
- # logger.warning(f"mutex:{id(FaceRestoreWithModel._mutex):x} Facerestore with upscale_model {id(upscale_model):x} and detection model {id(facedetection):x} and image {id(image):x}")
- # with FaceRestoreWithModel._mutex:
- # facedetection = copy.deepcopy(facedetection)
-
- device = model_management.get_torch_device()
- upscale_model.to(device)
- face_helper = FaceRestoreHelper(
- 1,
- face_size=512,
- crop_ratio=(1, 1),
- det_model=facedetection,
- save_ext="png",
- use_parse=True,
- device=device,
- )
-
- image_np = 255.0 * image.cpu().numpy().squeeze()
-
- image_np = image_np[:, :, ::-1]
-
- original_resolution = image_np.shape[0:2]
-
- if upscale_model is None or face_helper is None:
- return image
-
- face_helper.clean_all()
- face_helper.read_image(image_np)
- face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- face_helper.align_warp_face()
- restored_face = None
-
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
- cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
-
- try:
- with torch.no_grad():
- # output = upscale_model(cropped_face_t, w=strength, adain=True)[0]
- output = upscale_model(cropped_face_t)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- # torch.cuda.empty_cache()
- except Exception as error:
- logger.error(f"Failed inference for CodeFormer: {error}")
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
-
- restored_face = restored_face.astype("uint8")
- face_helper.add_restored_face(restored_face)
-
- face_helper.get_inverse_affine(None)
-
- restored_img = face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
-
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(
- restored_img,
- (0, 0),
- fx=original_resolution[1] / restored_img.shape[1],
- fy=original_resolution[0] / restored_img.shape[0],
- interpolation=cv2.INTER_LINEAR,
- )
-
- face_helper.clean_all()
-
- # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)
-
- restored_img_np = np.array(restored_img).astype(np.float32) / 255.0
- restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0)
-
- return (restored_img_tensor,)
-
-
-NODE_CLASS_MAPPINGS = {
- "FaceRestoreWithModel": FaceRestoreWithModel,
-}
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py b/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py
deleted file mode 100644
index 07e780c7..00000000
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import torch
-import sys
-sys.path.insert(0,'./facelib/detection/yolov5face')
-model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
-torch.save(model.state_dict(),'../../models/facedetection')
diff --git a/hordelib/nodes/facerestore/facelib/utils/__init__.py b/hordelib/nodes/facerestore/facelib/utils/__init__.py
deleted file mode 100644
index 23ef0352..00000000
--- a/hordelib/nodes/facerestore/facelib/utils/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
-from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
-
-__all__ = [
- 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
- 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
-]
diff --git a/hordelib/nodes/facerestore_cf/LICENSE b/hordelib/nodes/facerestore_cf/LICENSE
new file mode 100644
index 00000000..f288702d
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/hordelib/nodes/facerestore_cf/README.md b/hordelib/nodes/facerestore_cf/README.md
new file mode 100644
index 00000000..abe522cb
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/README.md
@@ -0,0 +1 @@
+Packaged code in this directory (unless stated otherwise) licensed under GPL and sourced from https://github.com/mav-rik/facerestore_cf. See LICENSE for more information.
diff --git a/hordelib/nodes/facerestore_cf/__init__.py b/hordelib/nodes/facerestore_cf/__init__.py
new file mode 100644
index 00000000..81c6c1c6
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/__init__.py
@@ -0,0 +1,333 @@
+import math
+import os
+import sys
+
+import comfy.utils
+import cv2
+import folder_paths
+import model_management
+import numpy as np
+import torch
+# from comfy_extras.chainner_models import model_loading
+from hordelib.nodes.facerestore_cf.r_chainner import model_loading
+from torchvision.transforms.functional import normalize
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+from hordelib.nodes.facerestore_cf.facelib.detection.retinaface import retinaface
+from hordelib.nodes.facerestore_cf.facelib.utils.face_restoration_helper import FaceRestoreHelper
+
+# import codeformer_arch
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}")
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+class FaceRestoreCFWithModel:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "facerestore_model": ("FACERESTORE_MODEL",),
+ "image": ("IMAGE",),
+ "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
+ "codeformer_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+
+ FUNCTION = "restore_face"
+
+ CATEGORY = "facerestore_cf"
+
+ def __init__(self):
+ self.face_helper = None
+
+ def restore_face(self, facerestore_model, image, facedetection, codeformer_fidelity):
+ print(f"\tStarting restore_face with codeformer_fidelity: {codeformer_fidelity}")
+ device = model_management.get_torch_device()
+ facerestore_model.to(device)
+ if self.face_helper is None:
+ self.face_helper = FaceRestoreHelper(
+ 1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=facedetection,
+ save_ext="png",
+ use_parse=True,
+ device=device,
+ )
+
+ image_np = 255.0 * image.cpu().numpy()
+
+ total_images = image_np.shape[0]
+ out_images = np.ndarray(shape=image_np.shape)
+
+ for i in range(total_images):
+ cur_image_np = image_np[i, :, :, ::-1]
+
+ original_resolution = cur_image_np.shape[0:2]
+
+ if facerestore_model is None or self.face_helper is None:
+ return image
+
+ self.face_helper.clean_all()
+ self.face_helper.read_image(cur_image_np)
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ self.face_helper.align_warp_face()
+
+ restored_face = None
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ # output = facerestore_model(cropped_face_t, w=strength, adain=True)[0]
+ # output = facerestore_model(cropped_face_t)[0]
+ output = facerestore_model(cropped_face_t, w=codeformer_fidelity)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f"\tFailed inference for CodeFormer: {error}", file=sys.stderr)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype("uint8")
+ self.face_helper.add_restored_face(restored_face)
+
+ self.face_helper.get_inverse_affine(None)
+
+ restored_img = self.face_helper.paste_faces_to_input_image()
+ restored_img = restored_img[:, :, ::-1]
+
+ if original_resolution != restored_img.shape[0:2]:
+ restored_img = cv2.resize(
+ restored_img,
+ (0, 0),
+ fx=original_resolution[1] / restored_img.shape[1],
+ fy=original_resolution[0] / restored_img.shape[0],
+ interpolation=cv2.INTER_LINEAR,
+ )
+
+ self.face_helper.clean_all()
+
+ # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)
+
+ out_images[i] = restored_img
+
+ restored_img_np = np.array(out_images).astype(np.float32) / 255.0
+ restored_img_tensor = torch.from_numpy(restored_img_np)
+ return (restored_img_tensor,)
+
+
+class CropFace:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+
+ FUNCTION = "crop_face"
+
+ CATEGORY = "facerestore_cf"
+
+ def __init__(self):
+ self.face_helper = None
+
+ def crop_face(self, image, facedetection):
+ device = model_management.get_torch_device()
+ if self.face_helper is None:
+ self.face_helper = FaceRestoreHelper(
+ 1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=facedetection,
+ save_ext="png",
+ use_parse=True,
+ device=device,
+ )
+
+ image_np = 255.0 * image.cpu().numpy()
+
+ total_images = image_np.shape[0]
+ out_images = np.ndarray(shape=(total_images, 512, 512, 3))
+ next_idx = 0
+
+ for i in range(total_images):
+
+ cur_image_np = image_np[i, :, :, ::-1]
+
+ original_resolution = cur_image_np.shape[0:2]
+
+ if self.face_helper is None:
+ return image
+
+ self.face_helper.clean_all()
+ self.face_helper.read_image(cur_image_np)
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ self.face_helper.align_warp_face()
+
+ faces_found = len(self.face_helper.cropped_faces)
+ if faces_found == 0:
+ next_idx += 1 # output black image for no face
+ if out_images.shape[0] < next_idx + faces_found:
+ print(out_images.shape)
+ print((next_idx + faces_found, 512, 512, 3))
+ print("aaaaa")
+ out_images = np.resize(out_images, (next_idx + faces_found, 512, 512, 3))
+ print(out_images.shape)
+ for j in range(faces_found):
+ cropped_face_1 = self.face_helper.cropped_faces[j]
+ cropped_face_2 = img2tensor(cropped_face_1 / 255.0, bgr2rgb=True, float32=True)
+ normalize(cropped_face_2, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_3 = cropped_face_2.unsqueeze(0).to(device)
+ cropped_face_4 = tensor2img(cropped_face_3, rgb2bgr=True, min_max=(-1, 1)).astype("uint8")
+ cropped_face_5 = cv2.cvtColor(cropped_face_4, cv2.COLOR_BGR2RGB)
+ out_images[next_idx] = cropped_face_5
+ next_idx += 1
+
+ cropped_face_6 = np.array(out_images).astype(np.float32) / 255.0
+ cropped_face_7 = torch.from_numpy(cropped_face_6)
+ return (cropped_face_7,)
+
+
+class FaceRestoreModelLoader:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model_name": (folder_paths.get_filename_list("facerestore_models"),),
+ },
+ }
+
+ RETURN_TYPES = ("FACERESTORE_MODEL",)
+ FUNCTION = "load_model"
+
+ CATEGORY = "facerestore_cf"
+
+ # def load_model(self, model_name):
+ # model_path = folder_paths.get_full_path("facerestore_models", model_name)
+ # sd = comfy.utils.load_torch_file(model_path, safe_load=True)
+ # out = model_loading.load_state_dict(sd).eval()
+ # return (out, )
+
+ def load_model(self, model_name):
+ if "codeformer" in model_name.lower():
+ print(f"\tLoading CodeFormer: {model_name}")
+ model_path = folder_paths.get_full_path("facerestore_models", model_name)
+ device = model_management.get_torch_device()
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(device)
+ checkpoint = torch.load(model_path)["params_ema"]
+ codeformer_net.load_state_dict(checkpoint)
+ out = codeformer_net.eval()
+ return (out,)
+ else:
+ model_path = folder_paths.get_full_path("facerestore_models", model_name)
+ sd = comfy.utils.load_torch_file(model_path, safe_load=True)
+ out = model_loading.load_state_dict(sd).eval()
+ return (out,)
+
+
+NODE_CLASS_MAPPINGS = {
+ "FaceRestoreCFWithModel": FaceRestoreCFWithModel,
+ "CropFace": CropFace,
+ "FaceRestoreModelLoader": FaceRestoreModelLoader,
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/VERSION b/hordelib/nodes/facerestore_cf/basicsr/VERSION
new file mode 100644
index 00000000..b85bccc7
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/hordelib/nodes/facerestore_cf/basicsr/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/__init__.py
new file mode 100644
index 00000000..2a06af02
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py
new file mode 100644
index 00000000..41b0cbc3
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py
@@ -0,0 +1,27 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ["build_network"]
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith("_arch.py")]
+# import all the arch modules
+_arch_modules = [
+ importlib.import_module(f"hordelib.nodes.facerestore_cf.basicsr.archs.{file_name}") for file_name in arch_filenames
+]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop("type")
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f"Network [{net.__class__.__name__}] is created.")
+ return net
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py
new file mode 100644
index 00000000..b2627251
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py
@@ -0,0 +1,252 @@
+import torch.nn as nn
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+ """A simple wrapper for 3x3 convolution with padding.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ outplanes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ """
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ """Basic residual block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class IRBlock(nn.Module):
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+ super(IRBlock, self).__init__()
+ self.bn0 = nn.BatchNorm2d(inplanes)
+ self.conv1 = conv3x3(inplanes, inplanes)
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.prelu = nn.PReLU()
+ self.conv2 = conv3x3(inplanes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.use_se = use_se
+ if self.use_se:
+ self.se = SEBlock(planes)
+
+ def forward(self, x):
+ residual = x
+ out = self.bn0(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.use_se:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.prelu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+
+ expansion = 4 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class SEBlock(nn.Module):
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+ Args:
+ channel (int): Channel number of inputs.
+ reduction (int): Channel reduction ration. Default: 16.
+ """
+
+ def __init__(self, channel, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.PReLU(),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+ """ArcFace with ResNet architectures.
+
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+ Args:
+ block (str): Block used in the ArcFace architecture.
+ layers (tuple(int)): Block numbers in each layer.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ def __init__(self, block, layers, use_se=True):
+ if block == "IRBlock":
+ block = IRBlock
+ self.inplanes = 64
+ self.use_se = use_se
+ super(ResNetArcFace, self).__init__()
+
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.prelu = nn.PReLU()
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.bn4 = nn.BatchNorm2d(512)
+ self.dropout = nn.Dropout()
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
+ self.bn5 = nn.BatchNorm1d(512)
+
+ # initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+ self.inplanes = planes
+ for _ in range(1, num_blocks):
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn4(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc5(x)
+ x = self.bn5(x)
+
+ return x
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py b/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py
new file mode 100644
index 00000000..a40bf5f2
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py
@@ -0,0 +1,342 @@
+import collections.abc
+import math
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from hordelib.nodes.facerestore_cf.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f"scale {scale} is not supported. Supported scales: 2^n and 3.")
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros", align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode="bilinear", align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == "ratio":
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == "shape":
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f"Size type should be ratio or shape, but got type {size_type}.")
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow,
+ size=(output_h, output_w),
+ mode=interp_mode,
+ align_corners=align_corners,
+ )
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f"Offset abs mean is {offset_absmean}, larger than 50.")
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion("0.9.0"):
+ return torchvision.ops.deform_conv2d(
+ x,
+ offset,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ mask,
+ )
+ else:
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py
new file mode 100644
index 00000000..e70ae4c0
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py
@@ -0,0 +1,291 @@
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from hordelib.nodes.facerestore_cf.basicsr.archs.vqgan_arch import *
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, "The input feature should be 4D tensor."
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4,
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4,
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Tensor | None):
+ return tensor if pos is None else tensor + pos
+
+ def forward(
+ self,
+ tgt,
+ tgt_mask: Tensor | None = None,
+ tgt_key_padding_mask: Tensor | None = None,
+ query_pos: Tensor | None = None,
+ ):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2 * in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
+ )
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
+ )
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(
+ self,
+ dim_embd=512,
+ n_head=8,
+ n_layers=9,
+ codebook_size=1024,
+ latent_size=256,
+ connect_list=["32", "64", "128", "256"],
+ fix_modules=["quantize", "generator"],
+ ):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd * 2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(
+ *[
+ TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)
+ ],
+ )
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ "16": 512,
+ "32": 256,
+ "64": 256,
+ "128": 128,
+ "256": 128,
+ "512": 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {"512": 2, "256": 5, "128": 8, "64": 11, "32": 14, "16": 18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {"16": 6, "32": 9, "64": 12, "128": 15, "256": 18, "512": 21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0], 16, 16, 256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w > 0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 00000000..bc0d00ab
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,120 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py
new file mode 100644
index 00000000..fc4e6e7f
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py
@@ -0,0 +1,264 @@
+import os
+from collections import OrderedDict
+
+import torch
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth"
+NAMES = {
+ "vgg11": [
+ "conv1_1",
+ "relu1_1",
+ "pool1",
+ "conv2_1",
+ "relu2_1",
+ "pool2",
+ "conv3_1",
+ "relu3_1",
+ "conv3_2",
+ "relu3_2",
+ "pool3",
+ "conv4_1",
+ "relu4_1",
+ "conv4_2",
+ "relu4_2",
+ "pool4",
+ "conv5_1",
+ "relu5_1",
+ "conv5_2",
+ "relu5_2",
+ "pool5",
+ ],
+ "vgg13": [
+ "conv1_1",
+ "relu1_1",
+ "conv1_2",
+ "relu1_2",
+ "pool1",
+ "conv2_1",
+ "relu2_1",
+ "conv2_2",
+ "relu2_2",
+ "pool2",
+ "conv3_1",
+ "relu3_1",
+ "conv3_2",
+ "relu3_2",
+ "pool3",
+ "conv4_1",
+ "relu4_1",
+ "conv4_2",
+ "relu4_2",
+ "pool4",
+ "conv5_1",
+ "relu5_1",
+ "conv5_2",
+ "relu5_2",
+ "pool5",
+ ],
+ "vgg16": [
+ "conv1_1",
+ "relu1_1",
+ "conv1_2",
+ "relu1_2",
+ "pool1",
+ "conv2_1",
+ "relu2_1",
+ "conv2_2",
+ "relu2_2",
+ "pool2",
+ "conv3_1",
+ "relu3_1",
+ "conv3_2",
+ "relu3_2",
+ "conv3_3",
+ "relu3_3",
+ "pool3",
+ "conv4_1",
+ "relu4_1",
+ "conv4_2",
+ "relu4_2",
+ "conv4_3",
+ "relu4_3",
+ "pool4",
+ "conv5_1",
+ "relu5_1",
+ "conv5_2",
+ "relu5_2",
+ "conv5_3",
+ "relu5_3",
+ "pool5",
+ ],
+ "vgg19": [
+ "conv1_1",
+ "relu1_1",
+ "conv1_2",
+ "relu1_2",
+ "pool1",
+ "conv2_1",
+ "relu2_1",
+ "conv2_2",
+ "relu2_2",
+ "pool2",
+ "conv3_1",
+ "relu3_1",
+ "conv3_2",
+ "relu3_2",
+ "conv3_3",
+ "relu3_3",
+ "conv3_4",
+ "relu3_4",
+ "pool3",
+ "conv4_1",
+ "relu4_1",
+ "conv4_2",
+ "relu4_2",
+ "conv4_3",
+ "relu4_3",
+ "conv4_4",
+ "relu4_4",
+ "pool4",
+ "conv5_1",
+ "relu5_1",
+ "conv5_2",
+ "relu5_2",
+ "conv5_3",
+ "relu5_3",
+ "conv5_4",
+ "relu5_4",
+ "pool5",
+ ],
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if "conv" in name:
+ position = name.replace("conv", "")
+ names_bn.append("bn" + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(
+ self,
+ layer_name_list,
+ vgg_type="vgg19",
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2,
+ ):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace("_bn", "")]
+ if "bn" in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[: max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features, strict=False):
+ if "pool" in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py
new file mode 100644
index 00000000..58ca0d69
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,456 @@
+"""
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY
+
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (
+ (z_flattened**2).sum(dim=1, keepdim=True)
+ + (self.embedding.weight**2).sum(1)
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+ )
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ min_encoding_scores = torch.exp(-min_encoding_scores / 10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return (
+ z_q,
+ loss,
+ {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "min_encoding_scores": min_encoding_scores,
+ "mean_distance": mean_distance,
+ },
+ )
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1, 1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return (
+ z_q,
+ diff,
+ {
+ "min_encoding_indices": min_encoding_indices,
+ },
+ )
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h * w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(
+ self,
+ img_size,
+ nf,
+ ch_mult,
+ quantizer="nearest",
+ res_blocks=2,
+ attn_resolutions=[16],
+ codebook_size=1024,
+ emb_dim=256,
+ beta=0.25,
+ gumbel_straight_through=False,
+ gumbel_kl_weight=1e-8,
+ model_path=None,
+ ):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions,
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta # 0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight,
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions,
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location="cpu")
+ if "params_ema" in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location="cpu")["params_ema"])
+ logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
+ elif "params" in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location="cpu")["params"])
+ logger.info(f"vqgan is loaded from: {model_path} [params]")
+ else:
+ raise ValueError("Wrong params!")
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2**n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2**n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ layers += [nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location="cpu")
+ if "params_d" in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location="cpu")["params_d"])
+ elif "params" in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location="cpu")["params"])
+ else:
+ raise ValueError("Wrong params!")
+
+ def forward(self, x):
+ return self.main(x)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py
new file mode 100644
index 00000000..2dc9effe
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py
@@ -0,0 +1,103 @@
+import importlib
+import random
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+import numpy as np
+import torch
+import torch.utils.data
+
+from hordelib.nodes.facerestore_cf.basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir
+from hordelib.nodes.facerestore_cf.basicsr.utils.dist_util import get_dist_info
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ["build_dataset", "build_dataloader"]
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith("_dataset.py")]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f"basicsr.data.{file_name}") for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must constain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt["type"])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' "is built.")
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt["phase"]
+ rank, _ = get_dist_info()
+ if phase == "train":
+ if dist: # distributed training
+ batch_size = dataset_opt["batch_size_per_gpu"]
+ num_workers = dataset_opt["num_worker_per_gpu"]
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt["batch_size_per_gpu"] * multiplier
+ num_workers = dataset_opt["num_worker_per_gpu"] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True,
+ )
+ if sampler is None:
+ dataloader_args["shuffle"] = True
+ dataloader_args["worker_init_fn"] = (
+ partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ )
+ elif phase in ["val", "test"]: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f"Wrong dataset phase: {phase}. " "Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args["pin_memory"] = dataset_opt.get("pin_memory", False)
+
+ prefetch_mode = dataset_opt.get("prefetch_mode")
+ if prefetch_mode == "cpu": # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get("num_prefetch_queue", 1)
+ logger = get_root_logger()
+ logger.info(f"Use {prefetch_mode} prefetch dataloader: " f"num_prefetch_queue = {num_prefetch_queue}")
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py b/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py
new file mode 100644
index 00000000..e4f2ccf8
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py
@@ -0,0 +1,49 @@
+import math
+
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py b/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py
new file mode 100644
index 00000000..b2481c00
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py
@@ -0,0 +1,313 @@
+from os import path as osp
+
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from hordelib.nodes.facerestore_cf.basicsr.data.transforms import mod_crop
+from hordelib.nodes.facerestore_cf.basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255.0 for v in img_paths]
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding="reflection"):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, "num_frames should be an odd number."
+ assert padding in ("replicate", "reflection", "reflection_circle", "circle"), f"Wrong padding mode: {padding}."
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == "replicate":
+ pad_idx = 0
+ elif padding == "reflection":
+ pad_idx = -i
+ elif padding == "reflection_circle":
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == "replicate":
+ pad_idx = max_frame_num
+ elif padding == "reflection":
+ pad_idx = max_frame_num * 2 - i
+ elif padding == "reflection_circle":
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, (
+ "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}"
+ )
+ assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}"
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith(".lmdb") and gt_folder.endswith(".lmdb")):
+ raise ValueError(
+ f"{input_key} folder and {gt_key} folder should both in lmdb "
+ f"formats. But received {input_key}: {input_folder}; "
+ f"{gt_key}: {gt_folder}",
+ )
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, "meta_info.txt")) as fin:
+ input_lmdb_keys = [line.split(".")[0] for line in fin]
+ with open(osp.join(gt_folder, "meta_info.txt")) as fin:
+ gt_lmdb_keys = [line.split(".")[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f"Keys in {input_key}_folder and {gt_key}_folder are different.")
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f"{input_key}_path", lmdb_key), (f"{gt_key}_path", lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, (
+ "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}"
+ )
+ assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}"
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file) as fin:
+ gt_names = [line.split(" ")[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f"{filename_tmpl.format(basename)}{ext}"
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f"{input_key}_path", input_path), (f"{gt_key}_path", gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, (
+ "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}"
+ )
+ assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}"
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (
+ f"{input_key} and {gt_key} datasets have different number of images: " f"{len(input_paths)}, {len(gt_paths)}."
+ )
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f"{filename_tmpl.format(basename)}{ext}"
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, f"{input_name} is not in " f"{input_key}_paths."
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f"{input_key}_path", input_path), (f"{gt_key}_path", gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith(".lmdb"):
+ raise ValueError(f"Folder {folder}folder should in lmdb format.")
+ with open(osp.join(folder, "meta_info.txt")) as fin:
+ paths = [line.split(".")[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f"Only support scale (2, 3, 4), but got {scale}."
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), "reflect")
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py b/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 00000000..dd84bb15
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,126 @@
+import queue as Queue
+import threading
+
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher:
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher:
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device("cuda" if opt["num_gpu"] != 0 else "cpu")
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py b/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py
new file mode 100644
index 00000000..ac562bff
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py
@@ -0,0 +1,170 @@
+import random
+
+import cv2
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[: h - h_remainder, : w - w_remainder, ...]
+ else:
+ raise ValueError(f"Wrong img ndim: {img.ndim}.")
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+ """Paired random crop.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ h_lq, w_lq, _ = img_lqs[0].shape
+ h_gt, w_gt, _ = img_gts[0].shape
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(
+ f"Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ",
+ f"multiplication of LQ ({h_lq}, {w_lq}).",
+ )
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(
+ f"LQ ({h_lq}, {w_lq}) is smaller than patch size "
+ f"({lq_patch_size}, {lq_patch_size}). "
+ f"Please remove {gt_path}.",
+ )
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ img_lqs = [v[top : top + lq_patch_size, left : left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ img_gts = [v[top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py
new file mode 100644
index 00000000..fb45a76e
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py
@@ -0,0 +1,43 @@
+from copy import deepcopy
+
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import LOSS_REGISTRY
+
+from .losses import (
+ CharbonnierLoss,
+ GANLoss,
+ L1Loss,
+ MSELoss,
+ PerceptualLoss,
+ WeightedTVLoss,
+ g_path_regularize,
+ gradient_penalty_loss,
+ r1_penalty,
+)
+
+__all__ = [
+ "L1Loss",
+ "MSELoss",
+ "CharbonnierLoss",
+ "WeightedTVLoss",
+ "PerceptualLoss",
+ "GANLoss",
+ "gradient_penalty_loss",
+ "r1_penalty",
+ "g_path_regularize",
+]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop("type")
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f"Loss [{loss.__class__.__name__}] is created.")
+ return loss
diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py b/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py
new file mode 100644
index 00000000..b08bcaee
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py
@@ -0,0 +1,96 @@
+import functools
+
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction="mean"):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == "sum":
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == "mean":
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction="mean", **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py b/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py
new file mode 100644
index 00000000..a6c9cd8d
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py
@@ -0,0 +1,470 @@
+import math
+
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from hordelib.nodes.facerestore_cf.basicsr.archs.vgg_arch import VGGFeatureExtractor
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import LOSS_REGISTRY
+
+from .loss_util import weighted_loss
+
+_reduction_modes = ["none", "mean", "sum"]
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction="none")
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction="none")
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target) ** 2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction="mean"):
+ super(L1Loss, self).__init__()
+ if reduction not in ["none", "mean", "sum"]:
+ raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}")
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction="mean"):
+ super(MSELoss, self).__init__()
+ if reduction not in ["none", "mean", "sum"]:
+ raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}")
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero.
+ Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction="mean", eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ["none", "mean", "sum"]:
+ raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}")
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+ def forward(self, pred, weight=None):
+ y_diff = super(WeightedTVLoss, self).forward(
+ pred[:, :, :-1, :],
+ pred[:, :, 1:, :],
+ weight=weight[:, :, :-1, :],
+ )
+ x_diff = super(WeightedTVLoss, self).forward(
+ pred[:, :, :, :-1],
+ pred[:, :, :, 1:],
+ weight=weight[:, :, :, :-1],
+ )
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(
+ self,
+ layer_weights,
+ vgg_type="vgg19",
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.0,
+ criterion="l1",
+ ):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm,
+ )
+
+ self.criterion_type = criterion
+ if self.criterion_type == "l1":
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == "l2":
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == "mse":
+ self.criterion = torch.nn.MSELoss(reduction="mean")
+ elif self.criterion_type == "fro":
+ self.criterion = None
+ else:
+ raise NotImplementedError(f"{criterion} criterion has not been supported.")
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == "fro":
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == "fro":
+ style_loss += (
+ torch.norm(self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro")
+ * self.layer_weights[k]
+ )
+ else:
+ style_loss += (
+ self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]))
+ * self.layer_weights[k]
+ )
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+ def __init__(self, loss_weight=1.0, use_input_norm=True, range_norm=False):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == "vanilla":
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == "lsgan":
+ self.loss = nn.MSELoss()
+ elif self.gan_type == "wgan":
+ self.loss = self._wgan_loss
+ elif self.gan_type == "wgan_softplus":
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == "hinge":
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f"GAN type {self.gan_type} is not implemented.")
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ["wgan", "wgan_softplus"]:
+ return target_is_real
+ target_val = self.real_label_val if target_is_real else self.fake_label_val
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ if self.gan_type == "hinge":
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1.0 - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py
new file mode 100644
index 00000000..5e9ae7c6
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py
@@ -0,0 +1,20 @@
+from copy import deepcopy
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import METRIC_REGISTRY
+
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ["calculate_psnr", "calculate_ssim"]
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop("type")
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py
new file mode 100644
index 00000000..c77e591e
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order="HWC"):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ["HWC", "CHW"]:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " "'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == "CHW":
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.0
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.0
diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 00000000..e7735956
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from hordelib.nodes.facerestore_cf.basicsr.metrics.metric_util import reorder_image, to_y_channel
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order="HWC", test_y_channel=False):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}."
+ if input_order not in ["HWC", "CHW"]:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ mse = np.mean((img1 - img2) ** 2)
+ if mse == 0:
+ return float("inf")
+ return 20.0 * np.log10(255.0 / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: ssim result.
+ """
+
+ C1 = (0.01 * 255) ** 2
+ C2 = (0.03 * 255) ** 2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order="HWC", test_y_channel=False):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the SSIM calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: ssim result.
+ """
+
+ assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}."
+ if input_order not in ["HWC", "CHW"]:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ ssims = []
+ for i in range(img1.shape[2]):
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
+ return np.array(ssims).mean()
diff --git a/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py
new file mode 100644
index 00000000..73c723c6
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir
+from hordelib.nodes.facerestore_cf.basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ["build_model"]
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith("_model.py")]
+# import all the model modules
+_model_modules = [importlib.import_module(f"basicsr.models.{file_name}") for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt["model_type"])(opt)
+ logger = get_root_logger()
+ logger.info(f"Model [{model.__class__.__name__}] is created.")
+ return model
diff --git a/hordelib/nodes/facerestore/facelib/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/__init__.py
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/__init__.py
rename to hordelib/nodes/facerestore_cf/basicsr/ops/__init__.py
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py
new file mode 100644
index 00000000..55b8b8e9
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,17 @@
+from .deform_conv import (
+ DeformConv,
+ DeformConvPack,
+ ModulatedDeformConv,
+ ModulatedDeformConvPack,
+ deform_conv,
+ modulated_deform_conv,
+)
+
+__all__ = [
+ "DeformConv",
+ "DeformConvPack",
+ "ModulatedDeformConv",
+ "ModulatedDeformConvPack",
+ "deform_conv",
+ "modulated_deform_conv",
+]
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 00000000..020c86d3
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,503 @@
+import math
+
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+ from . import deform_conv_ext
+except ImportError:
+ import os
+
+ BASICSR_JIT = os.getenv("BASICSR_JIT")
+ if BASICSR_JIT == "True":
+ from torch.utils.cpp_extension import load
+
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ "deform_conv",
+ sources=[
+ os.path.join(module_path, "src", "deform_conv_ext.cpp"),
+ os.path.join(module_path, "src", "deform_conv_cuda.cpp"),
+ os.path.join(module_path, "src", "deform_conv_cuda_kernel.cu"),
+ ],
+ )
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64,
+ ):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f"Expected 4D tensor as input, got {input.dim()}" "D tensor instead.")
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
+ deform_conv_ext.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ cur_im2col_step,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ cur_im2col_step,
+ )
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ 1,
+ cur_im2col_step,
+ )
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError("convolution input is too small (output would be " f'{"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ ):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ weight.shape[2],
+ weight.shape[3],
+ ctx.stride,
+ ctx.stride,
+ ctx.padding,
+ ctx.padding,
+ ctx.dilation,
+ ctx.dilation,
+ ctx.groups,
+ ctx.deformable_groups,
+ ctx.with_bias,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ weight.shape[2],
+ weight.shape[3],
+ ctx.stride,
+ ctx.stride,
+ ctx.padding,
+ ctx.padding,
+ ctx.dilation,
+ ctx.dilation,
+ ctx.groups,
+ ctx.deformable_groups,
+ ctx.with_bias,
+ )
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False,
+ ):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, f"in_channels {in_channels} is not divisible by groups {groups}"
+ assert out_channels % groups == 0, f"out_channels {out_channels} is not divisible " f"by groups {groups}"
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1.0 / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), "constant", 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), "constant", 0).contiguous()
+ out = deform_conv(
+ x,
+ offset,
+ self.weight,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
+ if input_pad:
+ out = out[:, :, : out.size(2) - pad_h, : out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True,
+ )
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(
+ x,
+ offset,
+ self.weight,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True,
+ ):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter("bias", None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1.0 / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True,
+ )
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, "conv_offset"):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 00000000..6fbef833
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 00000000..9fe9ba3a
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+ scalar_t *grad_mask_ = grad_mask.data_ptr();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp
new file mode 100644
index 00000000..5c21d02c
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp
@@ -0,0 +1,164 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+#define WITH_CUDA // always use cuda
+#ifdef WITH_CUDA
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias);
+#endif
+
+int deform_conv_forward(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_forward_cuda(input, weight, offset, output, columns,
+ ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+ deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_input_cuda(input, offset, gradOutput,
+ gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+ dilationW, dilationH, group, deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_parameters(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+ gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+ dilationH, group, deformable_group, scale, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+ offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+ stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+ deformable_group, with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+ offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+ grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+ pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+ with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward", &deform_conv_forward,
+ "deform forward");
+ m.def("deform_conv_backward_input", &deform_conv_backward_input,
+ "deform_conv_backward_input");
+ m.def("deform_conv_backward_parameters",
+ &deform_conv_backward_parameters,
+ "deform_conv_backward_parameters");
+ m.def("modulated_deform_conv_forward",
+ &modulated_deform_conv_forward,
+ "modulated deform conv forward");
+ m.def("modulated_deform_conv_backward",
+ &modulated_deform_conv_backward,
+ "modulated deform conv backward");
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 00000000..81f5b1e8
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ["FusedLeakyReLU", "fused_leaky_relu"]
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 00000000..bb8be954
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,98 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+try:
+ from . import fused_act_ext
+except ImportError:
+ import os
+
+ BASICSR_JIT = os.getenv("BASICSR_JIT")
+ if BASICSR_JIT == "True":
+ from torch.utils.cpp_extension import load
+
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ "fused",
+ sources=[
+ os.path.join(module_path, "src", "fused_bias_act.cpp"),
+ os.path.join(module_path, "src", "fused_bias_act_kernel.cu"),
+ ],
+ )
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ (out,) = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(
+ gradgrad_input,
+ gradgrad_bias,
+ out,
+ 3,
+ 1,
+ ctx.negative_slope,
+ ctx.scale,
+ )
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (out,) = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp
new file mode 100644
index 00000000..c6225bbc
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp
@@ -0,0 +1,26 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
new file mode 100644
index 00000000..31a536f9
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
@@ -0,0 +1,100 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 00000000..c6fd35e4
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ["upfirdn2d"]
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
new file mode 100644
index 00000000..12b56617
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
@@ -0,0 +1,24 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
new file mode 100644
index 00000000..e82913f5
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
@@ -0,0 +1,370 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 00000000..59a9411e
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,188 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+try:
+ from . import upfirdn2d_ext
+except ImportError:
+ import os
+
+ BASICSR_JIT = os.getenv("BASICSR_JIT")
+ if BASICSR_JIT == "True":
+ from torch.utils.cpp_extension import load
+
+ module_path = os.path.dirname(__file__)
+ upfirdn2d_ext = load(
+ "upfirdn2d",
+ sources=[
+ os.path.join(module_path, "src", "upfirdn2d.cpp"),
+ os.path.join(module_path, "src", "upfirdn2d_kernel.cu"),
+ ],
+ )
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ (kernel,) = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == "cpu":
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ else:
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/setup.py b/hordelib/nodes/facerestore_cf/basicsr/setup.py
new file mode 100644
index 00000000..b8755f97
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/setup.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python
+
+import os
+import subprocess
+import sys
+import time
+
+import torch
+from setuptools import find_packages, setup
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+
+version_file = "./basicsr/version.py"
+
+
+def readme():
+ with open("README.md", encoding="utf-8") as f:
+ content = f.read()
+ return content
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ["SYSTEMROOT", "PATH", "HOME"]:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env["LANGUAGE"] = "C"
+ env["LANG"] = "C"
+ env["LC_ALL"] = "C"
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(["git", "rev-parse", "HEAD"])
+ sha = out.strip().decode("ascii")
+ except OSError:
+ sha = "unknown"
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists(".git"):
+ sha = get_git_hash()[:7]
+ elif os.path.exists(version_file):
+ try:
+ from version import __version__
+
+ sha = __version__.split("+")[-1]
+ except ImportError:
+ raise ImportError("Unable to get git version")
+ else:
+ sha = "unknown"
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open("./basicsr/VERSION") as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ", ".join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split(".")])
+
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+ with open(version_file, "w") as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file) as f:
+ exec(compile(f.read(), version_file, "exec"))
+ return locals()["__version__"]
+
+
+def make_cuda_ext(name, module, sources, sources_cuda=None):
+ if sources_cuda is None:
+ sources_cuda = []
+ define_macros = []
+ extra_compile_args = {"cxx": []}
+
+ if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1":
+ define_macros += [("WITH_CUDA", None)]
+ extension = CUDAExtension
+ extra_compile_args["nvcc"] = [
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ sources += sources_cuda
+ else:
+ print(f"Compiling {name} without CUDA")
+ extension = CppExtension
+
+ return extension(
+ name=f"{module}.{name}",
+ sources=[os.path.join(*module.split("."), p) for p in sources],
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+
+
+def get_requirements(filename="requirements.txt"):
+ with open(os.path.join(".", filename)) as f:
+ requires = [line.replace("\n", "") for line in f.readlines()]
+ return requires
+
+
+if __name__ == "__main__":
+ if "--cuda_ext" in sys.argv:
+ ext_modules = [
+ make_cuda_ext(
+ name="deform_conv_ext",
+ module="ops.dcn",
+ sources=["src/deform_conv_ext.cpp"],
+ sources_cuda=["src/deform_conv_cuda.cpp", "src/deform_conv_cuda_kernel.cu"],
+ ),
+ make_cuda_ext(
+ name="fused_act_ext",
+ module="ops.fused_act",
+ sources=["src/fused_bias_act.cpp"],
+ sources_cuda=["src/fused_bias_act_kernel.cu"],
+ ),
+ make_cuda_ext(
+ name="upfirdn2d_ext",
+ module="ops.upfirdn2d",
+ sources=["src/upfirdn2d.cpp"],
+ sources_cuda=["src/upfirdn2d_kernel.cu"],
+ ),
+ ]
+ sys.argv.remove("--cuda_ext")
+ else:
+ ext_modules = []
+
+ write_version_py()
+ setup(
+ name="basicsr",
+ version=get_version(),
+ description="Open Source Image and Video Super-Resolution Toolbox",
+ long_description=readme(),
+ long_description_content_type="text/markdown",
+ author="Xintao Wang",
+ author_email="xintao.wang@outlook.com",
+ keywords="computer vision, restoration, super resolution",
+ url="https://github.com/xinntao/BasicSR",
+ include_package_data=True,
+ packages=find_packages(exclude=("options", "datasets", "experiments", "results", "tb_logger", "wandb")),
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ ],
+ license="Apache License 2.0",
+ setup_requires=["cython", "numpy"],
+ install_requires=get_requirements(),
+ ext_modules=ext_modules,
+ cmdclass={"build_ext": BuildExtension},
+ zip_safe=False,
+ )
diff --git a/hordelib/nodes/facerestore_cf/basicsr/train.py b/hordelib/nodes/facerestore_cf/basicsr/train.py
new file mode 100644
index 00000000..7e84af31
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/train.py
@@ -0,0 +1,250 @@
+import argparse
+import datetime
+import logging
+import math
+import random
+import time
+import warnings
+from os import path as osp
+
+import torch
+
+from hordelib.nodes.facerestore_cf.basicsr.data import build_dataloader, build_dataset
+from hordelib.nodes.facerestore_cf.basicsr.data.data_sampler import EnlargedSampler
+from hordelib.nodes.facerestore_cf.basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from hordelib.nodes.facerestore_cf.basicsr.models import build_model
+from hordelib.nodes.facerestore_cf.basicsr.utils import (
+ MessageLogger,
+ check_resume,
+ get_env_info,
+ get_root_logger,
+ init_tb_logger,
+ init_wandb_logger,
+ make_exp_dirs,
+ mkdir_and_rename,
+ set_random_seed,
+)
+from hordelib.nodes.facerestore_cf.basicsr.utils.dist_util import get_dist_info, init_dist
+from hordelib.nodes.facerestore_cf.basicsr.utils.options import dict2str, parse
+
+# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-opt", type=str, required=True, help="Path to option YAML file.")
+ parser.add_argument("--launcher", choices=["none", "pytorch", "slurm"], default="none", help="job launcher")
+ parser.add_argument("--local_rank", type=int, default=0)
+ args = parser.parse_args()
+ opt = parse(args.opt, root_path, is_train=is_train)
+
+ # distributed settings
+ if args.launcher == "none":
+ opt["dist"] = False
+ print("Disable distributed.", flush=True)
+ else:
+ opt["dist"] = True
+ if args.launcher == "slurm" and "dist_params" in opt:
+ init_dist(args.launcher, **opt["dist_params"])
+ else:
+ init_dist(args.launcher)
+
+ opt["rank"], opt["world_size"] = get_dist_info()
+
+ # random seed
+ seed = opt.get("manual_seed")
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt["manual_seed"] = seed
+ set_random_seed(seed + opt["rank"])
+
+ return opt
+
+
+def init_loggers(opt):
+ log_file = osp.join(opt["path"]["log"], f"train_{opt['name']}.log")
+ logger = get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # initialize wandb logger before tensorboard logger to allow proper sync:
+ if (opt["logger"].get("wandb") is not None) and (opt["logger"]["wandb"].get("project") is not None):
+ assert opt["logger"].get("use_tb_logger") is True, "should turn on tensorboard when using wandb"
+ init_wandb_logger(opt)
+ tb_logger = None
+ if opt["logger"].get("use_tb_logger"):
+ tb_logger = init_tb_logger(log_dir=osp.join("tb_logger", opt["name"]))
+ return logger, tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+ # create train and val dataloaders
+ train_loader, val_loader = None, None
+ for phase, dataset_opt in opt["datasets"].items():
+ if phase == "train":
+ dataset_enlarge_ratio = dataset_opt.get("dataset_enlarge_ratio", 1)
+ train_set = build_dataset(dataset_opt)
+ train_sampler = EnlargedSampler(train_set, opt["world_size"], opt["rank"], dataset_enlarge_ratio)
+ train_loader = build_dataloader(
+ train_set,
+ dataset_opt,
+ num_gpu=opt["num_gpu"],
+ dist=opt["dist"],
+ sampler=train_sampler,
+ seed=opt["manual_seed"],
+ )
+
+ num_iter_per_epoch = math.ceil(
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt["batch_size_per_gpu"] * opt["world_size"]),
+ )
+ total_iters = int(opt["train"]["total_iter"])
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+ logger.info(
+ "Training statistics:"
+ f"\n\tNumber of train images: {len(train_set)}"
+ f"\n\tDataset enlarge ratio: {dataset_enlarge_ratio}"
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f"\n\tRequire iter number per epoch: {num_iter_per_epoch}"
+ f"\n\tTotal epochs: {total_epochs}; iters: {total_iters}.",
+ )
+
+ elif phase == "val":
+ val_set = build_dataset(dataset_opt)
+ val_loader = build_dataloader(
+ val_set,
+ dataset_opt,
+ num_gpu=opt["num_gpu"],
+ dist=opt["dist"],
+ sampler=None,
+ seed=opt["manual_seed"],
+ )
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f"{len(val_set)}")
+ else:
+ raise ValueError(f"Dataset phase {phase} is not recognized.")
+
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def train_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(root_path, is_train=True)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ if opt["path"].get("resume_state"):
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(
+ opt["path"]["resume_state"],
+ map_location=lambda storage, loc: storage.cuda(device_id),
+ )
+ else:
+ resume_state = None
+
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt["logger"].get("use_tb_logger") and opt["rank"] == 0:
+ mkdir_and_rename(osp.join("tb_logger", opt["name"]))
+
+ # initialize loggers
+ logger, tb_logger = init_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
+
+ # create model
+ if resume_state: # resume training
+ check_resume(opt, resume_state["iter"])
+ model = build_model(opt)
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state["epoch"]
+ current_iter = resume_state["iter"]
+ else:
+ model = build_model(opt)
+ start_epoch = 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ # dataloader prefetcher
+ prefetch_mode = opt["datasets"]["train"].get("prefetch_mode")
+ if prefetch_mode is None or prefetch_mode == "cpu":
+ prefetcher = CPUPrefetcher(train_loader)
+ elif prefetch_mode == "cuda":
+ prefetcher = CUDAPrefetcher(train_loader, opt)
+ logger.info(f"Use {prefetch_mode} prefetch dataloader")
+ if opt["datasets"]["train"].get("pin_memory") is not True:
+ raise ValueError("Please set pin_memory=True for CUDAPrefetcher.")
+ else:
+ raise ValueError(f"Wrong prefetch_mode {prefetch_mode}." "Supported ones are: None, 'cuda', 'cpu'.")
+
+ # training
+ logger.info(f"Start training from epoch: {start_epoch}, iter: {current_iter+1}")
+ data_time, iter_time = time.time(), time.time()
+ start_time = time.time()
+
+ for epoch in range(start_epoch, total_epochs + 1):
+ train_sampler.set_epoch(epoch)
+ prefetcher.reset()
+ train_data = prefetcher.next()
+
+ while train_data is not None:
+ data_time = time.time() - data_time
+
+ current_iter += 1
+ if current_iter > total_iters:
+ break
+ # update learning rate
+ model.update_learning_rate(current_iter, warmup_iter=opt["train"].get("warmup_iter", -1))
+ # training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_iter)
+ iter_time = time.time() - iter_time
+ # log
+ if current_iter % opt["logger"]["print_freq"] == 0:
+ log_vars = {"epoch": epoch, "iter": current_iter}
+ log_vars.update({"lrs": model.get_current_learning_rate()})
+ log_vars.update({"time": iter_time, "data_time": data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ # save models and training states
+ if current_iter % opt["logger"]["save_checkpoint_freq"] == 0:
+ logger.info("Saving models and training states.")
+ model.save(epoch, current_iter)
+
+ # validation
+ if (
+ opt.get("val") is not None
+ and opt["datasets"].get("val") is not None
+ and (current_iter % opt["val"]["val_freq"] == 0)
+ ):
+ model.validation(val_loader, current_iter, tb_logger, opt["val"]["save_img"])
+
+ data_time = time.time()
+ iter_time = time.time()
+ train_data = prefetcher.next()
+ # end of iter
+
+ # end of epoch
+
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+ logger.info(f"End of training. Time consumed: {consumed_time}")
+ logger.info("Save the latest model.")
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
+ if opt.get("val") is not None and opt["datasets"].get("val"):
+ model.validation(val_loader, current_iter, tb_logger, opt["val"]["save_img"])
+ if tb_logger:
+ tb_logger.close()
+
+
+if __name__ == "__main__":
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py
new file mode 100644
index 00000000..b1d60bc8
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py
@@ -0,0 +1,29 @@
+from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+ # file_client.py
+ "FileClient",
+ # img_util.py
+ "img2tensor",
+ "tensor2img",
+ "imfrombytes",
+ "imwrite",
+ "crop_border",
+ # logger.py
+ "MessageLogger",
+ "init_tb_logger",
+ "init_wandb_logger",
+ "get_root_logger",
+ "get_env_info",
+ # misc.py
+ "set_random_seed",
+ "get_time_str",
+ "mkdir_and_rename",
+ "make_exp_dirs",
+ "scandir",
+ "check_resume",
+ "sizeof_fmt",
+]
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py
new file mode 100644
index 00000000..a252a461
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py
@@ -0,0 +1,83 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py
+import functools
+import os
+import subprocess
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend="nccl", **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method("spawn")
+ if launcher == "pytorch":
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == "slurm":
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f"Invalid launcher type: {launcher}")
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ["RANK"])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ["SLURM_PROCID"])
+ ntasks = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ # specify master port
+ if port is not None:
+ os.environ["MASTER_PORT"] = str(port)
+ elif "MASTER_PORT" in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ["MASTER_PORT"] = "29500"
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(ntasks)
+ os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
+ os.environ["RANK"] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py
new file mode 100644
index 00000000..620c9f73
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py
@@ -0,0 +1,83 @@
+import math
+import os
+from urllib.parse import urlparse
+
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+
+from .misc import sizeof_fmt
+from hordelib.shared_model_manager import SharedModelManager
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = "https://docs.google.com/uc?export=download"
+ params = {"id": file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params["confirm"] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={"Range": "bytes=0-2"})
+ print(response_file_size)
+ if "Content-Range" in response_file_size.headers:
+ file_size = int(response_file_size.headers["Content-Range"].split("/")[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith("download_warning"):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit="chunk")
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, "wb") as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f"Download {sizeof_fmt(downloaded_size)} / {readable_file_size}")
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ return str(SharedModelManager.manager.gfpgan.model_folder_path / file_name)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py b/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py
new file mode 100644
index 00000000..d9c1d273
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py
@@ -0,0 +1,172 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError("Please install memcached to enable MemcachedBackend.")
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, "rb") as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath) as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys="default", readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError("Please install lmdb to enable LmdbBackend.")
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), (
+ "client_keys and db_paths should have the same length, "
+ f"but received {len(client_keys)} and {len(self.db_paths)}."
+ )
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths, strict=False):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, f"client_key {client_key} is not " "in lmdb clients."
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode("ascii"))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient:
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ "disk": HardDiskBackend,
+ "memcached": MemcachedBackend,
+ "lmdb": LmdbBackend,
+ }
+
+ def __init__(self, backend="disk", **kwargs):
+ if backend not in self._backends:
+ raise ValueError(
+ f"Backend {backend} is not supported. Currently supported ones" f" are {list(self._backends.keys())}",
+ )
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key="default"):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == "lmdb":
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py
new file mode 100644
index 00000000..da64b7aa
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py
@@ -0,0 +1,171 @@
+import math
+import os
+
+import cv2
+import numpy as np
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}")
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag="color", float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {"color": cv2.IMREAD_COLOR, "grayscale": cv2.IMREAD_GRAYSCALE, "unchanged": cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.0
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py
new file mode 100644
index 00000000..21ecbfd5
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py
@@ -0,0 +1,200 @@
+import sys
+from multiprocessing import Pool
+from os import path as osp
+
+import cv2
+import lmdb
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(
+ data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None,
+):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), (
+ "img_path_list and keys should have the same length, " f"but got {len(img_path_list)} and {len(keys)}"
+ )
+ print(f"Create lmdb for {data_path}, save to {lmdb_path}...")
+ print(f"Totoal images: {len(img_path_list)}")
+ if not lmdb_path.endswith(".lmdb"):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f"Folder {lmdb_path} already exists. Exit.")
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f"Read images with multiprocessing, #thread: {n_thread} ...")
+ pbar = tqdm(total=len(img_path_list), unit="image")
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f"Read {key}")
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys, strict=False):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f"Finish reading {len(img_path_list)} images.")
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print("Data size per image is: ", data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit="chunk")
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w")
+ for idx, (path, key) in enumerate(zip(img_path_list, keys, strict=False)):
+ pbar.update(1)
+ pbar.set_description(f"Write {key}")
+ key_byte = key.encode("ascii")
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f"{key}.png ({h},{w},{c}) {compress_level}\n")
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print("\nFinish writing lmdb.")
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker:
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith(".lmdb"):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f"Folder {lmdb_path} already exists. Exit.")
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w")
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode("ascii")
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f"{key}.png ({h},{w},{c}) {self.compress_level}\n")
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py b/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py
new file mode 100644
index 00000000..1201c879
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py
@@ -0,0 +1,174 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class MessageLogger:
+ """Message logger for printing.
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt["name"]
+ self.interval = opt["logger"]["print_freq"]
+ self.start_iter = start_iter
+ self.max_iters = opt["train"]["total_iter"]
+ self.use_tb_logger = opt["logger"]["use_tb_logger"]
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop("epoch")
+ current_iter = log_vars.pop("iter")
+ lrs = log_vars.pop("lrs")
+
+ message = f"[{self.exp_name[:5]}..][epoch:{epoch:3d}, " f"iter:{current_iter:8,d}, lr:("
+ for v in lrs:
+ message += f"{v:.3e},"
+ message += ")] "
+
+ # time and estimated time
+ if "time" in log_vars.keys():
+ iter_time = log_vars.pop("time")
+ data_time = log_vars.pop("data_time")
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f"[eta: {eta_str}, "
+ message += f"time (data): {iter_time:.3f} ({data_time:.3f})] "
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f"{k}: {v:.4e} "
+ # tensorboard logger
+ if self.use_tb_logger:
+ if k.startswith("l_"):
+ self.tb_logger.add_scalar(f"losses/{k}", v, current_iter)
+ else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+
+ logger = logging.getLogger("basicsr")
+
+ project = opt["logger"]["wandb"]["project"]
+ resume_id = opt["logger"]["wandb"].get("resume_id")
+ if resume_id:
+ wandb_id = resume_id
+ resume = "allow"
+ logger.warning(f"Resume wandb logger with id={wandb_id}.")
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = "never"
+
+ wandb.init(id=wandb_id, resume=resume, name=opt["name"], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f"Use wandb logger with id={wandb_id}; project={project}.")
+
+
+def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = "%(asctime)s %(levelname)s: %(message)s"
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel("ERROR")
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ file_handler = logging.FileHandler(log_file, "a") # Shangchen: keep the previous log
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from hordelib.nodes.facerestore_cf.basicsr.version import __version__
+
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += (
+ "\nVersion Information: "
+ f"\n\tBasicSR: {__version__}"
+ f"\n\tPyTorch: {torch.__version__}"
+ f"\n\tTorchVision: {torchvision.__version__}"
+ )
+ return msg
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py b/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py
new file mode 100644
index 00000000..5a637988
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py
@@ -0,0 +1,372 @@
+import math
+
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
+ -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
+ ) * (((absx > 1) * (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length,
+ p,
+ )
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = "cubic"
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
+ in_h,
+ out_h,
+ scale,
+ kernel,
+ kernel_width,
+ antialiasing,
+ )
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
+ in_w,
+ out_w,
+ scale,
+ kernel,
+ kernel_width,
+ antialiasing,
+ )
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_w[i])
+
+ if numpy_type:
+ out_2 = out_2.numpy().transpose(1, 2, 0)
+ return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
+ 16,
+ 128,
+ 128,
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
+ 16,
+ 128,
+ 128,
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(
+ img,
+ [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]],
+ ) * 255.0 + [-222.921, 135.576, -276.836]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(
+ img,
+ [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], [0, -0.00318811, 0.00625893]],
+ ) * 255.0 + [-276.836, 135.576, -222.921]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.0
+ else:
+ raise TypeError("The img type should be np.float32 or np.uint8, " f"but got {img_type}")
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace convertion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError("The dst_type should be np.float32 or np.uint8, " f"but got {dst_type}")
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.0
+ return img.astype(dst_type)
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py b/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py
new file mode 100644
index 00000000..be8fe502
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py
@@ -0,0 +1,136 @@
+import os
+import random
+import time
+from os import path as osp
+
+import numpy as np
+import torch
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime("%Y%m%d_%H%M%S", time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + "_archived_" + get_time_str()
+ print(f"Path already exists. Rename it to {new_name}", flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt["path"].copy()
+ if opt["is_train"]:
+ mkdir_and_rename(path_opt.pop("experiments_root"))
+ else:
+ mkdir_and_rename(path_opt.pop("results_root"))
+ for key, path in path_opt.items():
+ if ("strict_load" not in key) and ("pretrain_network" not in key) and ("resume" not in key):
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith(".") and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ logger = get_root_logger()
+ if opt["path"]["resume_state"]:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith("network_")]
+ flag_pretrain = False
+ for network in networks:
+ if opt["path"].get(f"pretrain_{network}") is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ logger.warning("pretrain_network path will be ignored during resuming.")
+ # set pretrained model paths
+ for network in networks:
+ name = f"pretrain_{network}"
+ basename = network.replace("network_", "")
+ if opt["path"].get("ignore_resume_networks") is None or (
+ basename not in opt["path"]["ignore_resume_networks"]
+ ):
+ opt["path"][name] = osp.join(opt["path"]["models"], f"net_{basename}_{resume_iter}.pth")
+ logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix="B"):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
+ if abs(size) < 1024.0:
+ return f"{size:3.1f} {unit}{suffix}"
+ size /= 1024.0
+ return f"{size:3.1f} Y{suffix}"
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/options.py b/hordelib/nodes/facerestore_cf/basicsr/utils/options.py
new file mode 100644
index 00000000..909e4576
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/options.py
@@ -0,0 +1,109 @@
+from collections import OrderedDict
+from os import path as osp
+
+import yaml
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.misc import get_time_str
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def parse(opt_path, root_path, is_train=True):
+ """Parse option file.
+
+ Args:
+ opt_path (str): Option file path.
+ is_train (str): Indicate whether in training or not. Default: True.
+
+ Returns:
+ (dict): Options.
+ """
+ with open(opt_path) as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader)
+
+ opt["is_train"] = is_train
+
+ # opt['name'] = f"{get_time_str()}_{opt['name']}"
+ if opt["path"].get("resume_state", None): # Shangchen added
+ resume_state_path = opt["path"].get("resume_state")
+ opt["name"] = resume_state_path.split("/")[-3]
+ else:
+ opt["name"] = f"{get_time_str()}_{opt['name']}"
+
+ # datasets
+ for phase, dataset in opt["datasets"].items():
+ # for several datasets, e.g., test_1, test_2
+ phase = phase.split("_")[0]
+ dataset["phase"] = phase
+ if "scale" in opt:
+ dataset["scale"] = opt["scale"]
+ if dataset.get("dataroot_gt") is not None:
+ dataset["dataroot_gt"] = osp.expanduser(dataset["dataroot_gt"])
+ if dataset.get("dataroot_lq") is not None:
+ dataset["dataroot_lq"] = osp.expanduser(dataset["dataroot_lq"])
+
+ # paths
+ for key, val in opt["path"].items():
+ if (val is not None) and ("resume_state" in key or "pretrain_network" in key):
+ opt["path"][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, "experiments", opt["name"])
+ opt["path"]["experiments_root"] = experiments_root
+ opt["path"]["models"] = osp.join(experiments_root, "models")
+ opt["path"]["training_states"] = osp.join(experiments_root, "training_states")
+ opt["path"]["log"] = experiments_root
+ opt["path"]["visualization"] = osp.join(experiments_root, "visualization")
+
+ else: # test
+ results_root = osp.join(root_path, "results", opt["name"])
+ opt["path"]["results_root"] = results_root
+ opt["path"]["log"] = results_root
+ opt["path"]["visualization"] = osp.join(results_root, "visualization")
+
+ return opt
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = "\n"
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += " " * (indent_level * 2) + k + ":["
+ msg += dict2str(v, indent_level + 1)
+ msg += " " * (indent_level * 2) + "]\n"
+ else:
+ msg += " " * (indent_level * 2) + k + ": " + str(v) + "\n"
+ return msg
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py b/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 00000000..73c4da85
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,319 @@
+import math
+import os
+import queue
+import threading
+
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from hordelib.nodes.facerestore_cf.basicsr.utils.download_util import load_file_from_url
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer:
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(
+ self,
+ scale,
+ model_path,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None,
+ ):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = (
+ torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if device is None else device
+ )
+ else:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+ if model_path.startswith("https://"):
+ model_path = load_file_from_url(
+ url=model_path,
+ model_dir=os.path.join("weights/realesrgan"),
+ progress=True,
+ file_name=None,
+ )
+ loadnet = torch.load(model_path, map_location=torch.device("cpu"))
+ # prefer to use params_ema
+ if "params_ema" in loadnet:
+ keyname = "params_ema"
+ else:
+ keyname = "params"
+ model.load_state_dict(loadnet[keyname], strict=True)
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ self.img_pre_pad = self.img.clone()
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if h % self.mod_scale != 0:
+ self.mod_pad_h = self.mod_scale - h % self.mod_scale
+ if w % self.mod_scale != 0:
+ self.mod_pad_w = self.mod_scale - w % self.mod_scale
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect")
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print("Error", error)
+ # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[
+ :,
+ :,
+ output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile,
+ ]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0 : h - self.mod_pad_h * self.scale, 0 : w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0 : h - self.pre_pad * self.scale, 0 : w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print("\tInput is a 16-bit image")
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = "L"
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = "RGBA"
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == "realesrgan":
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = "RGB"
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ try:
+ with torch.no_grad():
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img_t = self.post_process()
+ output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == "L":
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ del output_img_t
+ torch.cuda.empty_cache()
+ except RuntimeError as error:
+ output_img = cv2.resize(
+ self.img_pre_pad,
+ (w_input * self.scale, h_input * self.scale),
+ interpolation=cv2.INTER_LINEAR,
+ )
+ print(f"Failed inference for RealESRGAN: {error}")
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == "RGBA":
+ if alpha_upsampler == "realesrgan":
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output,
+ (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ),
+ interpolation=cv2.INTER_LANCZOS4,
+ )
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == "quit":
+ break
+
+ output = msg["output"]
+ save_path = msg["save_path"]
+ cv2.imwrite(save_path, output)
+ print(f"IO worker {self.qid} is done.")
diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py b/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py
new file mode 100644
index 00000000..1eec4a97
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py
@@ -0,0 +1,83 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py
+
+
+class Registry:
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj):
+ assert name not in self._obj_map, (
+ f"An object named '{name}' was already registered " f"in '{self._name}' registry!"
+ )
+ self._obj_map[name] = obj
+
+ def register(self, obj=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj)
+
+ def get(self, name):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry("dataset")
+ARCH_REGISTRY = Registry("arch")
+MODEL_REGISTRY = Registry("model")
+LOSS_REGISTRY = Registry("loss")
+METRIC_REGISTRY = Registry("metric")
diff --git a/hordelib/nodes/facerestore_cf/basicsr/version.py b/hordelib/nodes/facerestore_cf/basicsr/version.py
new file mode 100644
index 00000000..677c4699
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/basicsr/version.py
@@ -0,0 +1,5 @@
+# GENERATED VERSION FILE
+# TIME: Sun Aug 7 15:14:26 2022
+__version__ = "1.3.2"
+__gitsha__ = "6f94023"
+version_info = (1, 3, 2)
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/__init__.py b/hordelib/nodes/facerestore_cf/facelib/__init__.py
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/__init__.py
rename to hordelib/nodes/facerestore_cf/facelib/__init__.py
diff --git a/hordelib/nodes/facerestore/facelib/detection/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/__init__.py
similarity index 83%
rename from hordelib/nodes/facerestore/facelib/detection/__init__.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/__init__.py
index 7ec791bf..9bf408a8 100644
--- a/hordelib/nodes/facerestore/facelib/detection/__init__.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/__init__.py
@@ -1,116 +1,116 @@
-import os
-import torch
-from torch import nn
-from copy import deepcopy
-
-from hordelib.nodes.facerestore.facelib.utils import load_file_from_url
-from hordelib.nodes.facerestore.facelib.utils import download_pretrained_models
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv
-
-from .retinaface.retinaface import RetinaFace
-from .yolov5face.face_detector import YoloDetector
-
-
-def init_detection_model(model_name, half=False, device="cuda"):
- if "retinaface" in model_name:
- model = init_retinaface_model(model_name, half, device)
- elif "YOLOv5" in model_name:
- model = init_yolov5face_model(model_name, device)
- else:
- raise NotImplementedError(f"{model_name} is not implemented.")
-
- return model
-
-
-def init_retinaface_model(model_name, half=False, device="cuda"):
- if model_name == "retinaface_resnet50":
- model = RetinaFace(network_name="resnet50", half=half)
- model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
- elif model_name == "retinaface_mobile0.25":
- model = RetinaFace(network_name="mobile0.25", half=half)
- model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth"
- else:
- raise NotImplementedError(f"{model_name} is not implemented.")
-
- model_path = load_file_from_url(
- url=model_url,
- model_dir="../../models/facedetection",
- progress=True,
- file_name=None,
- )
- load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
- # remove unnecessary 'module.'
- for k, v in deepcopy(load_net).items():
- if k.startswith("module."):
- load_net[k[7:]] = v
- load_net.pop(k)
- model.load_state_dict(load_net, strict=True)
- model.eval()
- model = model.to(device)
-
- return model
-
-
-def init_yolov5face_model(model_name, device="cuda"):
- if model_name == "YOLOv5l":
- model = YoloDetector(
- config_name="nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml",
- device=device,
- )
- model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth"
- elif model_name == "YOLOv5n":
- model = YoloDetector(
- config_name="nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml",
- device=device,
- )
- model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth"
- else:
- raise NotImplementedError(f"{model_name} is not implemented.")
-
- model_path = load_file_from_url(
- url=model_url,
- model_dir="../../models/facedetection",
- progress=True,
- file_name=None,
- )
- load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
- model.detector.load_state_dict(load_net, strict=True)
- model.detector.eval()
- model.detector = model.detector.to(device).float()
-
- for m in model.detector.modules():
- if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
- m.inplace = True # pytorch 1.7.0 compatibility
- elif isinstance(m, Conv):
- m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
-
- return model
-
-
-# Download from Google Drive
-# def init_yolov5face_model(model_name, device='cuda'):
-# if model_name == 'YOLOv5l':
-# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
-# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
-# elif model_name == 'YOLOv5n':
-# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
-# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
-# else:
-# raise NotImplementedError(f'{model_name} is not implemented.')
-
-# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0])
-# if not os.path.exists(model_path):
-# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection')
-
-# load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
-# model.detector.load_state_dict(load_net, strict=True)
-# model.detector.eval()
-# model.detector = model.detector.to(device).float()
-
-# for m in model.detector.modules():
-# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
-# m.inplace = True # pytorch 1.7.0 compatibility
-# elif isinstance(m, Conv):
-# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
-
-# return model
+import os
+import pathlib
+from copy import deepcopy
+
+import torch
+from torch import nn
+
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import Conv
+from hordelib.nodes.facerestore_cf.facelib.utils import download_pretrained_models, load_file_from_url
+
+from .retinaface.retinaface import RetinaFace
+from .yolov5face.face_detector import YoloDetector
+
+
+def init_detection_model(model_name, half=False, device="cuda"):
+ if "retinaface" in model_name:
+ model = init_retinaface_model(model_name, half, device)
+ elif "YOLOv5" in model_name:
+ model = init_yolov5face_model(model_name, device)
+ else:
+ raise NotImplementedError(f"{model_name} is not implemented.")
+
+ return model
+
+
+def init_retinaface_model(model_name, half=False, device="cuda"):
+ if model_name == "retinaface_resnet50":
+ model = RetinaFace(network_name="resnet50", half=half)
+ model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
+ filename = "detection_Resnet50_Final.pth"
+ elif model_name == "retinaface_mobile0.25":
+ model = RetinaFace(network_name="mobile0.25", half=half)
+ model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth"
+ filename = "detection_mobilenet0.25_Final.pth"
+ else:
+ raise NotImplementedError(f"{model_name} is not implemented.")
+
+ model_path = load_file_from_url(
+ url=model_url,
+ model_dir="../../models/facedetection",
+ progress=True,
+ file_name=filename,
+ )
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith("module."):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+
+ return model
+
+
+def init_yolov5face_model(model_name, device="cuda"):
+ current_dir = str(pathlib.Path(__file__).parent.resolve())
+ if model_name == "YOLOv5l":
+ model = YoloDetector(config_name=current_dir + "/yolov5face/models/yolov5l.yaml", device=device)
+ model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth"
+ filename = "yolov5l-face.pth"
+ elif model_name == "YOLOv5n":
+ model = YoloDetector(config_name=current_dir + "/yolov5face/models/yolov5n.yaml", device=device)
+ model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth"
+ filename = "yolov5n-face.pth"
+ else:
+ raise NotImplementedError(f"{model_name} is not implemented.")
+
+ model_path = load_file_from_url(
+ url=model_url,
+ model_dir="../../models/facedetection",
+ progress=True,
+ file_name=filename,
+ )
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.detector.load_state_dict(load_net, strict=True)
+ model.detector.eval()
+ model.detector = model.detector.to(device).float()
+
+ for m in model.detector.modules():
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+ m.inplace = True # pytorch 1.7.0 compatibility
+ elif isinstance(m, Conv):
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+ return model
+
+
+# Download from Google Drive
+# def init_yolov5face_model(model_name, device='cuda'):
+# if model_name == 'YOLOv5l':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
+# elif model_name == 'YOLOv5n':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
+# else:
+# raise NotImplementedError(f'{model_name} is not implemented.')
+
+# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0])
+# if not os.path.exists(model_path):
+# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection')
+
+# load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+# model.detector.load_state_dict(load_net, strict=True)
+# model.detector.eval()
+# model.detector = model.detector.to(device).float()
+
+# for m in model.detector.modules():
+# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+# m.inplace = True # pytorch 1.7.0 compatibility
+# elif isinstance(m, Conv):
+# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+# return model
diff --git a/hordelib/nodes/facerestore/facelib/detection/align_trans.py b/hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py
similarity index 75%
rename from hordelib/nodes/facerestore/facelib/detection/align_trans.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py
index 07f1eb36..84e7a16d 100644
--- a/hordelib/nodes/facerestore/facelib/detection/align_trans.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py
@@ -1,219 +1,233 @@
-import cv2
-import numpy as np
-
-from .matlab_cp2tform import get_similarity_transform_for_cv2
-
-# reference facial points, a list of coordinates (x,y)
-REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
- [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
-
-DEFAULT_CROP_SIZE = (96, 112)
-
-
-class FaceWarpException(Exception):
-
- def __str__(self):
- return 'In File {}:{}'.format(__file__, super.__str__(self))
-
-
-def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
- """
- Function:
- ----------
- get reference 5 key points according to crop settings:
- 0. Set default crop_size:
- if default_square:
- crop_size = (112, 112)
- else:
- crop_size = (96, 112)
- 1. Pad the crop_size by inner_padding_factor in each side;
- 2. Resize crop_size into (output_size - outer_padding*2),
- pad into output_size with outer_padding;
- 3. Output reference_5point;
- Parameters:
- ----------
- @output_size: (w, h) or None
- size of aligned face image
- @inner_padding_factor: (w_factor, h_factor)
- padding factor for inner (w, h)
- @outer_padding: (w_pad, h_pad)
- each row is a pair of coordinates (x, y)
- @default_square: True or False
- if True:
- default crop_size = (112, 112)
- else:
- default crop_size = (96, 112);
- !!! make sure, if output_size is not None:
- (output_size - outer_padding)
- = some_scale * (default crop_size * (1.0 +
- inner_padding_factor))
- Returns:
- ----------
- @reference_5point: 5x2 np.array
- each row is a pair of transformed coordinates (x, y)
- """
-
- tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
- tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
-
- # 0) make the inner region a square
- if default_square:
- size_diff = max(tmp_crop_size) - tmp_crop_size
- tmp_5pts += size_diff / 2
- tmp_crop_size += size_diff
-
- if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
-
- return tmp_5pts
-
- if (inner_padding_factor == 0 and outer_padding == (0, 0)):
- if output_size is None:
- return tmp_5pts
- else:
- raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
-
- # check output size
- if not (0 <= inner_padding_factor <= 1.0):
- raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
-
- if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
- output_size = tmp_crop_size * \
- (1 + inner_padding_factor * 2).astype(np.int32)
- output_size += np.array(outer_padding)
- if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
- raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
-
- # 1) pad the inner region according inner_padding_factor
- if inner_padding_factor > 0:
- size_diff = tmp_crop_size * inner_padding_factor * 2
- tmp_5pts += size_diff / 2
- tmp_crop_size += np.round(size_diff).astype(np.int32)
-
- # 2) resize the padded inner region
- size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
-
- if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
- raise FaceWarpException('Must have (output_size - outer_padding)'
- '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
-
- scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
- tmp_5pts = tmp_5pts * scale_factor
- # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
- # tmp_5pts = tmp_5pts + size_diff / 2
- tmp_crop_size = size_bf_outer_pad
-
- # 3) add outer_padding to make output_size
- reference_5point = tmp_5pts + np.array(outer_padding)
- tmp_crop_size = output_size
-
- return reference_5point
-
-
-def get_affine_transform_matrix(src_pts, dst_pts):
- """
- Function:
- ----------
- get affine transform matrix 'tfm' from src_pts to dst_pts
- Parameters:
- ----------
- @src_pts: Kx2 np.array
- source points matrix, each row is a pair of coordinates (x, y)
- @dst_pts: Kx2 np.array
- destination points matrix, each row is a pair of coordinates (x, y)
- Returns:
- ----------
- @tfm: 2x3 np.array
- transform matrix from src_pts to dst_pts
- """
-
- tfm = np.float32([[1, 0, 0], [0, 1, 0]])
- n_pts = src_pts.shape[0]
- ones = np.ones((n_pts, 1), src_pts.dtype)
- src_pts_ = np.hstack([src_pts, ones])
- dst_pts_ = np.hstack([dst_pts, ones])
-
- A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
-
- if rank == 3:
- tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
- elif rank == 2:
- tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
-
- return tfm
-
-
-def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
- """
- Function:
- ----------
- apply affine transform 'trans' to uv
- Parameters:
- ----------
- @src_img: 3x3 np.array
- input image
- @facial_pts: could be
- 1)a list of K coordinates (x,y)
- or
- 2) Kx2 or 2xK np.array
- each row or col is a pair of coordinates (x, y)
- @reference_pts: could be
- 1) a list of K coordinates (x,y)
- or
- 2) Kx2 or 2xK np.array
- each row or col is a pair of coordinates (x, y)
- or
- 3) None
- if None, use default reference facial points
- @crop_size: (w, h)
- output face image size
- @align_type: transform type, could be one of
- 1) 'similarity': use similarity transform
- 2) 'cv2_affine': use the first 3 points to do affine transform,
- by calling cv2.getAffineTransform()
- 3) 'affine': use all points to do affine transform
- Returns:
- ----------
- @face_img: output face image with size (w, h) = @crop_size
- """
-
- if reference_pts is None:
- if crop_size[0] == 96 and crop_size[1] == 112:
- reference_pts = REFERENCE_FACIAL_POINTS
- else:
- default_square = False
- inner_padding_factor = 0
- outer_padding = (0, 0)
- output_size = crop_size
-
- reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
- default_square)
-
- ref_pts = np.float32(reference_pts)
- ref_pts_shp = ref_pts.shape
- if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
- raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
-
- if ref_pts_shp[0] == 2:
- ref_pts = ref_pts.T
-
- src_pts = np.float32(facial_pts)
- src_pts_shp = src_pts.shape
- if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
- raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
-
- if src_pts_shp[0] == 2:
- src_pts = src_pts.T
-
- if src_pts.shape != ref_pts.shape:
- raise FaceWarpException('facial_pts and reference_pts must have the same shape')
-
- if align_type == 'cv2_affine':
- tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
- elif align_type == 'affine':
- tfm = get_affine_transform_matrix(src_pts, ref_pts)
- else:
- tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
-
- face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
-
- return face_img
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [
+ [30.29459953, 51.69630051],
+ [65.53179932, 51.50139999],
+ [48.02519989, 71.73660278],
+ [33.54930115, 92.3655014],
+ [62.72990036, 92.20410156],
+]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+ def __str__(self):
+ return f"In File {__file__}:{super.__str__(self)}"
+
+
+def get_reference_facial_points(
+ output_size=None,
+ inner_padding_factor=0.0,
+ outer_padding=(0, 0),
+ default_square=False,
+):
+ """
+ Function:
+ ----------
+ get reference 5 key points according to crop settings:
+ 0. Set default crop_size:
+ if default_square:
+ crop_size = (112, 112)
+ else:
+ crop_size = (96, 112)
+ 1. Pad the crop_size by inner_padding_factor in each side;
+ 2. Resize crop_size into (output_size - outer_padding*2),
+ pad into output_size with outer_padding;
+ 3. Output reference_5point;
+ Parameters:
+ ----------
+ @output_size: (w, h) or None
+ size of aligned face image
+ @inner_padding_factor: (w_factor, h_factor)
+ padding factor for inner (w, h)
+ @outer_padding: (w_pad, h_pad)
+ each row is a pair of coordinates (x, y)
+ @default_square: True or False
+ if True:
+ default crop_size = (112, 112)
+ else:
+ default crop_size = (96, 112);
+ !!! make sure, if output_size is not None:
+ (output_size - outer_padding)
+ = some_scale * (default crop_size * (1.0 +
+ inner_padding_factor))
+ Returns:
+ ----------
+ @reference_5point: 5x2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+ # 0) make the inner region a square
+ if default_square:
+ size_diff = max(tmp_crop_size) - tmp_crop_size
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += size_diff
+
+ if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]:
+
+ return tmp_5pts
+
+ if inner_padding_factor == 0 and outer_padding == (0, 0):
+ if output_size is None:
+ return tmp_5pts
+ else:
+ raise FaceWarpException(f"No paddings to do, output_size must be None or {tmp_crop_size}")
+
+ # check output size
+ if not (0 <= inner_padding_factor <= 1.0):
+ raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)")
+
+ if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None:
+ output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32)
+ output_size += np.array(outer_padding)
+ if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+ raise FaceWarpException("Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])")
+
+ # 1) pad the inner region according inner_padding_factor
+ if inner_padding_factor > 0:
+ size_diff = tmp_crop_size * inner_padding_factor * 2
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+ # 2) resize the padded inner region
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+ raise FaceWarpException(
+ "Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)",
+ )
+
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+ tmp_5pts = tmp_5pts * scale_factor
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+ # tmp_5pts = tmp_5pts + size_diff / 2
+ tmp_crop_size = size_bf_outer_pad
+
+ # 3) add outer_padding to make output_size
+ reference_5point = tmp_5pts + np.array(outer_padding)
+ tmp_crop_size = output_size
+
+ return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+ """
+ Function:
+ ----------
+ get affine transform matrix 'tfm' from src_pts to dst_pts
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points matrix, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points matrix, each row is a pair of coordinates (x, y)
+ Returns:
+ ----------
+ @tfm: 2x3 np.array
+ transform matrix from src_pts to dst_pts
+ """
+
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+ n_pts = src_pts.shape[0]
+ ones = np.ones((n_pts, 1), src_pts.dtype)
+ src_pts_ = np.hstack([src_pts, ones])
+ dst_pts_ = np.hstack([dst_pts, ones])
+
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+ if rank == 3:
+ tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+ elif rank == 2:
+ tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+ return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+ Parameters:
+ ----------
+ @src_img: 3x3 np.array
+ input image
+ @facial_pts: could be
+ 1)a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ @reference_pts: could be
+ 1) a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ or
+ 3) None
+ if None, use default reference facial points
+ @crop_size: (w, h)
+ output face image size
+ @align_type: transform type, could be one of
+ 1) 'similarity': use similarity transform
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
+ by calling cv2.getAffineTransform()
+ 3) 'affine': use all points to do affine transform
+ Returns:
+ ----------
+ @face_img: output face image with size (w, h) = @crop_size
+ """
+
+ if reference_pts is None:
+ if crop_size[0] == 96 and crop_size[1] == 112:
+ reference_pts = REFERENCE_FACIAL_POINTS
+ else:
+ default_square = False
+ inner_padding_factor = 0
+ outer_padding = (0, 0)
+ output_size = crop_size
+
+ reference_pts = get_reference_facial_points(
+ output_size,
+ inner_padding_factor,
+ outer_padding,
+ default_square,
+ )
+
+ ref_pts = np.float32(reference_pts)
+ ref_pts_shp = ref_pts.shape
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+ raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2")
+
+ if ref_pts_shp[0] == 2:
+ ref_pts = ref_pts.T
+
+ src_pts = np.float32(facial_pts)
+ src_pts_shp = src_pts.shape
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+ raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2")
+
+ if src_pts_shp[0] == 2:
+ src_pts = src_pts.T
+
+ if src_pts.shape != ref_pts.shape:
+ raise FaceWarpException("facial_pts and reference_pts must have the same shape")
+
+ if align_type == "cv2_affine":
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+ elif align_type == "affine":
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
+ else:
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+ return face_img
diff --git a/hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py b/hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py
similarity index 88%
rename from hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py
index b2a8b54a..7bd7b4cd 100644
--- a/hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py
@@ -1,317 +1,316 @@
-import numpy as np
-from numpy.linalg import inv, lstsq
-from numpy.linalg import matrix_rank as rank
-from numpy.linalg import norm
-
-
-class MatlabCp2tormException(Exception):
-
- def __str__(self):
- return 'In File {}:{}'.format(__file__, super.__str__(self))
-
-
-def tformfwd(trans, uv):
- """
- Function:
- ----------
- apply affine transform 'trans' to uv
-
- Parameters:
- ----------
- @trans: 3x3 np.array
- transform matrix
- @uv: Kx2 np.array
- each row is a pair of coordinates (x, y)
-
- Returns:
- ----------
- @xy: Kx2 np.array
- each row is a pair of transformed coordinates (x, y)
- """
- uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
- xy = np.dot(uv, trans)
- xy = xy[:, 0:-1]
- return xy
-
-
-def tforminv(trans, uv):
- """
- Function:
- ----------
- apply the inverse of affine transform 'trans' to uv
-
- Parameters:
- ----------
- @trans: 3x3 np.array
- transform matrix
- @uv: Kx2 np.array
- each row is a pair of coordinates (x, y)
-
- Returns:
- ----------
- @xy: Kx2 np.array
- each row is a pair of inverse-transformed coordinates (x, y)
- """
- Tinv = inv(trans)
- xy = tformfwd(Tinv, uv)
- return xy
-
-
-def findNonreflectiveSimilarity(uv, xy, options=None):
- options = {'K': 2}
-
- K = options['K']
- M = xy.shape[0]
- x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
- y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
-
- tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
- tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
- X = np.vstack((tmp1, tmp2))
-
- u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
- v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
- U = np.vstack((u, v))
-
- # We know that X * r = U
- if rank(X) >= 2 * K:
- r, _, _, _ = lstsq(X, U, rcond=-1)
- r = np.squeeze(r)
- else:
- raise Exception('cp2tform:twoUniquePointsReq')
- sc = r[0]
- ss = r[1]
- tx = r[2]
- ty = r[3]
-
- Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
- T = inv(Tinv)
- T[:, 2] = np.array([0, 0, 1])
-
- return T, Tinv
-
-
-def findSimilarity(uv, xy, options=None):
- options = {'K': 2}
-
- # uv = np.array(uv)
- # xy = np.array(xy)
-
- # Solve for trans1
- trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
-
- # Solve for trans2
-
- # manually reflect the xy data across the Y-axis
- xyR = xy
- xyR[:, 0] = -1 * xyR[:, 0]
-
- trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
-
- # manually reflect the tform to undo the reflection done on xyR
- TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
-
- trans2 = np.dot(trans2r, TreflectY)
-
- # Figure out if trans1 or trans2 is better
- xy1 = tformfwd(trans1, uv)
- norm1 = norm(xy1 - xy)
-
- xy2 = tformfwd(trans2, uv)
- norm2 = norm(xy2 - xy)
-
- if norm1 <= norm2:
- return trans1, trans1_inv
- else:
- trans2_inv = inv(trans2)
- return trans2, trans2_inv
-
-
-def get_similarity_transform(src_pts, dst_pts, reflective=True):
- """
- Function:
- ----------
- Find Similarity Transform Matrix 'trans':
- u = src_pts[:, 0]
- v = src_pts[:, 1]
- x = dst_pts[:, 0]
- y = dst_pts[:, 1]
- [x, y, 1] = [u, v, 1] * trans
-
- Parameters:
- ----------
- @src_pts: Kx2 np.array
- source points, each row is a pair of coordinates (x, y)
- @dst_pts: Kx2 np.array
- destination points, each row is a pair of transformed
- coordinates (x, y)
- @reflective: True or False
- if True:
- use reflective similarity transform
- else:
- use non-reflective similarity transform
-
- Returns:
- ----------
- @trans: 3x3 np.array
- transform matrix from uv to xy
- trans_inv: 3x3 np.array
- inverse of trans, transform matrix from xy to uv
- """
-
- if reflective:
- trans, trans_inv = findSimilarity(src_pts, dst_pts)
- else:
- trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
-
- return trans, trans_inv
-
-
-def cvt_tform_mat_for_cv2(trans):
- """
- Function:
- ----------
- Convert Transform Matrix 'trans' into 'cv2_trans' which could be
- directly used by cv2.warpAffine():
- u = src_pts[:, 0]
- v = src_pts[:, 1]
- x = dst_pts[:, 0]
- y = dst_pts[:, 1]
- [x, y].T = cv_trans * [u, v, 1].T
-
- Parameters:
- ----------
- @trans: 3x3 np.array
- transform matrix from uv to xy
-
- Returns:
- ----------
- @cv2_trans: 2x3 np.array
- transform matrix from src_pts to dst_pts, could be directly used
- for cv2.warpAffine()
- """
- cv2_trans = trans[:, 0:2].T
-
- return cv2_trans
-
-
-def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
- """
- Function:
- ----------
- Find Similarity Transform Matrix 'cv2_trans' which could be
- directly used by cv2.warpAffine():
- u = src_pts[:, 0]
- v = src_pts[:, 1]
- x = dst_pts[:, 0]
- y = dst_pts[:, 1]
- [x, y].T = cv_trans * [u, v, 1].T
-
- Parameters:
- ----------
- @src_pts: Kx2 np.array
- source points, each row is a pair of coordinates (x, y)
- @dst_pts: Kx2 np.array
- destination points, each row is a pair of transformed
- coordinates (x, y)
- reflective: True or False
- if True:
- use reflective similarity transform
- else:
- use non-reflective similarity transform
-
- Returns:
- ----------
- @cv2_trans: 2x3 np.array
- transform matrix from src_pts to dst_pts, could be directly used
- for cv2.warpAffine()
- """
- trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
- cv2_trans = cvt_tform_mat_for_cv2(trans)
-
- return cv2_trans
-
-
-if __name__ == '__main__':
- """
- u = [0, 6, -2]
- v = [0, 3, 5]
- x = [-1, 0, 4]
- y = [-1, -10, 4]
-
- # In Matlab, run:
- #
- # uv = [u'; v'];
- # xy = [x'; y'];
- # tform_sim=cp2tform(uv,xy,'similarity');
- #
- # trans = tform_sim.tdata.T
- # ans =
- # -0.0764 -1.6190 0
- # 1.6190 -0.0764 0
- # -3.2156 0.0290 1.0000
- # trans_inv = tform_sim.tdata.Tinv
- # ans =
- #
- # -0.0291 0.6163 0
- # -0.6163 -0.0291 0
- # -0.0756 1.9826 1.0000
- # xy_m=tformfwd(tform_sim, u,v)
- #
- # xy_m =
- #
- # -3.2156 0.0290
- # 1.1833 -9.9143
- # 5.0323 2.8853
- # uv_m=tforminv(tform_sim, x,y)
- #
- # uv_m =
- #
- # 0.5698 1.3953
- # 6.0872 2.2733
- # -2.6570 4.3314
- """
- u = [0, 6, -2]
- v = [0, 3, 5]
- x = [-1, 0, 4]
- y = [-1, -10, 4]
-
- uv = np.array((u, v)).T
- xy = np.array((x, y)).T
-
- print('\n--->uv:')
- print(uv)
- print('\n--->xy:')
- print(xy)
-
- trans, trans_inv = get_similarity_transform(uv, xy)
-
- print('\n--->trans matrix:')
- print(trans)
-
- print('\n--->trans_inv matrix:')
- print(trans_inv)
-
- print('\n---> apply transform to uv')
- print('\nxy_m = uv_augmented * trans')
- uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
- xy_m = np.dot(uv_aug, trans)
- print(xy_m)
-
- print('\nxy_m = tformfwd(trans, uv)')
- xy_m = tformfwd(trans, uv)
- print(xy_m)
-
- print('\n---> apply inverse transform to xy')
- print('\nuv_m = xy_augmented * trans_inv')
- xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
- uv_m = np.dot(xy_aug, trans_inv)
- print(uv_m)
-
- print('\nuv_m = tformfwd(trans_inv, xy)')
- uv_m = tformfwd(trans_inv, xy)
- print(uv_m)
-
- uv_m = tforminv(trans, xy)
- print('\nuv_m = tforminv(trans, xy)')
- print(uv_m)
+import numpy as np
+from numpy.linalg import inv, lstsq, norm
+from numpy.linalg import matrix_rank as rank
+
+
+class MatlabCp2tormException(Exception):
+
+ def __str__(self):
+ return f"In File {__file__}:{super.__str__(self)}"
+
+
+def tformfwd(trans, uv):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy = np.dot(uv, trans)
+ xy = xy[:, 0:-1]
+ return xy
+
+
+def tforminv(trans, uv):
+ """
+ Function:
+ ----------
+ apply the inverse of affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of inverse-transformed coordinates (x, y)
+ """
+ Tinv = inv(trans)
+ xy = tformfwd(Tinv, uv)
+ return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+ options = {"K": 2}
+
+ K = options["K"]
+ M = xy.shape[0]
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+ X = np.vstack((tmp1, tmp2))
+
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ U = np.vstack((u, v))
+
+ # We know that X * r = U
+ if rank(X) >= 2 * K:
+ r, _, _, _ = lstsq(X, U, rcond=-1)
+ r = np.squeeze(r)
+ else:
+ raise Exception("cp2tform:twoUniquePointsReq")
+ sc = r[0]
+ ss = r[1]
+ tx = r[2]
+ ty = r[3]
+
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+ T = inv(Tinv)
+ T[:, 2] = np.array([0, 0, 1])
+
+ return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+ options = {"K": 2}
+
+ # uv = np.array(uv)
+ # xy = np.array(xy)
+
+ # Solve for trans1
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+ # Solve for trans2
+
+ # manually reflect the xy data across the Y-axis
+ xyR = xy
+ xyR[:, 0] = -1 * xyR[:, 0]
+
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+ # manually reflect the tform to undo the reflection done on xyR
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ trans2 = np.dot(trans2r, TreflectY)
+
+ # Figure out if trans1 or trans2 is better
+ xy1 = tformfwd(trans1, uv)
+ norm1 = norm(xy1 - xy)
+
+ xy2 = tformfwd(trans2, uv)
+ norm2 = norm(xy2 - xy)
+
+ if norm1 <= norm2:
+ return trans1, trans1_inv
+ else:
+ trans2_inv = inv(trans2)
+ return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'trans':
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y, 1] = [u, v, 1] * trans
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ @reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+ trans_inv: 3x3 np.array
+ inverse of trans, transform matrix from xy to uv
+ """
+
+ if reflective:
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
+ else:
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+ return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+ """
+ Function:
+ ----------
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ cv2_trans = trans[:, 0:2].T
+
+ return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+ return cv2_trans
+
+
+if __name__ == "__main__":
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ # In Matlab, run:
+ #
+ # uv = [u'; v'];
+ # xy = [x'; y'];
+ # tform_sim=cp2tform(uv,xy,'similarity');
+ #
+ # trans = tform_sim.tdata.T
+ # ans =
+ # -0.0764 -1.6190 0
+ # 1.6190 -0.0764 0
+ # -3.2156 0.0290 1.0000
+ # trans_inv = tform_sim.tdata.Tinv
+ # ans =
+ #
+ # -0.0291 0.6163 0
+ # -0.6163 -0.0291 0
+ # -0.0756 1.9826 1.0000
+ # xy_m=tformfwd(tform_sim, u,v)
+ #
+ # xy_m =
+ #
+ # -3.2156 0.0290
+ # 1.1833 -9.9143
+ # 5.0323 2.8853
+ # uv_m=tforminv(tform_sim, x,y)
+ #
+ # uv_m =
+ #
+ # 0.5698 1.3953
+ # 6.0872 2.2733
+ # -2.6570 4.3314
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ uv = np.array((u, v)).T
+ xy = np.array((x, y)).T
+
+ print("\n--->uv:")
+ print(uv)
+ print("\n--->xy:")
+ print(xy)
+
+ trans, trans_inv = get_similarity_transform(uv, xy)
+
+ print("\n--->trans matrix:")
+ print(trans)
+
+ print("\n--->trans_inv matrix:")
+ print(trans_inv)
+
+ print("\n---> apply transform to uv")
+ print("\nxy_m = uv_augmented * trans")
+ uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy_m = np.dot(uv_aug, trans)
+ print(xy_m)
+
+ print("\nxy_m = tformfwd(trans, uv)")
+ xy_m = tformfwd(trans, uv)
+ print(xy_m)
+
+ print("\n---> apply inverse transform to xy")
+ print("\nuv_m = xy_augmented * trans_inv")
+ xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+ uv_m = np.dot(xy_aug, trans_inv)
+ print(uv_m)
+
+ print("\nuv_m = tformfwd(trans_inv, xy)")
+ uv_m = tformfwd(trans_inv, xy)
+ print(uv_m)
+
+ uv_m = tforminv(trans, xy)
+ print("\nuv_m = tforminv(trans, xy)")
+ print(uv_m)
diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py
similarity index 83%
rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py
index 16881365..bfe398b4 100644
--- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py
@@ -1,420 +1,389 @@
-import cv2
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from PIL import Image
-from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
-
-from hordelib.nodes.facerestore.facelib.detection.align_trans import (
- get_reference_facial_points,
- warp_and_crop_face,
-)
-from hordelib.nodes.facerestore.facelib.detection.retinaface.retinaface_net import (
- FPN,
- SSH,
- MobileNetV1,
- make_bbox_head,
- make_class_head,
- make_landmark_head,
-)
-from hordelib.nodes.facerestore.facelib.detection.retinaface.retinaface_utils import (
- PriorBox,
- batched_decode,
- batched_decode_landm,
- decode,
- decode_landm,
- py_cpu_nms,
-)
-
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
-def generate_config(network_name):
- cfg_mnet = {
- "name": "mobilenet0.25",
- "min_sizes": [[16, 32], [64, 128], [256, 512]],
- "steps": [8, 16, 32],
- "variance": [0.1, 0.2],
- "clip": False,
- "loc_weight": 2.0,
- "gpu_train": True,
- "batch_size": 32,
- "ngpu": 1,
- "epoch": 250,
- "decay1": 190,
- "decay2": 220,
- "image_size": 640,
- "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
- "in_channel": 32,
- "out_channel": 64,
- }
-
- cfg_re50 = {
- "name": "Resnet50",
- "min_sizes": [[16, 32], [64, 128], [256, 512]],
- "steps": [8, 16, 32],
- "variance": [0.1, 0.2],
- "clip": False,
- "loc_weight": 2.0,
- "gpu_train": True,
- "batch_size": 24,
- "ngpu": 4,
- "epoch": 100,
- "decay1": 70,
- "decay2": 90,
- "image_size": 840,
- "return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
- "in_channel": 256,
- "out_channel": 256,
- }
-
- if network_name == "mobile0.25":
- return cfg_mnet
- elif network_name == "resnet50":
- return cfg_re50
- else:
- raise NotImplementedError(f"network_name={network_name}")
-
-
-class RetinaFace(nn.Module):
- def __init__(self, network_name="resnet50", half=False, phase="test"):
- super(RetinaFace, self).__init__()
- self.half_inference = half
- cfg = generate_config(network_name)
- self.backbone = cfg["name"]
-
- self.model_name = f"retinaface_{network_name}"
- self.cfg = cfg
- self.phase = phase
- self.target_size, self.max_size = 1600, 2150
- self.resize, self.scale, self.scale1 = 1.0, None, None
- self.mean_tensor = torch.tensor([[[[104.0]], [[117.0]], [[123.0]]]]).to(device)
- self.reference = get_reference_facial_points(default_square=True)
- # Build network.
- backbone = None
- if cfg["name"] == "mobilenet0.25":
- backbone = MobileNetV1()
- self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
- elif cfg["name"] == "Resnet50":
- import torchvision.models as models
-
- backbone = models.resnet50(pretrained=False)
- self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
-
- in_channels_stage2 = cfg["in_channel"]
- in_channels_list = [
- in_channels_stage2 * 2,
- in_channels_stage2 * 4,
- in_channels_stage2 * 8,
- ]
-
- out_channels = cfg["out_channel"]
- self.fpn = FPN(in_channels_list, out_channels)
- self.ssh1 = SSH(out_channels, out_channels)
- self.ssh2 = SSH(out_channels, out_channels)
- self.ssh3 = SSH(out_channels, out_channels)
-
- self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
- self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
- self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
-
- self.to(device)
- self.eval()
- if self.half_inference:
- self.half()
-
- def forward(self, inputs):
- out = self.body(inputs)
-
- if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
- out = list(out.values())
- # FPN
- fpn = self.fpn(out)
-
- # SSH
- feature1 = self.ssh1(fpn[0])
- feature2 = self.ssh2(fpn[1])
- feature3 = self.ssh3(fpn[2])
- features = [feature1, feature2, feature3]
-
- bbox_regressions = torch.cat(
- [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
- )
- classifications = torch.cat(
- [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
- )
- tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
- ldm_regressions = torch.cat(tmp, dim=1)
-
- if self.phase == "train":
- output = (bbox_regressions, classifications, ldm_regressions)
- else:
- output = (
- bbox_regressions,
- F.softmax(classifications, dim=-1),
- ldm_regressions,
- )
- return output
-
- def __detect_faces(self, inputs):
- # get scale
- height, width = inputs.shape[2:]
- self.scale = torch.tensor(
- [width, height, width, height], dtype=torch.float32
- ).to(device)
- tmp = [
- width,
- height,
- width,
- height,
- width,
- height,
- width,
- height,
- width,
- height,
- ]
- self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
-
- # forawrd
- inputs = inputs.to(device)
- if self.half_inference:
- inputs = inputs.half()
- loc, conf, landmarks = self(inputs)
-
- # get priorbox
- priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
- priors = priorbox.forward().to(device)
-
- return loc, conf, landmarks, priors
-
- # single image detection
- def transform(self, image, use_origin_size):
- # convert to opencv format
- if isinstance(image, Image.Image):
- image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
- image = image.astype(np.float32)
-
- # testing scale
- im_size_min = np.min(image.shape[0:2])
- im_size_max = np.max(image.shape[0:2])
- resize = float(self.target_size) / float(im_size_min)
-
- # prevent bigger axis from being more than max_size
- if np.round(resize * im_size_max) > self.max_size:
- resize = float(self.max_size) / float(im_size_max)
- resize = 1 if use_origin_size else resize
-
- # resize
- if resize != 1:
- image = cv2.resize(
- image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
- )
-
- # convert to torch.tensor format
- # image -= (104, 117, 123)
- image = image.transpose(2, 0, 1)
- image = torch.from_numpy(image).unsqueeze(0)
-
- return image, resize
-
- def detect_faces(
- self,
- image,
- conf_threshold=0.8,
- nms_threshold=0.4,
- use_origin_size=True,
- ):
- """
- Params:
- imgs: BGR image
- """
- image, self.resize = self.transform(image, use_origin_size)
- image = image.to(device)
- if self.half_inference:
- image = image.half()
- image = image - self.mean_tensor
-
- loc, conf, landmarks, priors = self.__detect_faces(image)
-
- boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
- boxes = boxes * self.scale / self.resize
- boxes = boxes.cpu().numpy()
-
- scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
-
- landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
- landmarks = landmarks * self.scale1 / self.resize
- landmarks = landmarks.cpu().numpy()
-
- # ignore low scores
- inds = np.where(scores > conf_threshold)[0]
- boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
-
- # sort
- order = scores.argsort()[::-1]
- boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
-
- # do NMS
- bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(
- np.float32, copy=False
- )
- keep = py_cpu_nms(bounding_boxes, nms_threshold)
- bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
- # self.t['forward_pass'].toc()
- # print(self.t['forward_pass'].average_time)
- # import sys
- # sys.stdout.flush()
- return np.concatenate((bounding_boxes, landmarks), axis=1)
-
- def __align_multi(self, image, boxes, landmarks, limit=None):
- if len(boxes) < 1:
- return [], []
-
- if limit:
- boxes = boxes[:limit]
- landmarks = landmarks[:limit]
-
- faces = []
- for landmark in landmarks:
- facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
-
- warped_face = warp_and_crop_face(
- np.array(image), facial5points, self.reference, crop_size=(112, 112)
- )
- faces.append(warped_face)
-
- return np.concatenate((boxes, landmarks), axis=1), faces
-
- def align_multi(self, img, conf_threshold=0.8, limit=None):
- rlt = self.detect_faces(img, conf_threshold=conf_threshold)
- boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
-
- return self.__align_multi(img, boxes, landmarks, limit)
-
- # batched detection
- def batched_transform(self, frames, use_origin_size):
- """
- Arguments:
- frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
- type=np.float32, BGR format).
- use_origin_size: whether to use origin size.
- """
- from_PIL = True if isinstance(frames[0], Image.Image) else False
-
- # convert to opencv format
- if from_PIL:
- frames = [
- cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames
- ]
- frames = np.asarray(frames, dtype=np.float32)
-
- # testing scale
- im_size_min = np.min(frames[0].shape[0:2])
- im_size_max = np.max(frames[0].shape[0:2])
- resize = float(self.target_size) / float(im_size_min)
-
- # prevent bigger axis from being more than max_size
- if np.round(resize * im_size_max) > self.max_size:
- resize = float(self.max_size) / float(im_size_max)
- resize = 1 if use_origin_size else resize
-
- # resize
- if resize != 1:
- if not from_PIL:
- frames = F.interpolate(frames, scale_factor=resize)
- else:
- frames = [
- cv2.resize(
- frame,
- None,
- None,
- fx=resize,
- fy=resize,
- interpolation=cv2.INTER_LINEAR,
- )
- for frame in frames
- ]
-
- # convert to torch.tensor format
- if not from_PIL:
- frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
- else:
- frames = frames.transpose((0, 3, 1, 2))
- frames = torch.from_numpy(frames)
-
- return frames, resize
-
- def batched_detect_faces(
- self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True
- ):
- """
- Arguments:
- frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
- type=np.uint8, BGR format).
- conf_threshold: confidence threshold.
- nms_threshold: nms threshold.
- use_origin_size: whether to use origin size.
- Returns:
- final_bounding_boxes: list of np.array ([n_boxes, 5],
- type=np.float32).
- final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
- """
- # self.t['forward_pass'].tic()
- frames, self.resize = self.batched_transform(frames, use_origin_size)
- frames = frames.to(device)
- frames = frames - self.mean_tensor
-
- b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
-
- final_bounding_boxes, final_landmarks = [], []
-
- # decode
- priors = priors.unsqueeze(0)
- b_loc = (
- batched_decode(b_loc, priors, self.cfg["variance"])
- * self.scale
- / self.resize
- )
- b_landmarks = (
- batched_decode_landm(b_landmarks, priors, self.cfg["variance"])
- * self.scale1
- / self.resize
- )
- b_conf = b_conf[:, :, 1]
-
- # index for selection
- b_indice = b_conf > conf_threshold
-
- # concat
- b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
-
- for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
- # ignore low scores
- pred, landm = pred[inds, :], landm[inds, :]
- if pred.shape[0] == 0:
- final_bounding_boxes.append(np.array([], dtype=np.float32))
- final_landmarks.append(np.array([], dtype=np.float32))
- continue
-
- # sort
- # order = score.argsort(descending=True)
- # box, landm, score = box[order], landm[order], score[order]
-
- # to CPU
- bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
-
- # NMS
- keep = py_cpu_nms(bounding_boxes, nms_threshold)
- bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
-
- # append
- final_bounding_boxes.append(bounding_boxes)
- final_landmarks.append(landmarks)
- # self.t['forward_pass'].toc(average=True)
- # self.batch_time += self.t['forward_pass'].diff
- # self.total_frame += len(frames)
- # print(self.batch_time / self.total_frame)
-
- return final_bounding_boxes, final_landmarks
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy import model_management
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from hordelib.nodes.facerestore_cf.facelib.detection.align_trans import (
+ get_reference_facial_points,
+ warp_and_crop_face,
+)
+from hordelib.nodes.facerestore_cf.facelib.detection.retinaface.retinaface_net import (
+ FPN,
+ SSH,
+ MobileNetV1,
+ make_bbox_head,
+ make_class_head,
+ make_landmark_head,
+)
+from hordelib.nodes.facerestore_cf.facelib.detection.retinaface.retinaface_utils import (
+ PriorBox,
+ batched_decode,
+ batched_decode_landm,
+ decode,
+ decode_landm,
+ py_cpu_nms,
+)
+
+# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = model_management.get_torch_device()
+
+
+def generate_config(network_name):
+
+ cfg_mnet = {
+ "name": "mobilenet0.25",
+ "min_sizes": [[16, 32], [64, 128], [256, 512]],
+ "steps": [8, 16, 32],
+ "variance": [0.1, 0.2],
+ "clip": False,
+ "loc_weight": 2.0,
+ "gpu_train": True,
+ "batch_size": 32,
+ "ngpu": 1,
+ "epoch": 250,
+ "decay1": 190,
+ "decay2": 220,
+ "image_size": 640,
+ "return_layers": {
+ "stage1": 1,
+ "stage2": 2,
+ "stage3": 3,
+ },
+ "in_channel": 32,
+ "out_channel": 64,
+ }
+
+ cfg_re50 = {
+ "name": "Resnet50",
+ "min_sizes": [[16, 32], [64, 128], [256, 512]],
+ "steps": [8, 16, 32],
+ "variance": [0.1, 0.2],
+ "clip": False,
+ "loc_weight": 2.0,
+ "gpu_train": True,
+ "batch_size": 24,
+ "ngpu": 4,
+ "epoch": 100,
+ "decay1": 70,
+ "decay2": 90,
+ "image_size": 840,
+ "return_layers": {
+ "layer2": 1,
+ "layer3": 2,
+ "layer4": 3,
+ },
+ "in_channel": 256,
+ "out_channel": 256,
+ }
+
+ if network_name == "mobile0.25":
+ return cfg_mnet
+ elif network_name == "resnet50":
+ return cfg_re50
+ else:
+ raise NotImplementedError(f"network_name={network_name}")
+
+
+class RetinaFace(nn.Module):
+
+ def __init__(self, network_name="resnet50", half=False, phase="test"):
+ super(RetinaFace, self).__init__()
+ self.half_inference = half
+ cfg = generate_config(network_name)
+ self.backbone = cfg["name"]
+
+ self.model_name = f"retinaface_{network_name}"
+ self.cfg = cfg
+ self.phase = phase
+ self.target_size, self.max_size = 1600, 2150
+ self.resize, self.scale, self.scale1 = 1.0, None, None
+ self.mean_tensor = torch.tensor([[[[104.0]], [[117.0]], [[123.0]]]]).to(device)
+ self.reference = get_reference_facial_points(default_square=True)
+ # Build network.
+ backbone = None
+ if cfg["name"] == "mobilenet0.25":
+ backbone = MobileNetV1()
+ self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
+ elif cfg["name"] == "Resnet50":
+ import torchvision.models as models
+
+ backbone = models.resnet50(pretrained=False)
+ self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
+
+ in_channels_stage2 = cfg["in_channel"]
+ in_channels_list = [
+ in_channels_stage2 * 2,
+ in_channels_stage2 * 4,
+ in_channels_stage2 * 8,
+ ]
+
+ out_channels = cfg["out_channel"]
+ self.fpn = FPN(in_channels_list, out_channels)
+ self.ssh1 = SSH(out_channels, out_channels)
+ self.ssh2 = SSH(out_channels, out_channels)
+ self.ssh3 = SSH(out_channels, out_channels)
+
+ self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
+ self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
+ self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
+
+ self.to(device)
+ self.eval()
+ if self.half_inference:
+ self.half()
+
+ def forward(self, inputs):
+ out = self.body(inputs)
+
+ if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
+ out = list(out.values())
+ # FPN
+ fpn = self.fpn(out)
+
+ # SSH
+ feature1 = self.ssh1(fpn[0])
+ feature2 = self.ssh2(fpn[1])
+ feature3 = self.ssh3(fpn[2])
+ features = [feature1, feature2, feature3]
+
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+ ldm_regressions = torch.cat(tmp, dim=1)
+
+ if self.phase == "train":
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+ return output
+
+ def __detect_faces(self, inputs):
+ # get scale
+ height, width = inputs.shape[2:]
+ self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
+ tmp = [width, height, width, height, width, height, width, height, width, height]
+ self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
+
+ # forawrd
+ inputs = inputs.to(device)
+ if self.half_inference:
+ inputs = inputs.half()
+ loc, conf, landmarks = self(inputs)
+
+ # get priorbox
+ priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+ priors = priorbox.forward().to(device)
+
+ return loc, conf, landmarks, priors
+
+ # single image detection
+ def transform(self, image, use_origin_size):
+ # convert to opencv format
+ if isinstance(image, Image.Image):
+ image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ image = image.astype(np.float32)
+
+ # testing scale
+ im_size_min = np.min(image.shape[0:2])
+ im_size_max = np.max(image.shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+
+ # convert to torch.tensor format
+ # image -= (104, 117, 123)
+ image = image.transpose(2, 0, 1)
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ return image, resize
+
+ def detect_faces(
+ self,
+ image,
+ conf_threshold=0.8,
+ nms_threshold=0.4,
+ use_origin_size=True,
+ ):
+ """
+ Params:
+ imgs: BGR image
+ """
+ image, self.resize = self.transform(image, use_origin_size)
+ image = image.to(device)
+ if self.half_inference:
+ image = image.half()
+ image = image - self.mean_tensor
+
+ loc, conf, landmarks, priors = self.__detect_faces(image)
+
+ boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
+ boxes = boxes * self.scale / self.resize
+ boxes = boxes.cpu().numpy()
+
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+ landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
+ landmarks = landmarks * self.scale1 / self.resize
+ landmarks = landmarks.cpu().numpy()
+
+ # ignore low scores
+ inds = np.where(scores > conf_threshold)[0]
+ boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+ # sort
+ order = scores.argsort()[::-1]
+ boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+ # do NMS
+ bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+ # self.t['forward_pass'].toc()
+ # print(self.t['forward_pass'].average_time)
+ # import sys
+ # sys.stdout.flush()
+ return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+ def __align_multi(self, image, boxes, landmarks, limit=None):
+
+ if len(boxes) < 1:
+ return [], []
+
+ if limit:
+ boxes = boxes[:limit]
+ landmarks = landmarks[:limit]
+
+ faces = []
+ for landmark in landmarks:
+ facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
+
+ warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
+ faces.append(warped_face)
+
+ return np.concatenate((boxes, landmarks), axis=1), faces
+
+ def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+ rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+ boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+ return self.__align_multi(img, boxes, landmarks, limit)
+
+ # batched detection
+ def batched_transform(self, frames, use_origin_size):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+ type=np.float32, BGR format).
+ use_origin_size: whether to use origin size.
+ """
+ from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+ # convert to opencv format
+ if from_PIL:
+ frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
+ frames = np.asarray(frames, dtype=np.float32)
+
+ # testing scale
+ im_size_min = np.min(frames[0].shape[0:2])
+ im_size_max = np.max(frames[0].shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ if not from_PIL:
+ frames = F.interpolate(frames, scale_factor=resize)
+ else:
+ frames = [
+ cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+ for frame in frames
+ ]
+
+ # convert to torch.tensor format
+ if not from_PIL:
+ frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+ else:
+ frames = frames.transpose((0, 3, 1, 2))
+ frames = torch.from_numpy(frames)
+
+ return frames, resize
+
+ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+ type=np.uint8, BGR format).
+ conf_threshold: confidence threshold.
+ nms_threshold: nms threshold.
+ use_origin_size: whether to use origin size.
+ Returns:
+ final_bounding_boxes: list of np.array ([n_boxes, 5],
+ type=np.float32).
+ final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+ """
+ # self.t['forward_pass'].tic()
+ frames, self.resize = self.batched_transform(frames, use_origin_size)
+ frames = frames.to(device)
+ frames = frames - self.mean_tensor
+
+ b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+ final_bounding_boxes, final_landmarks = [], []
+
+ # decode
+ priors = priors.unsqueeze(0)
+ b_loc = batched_decode(b_loc, priors, self.cfg["variance"]) * self.scale / self.resize
+ b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg["variance"]) * self.scale1 / self.resize
+ b_conf = b_conf[:, :, 1]
+
+ # index for selection
+ b_indice = b_conf > conf_threshold
+
+ # concat
+ b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+ for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice, strict=False):
+
+ # ignore low scores
+ pred, landm = pred[inds, :], landm[inds, :]
+ if pred.shape[0] == 0:
+ final_bounding_boxes.append(np.array([], dtype=np.float32))
+ final_landmarks.append(np.array([], dtype=np.float32))
+ continue
+
+ # sort
+ # order = score.argsort(descending=True)
+ # box, landm, score = box[order], landm[order], score[order]
+
+ # to CPU
+ bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+ # NMS
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+ # append
+ final_bounding_boxes.append(bounding_boxes)
+ final_landmarks.append(landmarks)
+ # self.t['forward_pass'].toc(average=True)
+ # self.batch_time += self.t['forward_pass'].diff
+ # self.total_frame += len(frames)
+ # print(self.batch_time / self.total_frame)
+
+ return final_bounding_boxes, final_landmarks
diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py
similarity index 90%
rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py
index ab6aa82d..bd5f6816 100644
--- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py
@@ -1,196 +1,200 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def conv_bn(inp, oup, stride=1, leaky=0):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
- nn.LeakyReLU(negative_slope=leaky, inplace=True))
-
-
-def conv_bn_no_relu(inp, oup, stride):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
- nn.BatchNorm2d(oup),
- )
-
-
-def conv_bn1X1(inp, oup, stride, leaky=0):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
- nn.LeakyReLU(negative_slope=leaky, inplace=True))
-
-
-def conv_dw(inp, oup, stride, leaky=0.1):
- return nn.Sequential(
- nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
- nn.BatchNorm2d(inp),
- nn.LeakyReLU(negative_slope=leaky, inplace=True),
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- nn.LeakyReLU(negative_slope=leaky, inplace=True),
- )
-
-
-class SSH(nn.Module):
-
- def __init__(self, in_channel, out_channel):
- super(SSH, self).__init__()
- assert out_channel % 4 == 0
- leaky = 0
- if (out_channel <= 64):
- leaky = 0.1
- self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
-
- self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
- self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
-
- self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
- self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
-
- def forward(self, input):
- conv3X3 = self.conv3X3(input)
-
- conv5X5_1 = self.conv5X5_1(input)
- conv5X5 = self.conv5X5_2(conv5X5_1)
-
- conv7X7_2 = self.conv7X7_2(conv5X5_1)
- conv7X7 = self.conv7x7_3(conv7X7_2)
-
- out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
- out = F.relu(out)
- return out
-
-
-class FPN(nn.Module):
-
- def __init__(self, in_channels_list, out_channels):
- super(FPN, self).__init__()
- leaky = 0
- if (out_channels <= 64):
- leaky = 0.1
- self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
- self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
- self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
-
- self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
- self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
-
- def forward(self, input):
- # names = list(input.keys())
- # input = list(input.values())
-
- output1 = self.output1(input[0])
- output2 = self.output2(input[1])
- output3 = self.output3(input[2])
-
- up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
- output2 = output2 + up3
- output2 = self.merge2(output2)
-
- up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
- output1 = output1 + up2
- output1 = self.merge1(output1)
-
- out = [output1, output2, output3]
- return out
-
-
-class MobileNetV1(nn.Module):
-
- def __init__(self):
- super(MobileNetV1, self).__init__()
- self.stage1 = nn.Sequential(
- conv_bn(3, 8, 2, leaky=0.1), # 3
- conv_dw(8, 16, 1), # 7
- conv_dw(16, 32, 2), # 11
- conv_dw(32, 32, 1), # 19
- conv_dw(32, 64, 2), # 27
- conv_dw(64, 64, 1), # 43
- )
- self.stage2 = nn.Sequential(
- conv_dw(64, 128, 2), # 43 + 16 = 59
- conv_dw(128, 128, 1), # 59 + 32 = 91
- conv_dw(128, 128, 1), # 91 + 32 = 123
- conv_dw(128, 128, 1), # 123 + 32 = 155
- conv_dw(128, 128, 1), # 155 + 32 = 187
- conv_dw(128, 128, 1), # 187 + 32 = 219
- )
- self.stage3 = nn.Sequential(
- conv_dw(128, 256, 2), # 219 +3 2 = 241
- conv_dw(256, 256, 1), # 241 + 64 = 301
- )
- self.avg = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(256, 1000)
-
- def forward(self, x):
- x = self.stage1(x)
- x = self.stage2(x)
- x = self.stage3(x)
- x = self.avg(x)
- # x = self.model(x)
- x = x.view(-1, 256)
- x = self.fc(x)
- return x
-
-
-class ClassHead(nn.Module):
-
- def __init__(self, inchannels=512, num_anchors=3):
- super(ClassHead, self).__init__()
- self.num_anchors = num_anchors
- self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
-
- def forward(self, x):
- out = self.conv1x1(x)
- out = out.permute(0, 2, 3, 1).contiguous()
-
- return out.view(out.shape[0], -1, 2)
-
-
-class BboxHead(nn.Module):
-
- def __init__(self, inchannels=512, num_anchors=3):
- super(BboxHead, self).__init__()
- self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
-
- def forward(self, x):
- out = self.conv1x1(x)
- out = out.permute(0, 2, 3, 1).contiguous()
-
- return out.view(out.shape[0], -1, 4)
-
-
-class LandmarkHead(nn.Module):
-
- def __init__(self, inchannels=512, num_anchors=3):
- super(LandmarkHead, self).__init__()
- self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
-
- def forward(self, x):
- out = self.conv1x1(x)
- out = out.permute(0, 2, 3, 1).contiguous()
-
- return out.view(out.shape[0], -1, 10)
-
-
-def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
- classhead = nn.ModuleList()
- for i in range(fpn_num):
- classhead.append(ClassHead(inchannels, anchor_num))
- return classhead
-
-
-def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
- bboxhead = nn.ModuleList()
- for i in range(fpn_num):
- bboxhead.append(BboxHead(inchannels, anchor_num))
- return bboxhead
-
-
-def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
- landmarkhead = nn.ModuleList()
- for i in range(fpn_num):
- landmarkhead.append(LandmarkHead(inchannels, anchor_num))
- return landmarkhead
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+def conv_bn_no_relu(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+class SSH(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(SSH, self).__init__()
+ assert out_channel % 4 == 0
+ leaky = 0
+ if out_channel <= 64:
+ leaky = 0.1
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+ self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+ self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+ self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ def forward(self, input):
+ conv3X3 = self.conv3X3(input)
+
+ conv5X5_1 = self.conv5X5_1(input)
+ conv5X5 = self.conv5X5_2(conv5X5_1)
+
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
+ conv7X7 = self.conv7x7_3(conv7X7_2)
+
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+ out = F.relu(out)
+ return out
+
+
+class FPN(nn.Module):
+
+ def __init__(self, in_channels_list, out_channels):
+ super(FPN, self).__init__()
+ leaky = 0
+ if out_channels <= 64:
+ leaky = 0.1
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+ self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+ self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+ def forward(self, input):
+ # names = list(input.keys())
+ # input = list(input.values())
+
+ output1 = self.output1(input[0])
+ output2 = self.output2(input[1])
+ output3 = self.output3(input[2])
+
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
+ output2 = output2 + up3
+ output2 = self.merge2(output2)
+
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
+ output1 = output1 + up2
+ output1 = self.merge1(output1)
+
+ out = [output1, output2, output3]
+ return out
+
+
+class MobileNetV1(nn.Module):
+
+ def __init__(self):
+ super(MobileNetV1, self).__init__()
+ self.stage1 = nn.Sequential(
+ conv_bn(3, 8, 2, leaky=0.1), # 3
+ conv_dw(8, 16, 1), # 7
+ conv_dw(16, 32, 2), # 11
+ conv_dw(32, 32, 1), # 19
+ conv_dw(32, 64, 2), # 27
+ conv_dw(64, 64, 1), # 43
+ )
+ self.stage2 = nn.Sequential(
+ conv_dw(64, 128, 2), # 43 + 16 = 59
+ conv_dw(128, 128, 1), # 59 + 32 = 91
+ conv_dw(128, 128, 1), # 91 + 32 = 123
+ conv_dw(128, 128, 1), # 123 + 32 = 155
+ conv_dw(128, 128, 1), # 155 + 32 = 187
+ conv_dw(128, 128, 1), # 187 + 32 = 219
+ )
+ self.stage3 = nn.Sequential(
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
+ conv_dw(256, 256, 1), # 241 + 64 = 301
+ )
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(256, 1000)
+
+ def forward(self, x):
+ x = self.stage1(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.avg(x)
+ # x = self.model(x)
+ x = x.view(-1, 256)
+ x = self.fc(x)
+ return x
+
+
+class ClassHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(ClassHead, self).__init__()
+ self.num_anchors = num_anchors
+ self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(BboxHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(LandmarkHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+ classhead = nn.ModuleList()
+ for i in range(fpn_num):
+ classhead.append(ClassHead(inchannels, anchor_num))
+ return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+ bboxhead = nn.ModuleList()
+ for i in range(fpn_num):
+ bboxhead.append(BboxHead(inchannels, anchor_num))
+ return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+ landmarkhead = nn.ModuleList()
+ for i in range(fpn_num):
+ landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+ return landmarkhead
diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py
similarity index 92%
rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py
index 8c357757..1900c0b7 100644
--- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py
@@ -1,421 +1,420 @@
-import numpy as np
-import torch
-import torchvision
-from itertools import product as product
-from math import ceil
-
-
-class PriorBox(object):
-
- def __init__(self, cfg, image_size=None, phase='train'):
- super(PriorBox, self).__init__()
- self.min_sizes = cfg['min_sizes']
- self.steps = cfg['steps']
- self.clip = cfg['clip']
- self.image_size = image_size
- self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
- self.name = 's'
-
- def forward(self):
- anchors = []
- for k, f in enumerate(self.feature_maps):
- min_sizes = self.min_sizes[k]
- for i, j in product(range(f[0]), range(f[1])):
- for min_size in min_sizes:
- s_kx = min_size / self.image_size[1]
- s_ky = min_size / self.image_size[0]
- dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
- dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
- for cy, cx in product(dense_cy, dense_cx):
- anchors += [cx, cy, s_kx, s_ky]
-
- # back to torch land
- output = torch.Tensor(anchors).view(-1, 4)
- if self.clip:
- output.clamp_(max=1, min=0)
- return output
-
-
-def py_cpu_nms(dets, thresh):
- """Pure Python NMS baseline."""
- keep = torchvision.ops.nms(
- boxes=torch.Tensor(dets[:, :4]),
- scores=torch.Tensor(dets[:, 4]),
- iou_threshold=thresh,
- )
-
- return list(keep)
-
-
-def point_form(boxes):
- """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
- representation for comparison to point form ground truth data.
- Args:
- boxes: (tensor) center-size default boxes from priorbox layers.
- Return:
- boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
- """
- return torch.cat(
- (
- boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
- boxes[:, :2] + boxes[:, 2:] / 2),
- 1) # xmax, ymax
-
-
-def center_size(boxes):
- """ Convert prior_boxes to (cx, cy, w, h)
- representation for comparison to center-size form ground truth data.
- Args:
- boxes: (tensor) point_form boxes
- Return:
- boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
- """
- return torch.cat(
- (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
- boxes[:, 2:] - boxes[:, :2],
- 1) # w, h
-
-
-def intersect(box_a, box_b):
- """ We resize both tensors to [A,B,2] without new malloc:
- [A,2] -> [A,1,2] -> [A,B,2]
- [B,2] -> [1,B,2] -> [A,B,2]
- Then we compute the area of intersect between box_a and box_b.
- Args:
- box_a: (tensor) bounding boxes, Shape: [A,4].
- box_b: (tensor) bounding boxes, Shape: [B,4].
- Return:
- (tensor) intersection area, Shape: [A,B].
- """
- A = box_a.size(0)
- B = box_b.size(0)
- max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
- min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
- inter = torch.clamp((max_xy - min_xy), min=0)
- return inter[:, :, 0] * inter[:, :, 1]
-
-
-def jaccard(box_a, box_b):
- """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
- is simply the intersection over union of two boxes. Here we operate on
- ground truth boxes and default boxes.
- E.g.:
- A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
- Args:
- box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
- box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
- Return:
- jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
- """
- inter = intersect(box_a, box_b)
- area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
- area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
- union = area_a + area_b - inter
- return inter / union # [A,B]
-
-
-def matrix_iou(a, b):
- """
- return iou of a and b, numpy version for data augenmentation
- """
- lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
- rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
-
- area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
- area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
- area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
- return area_i / (area_a[:, np.newaxis] + area_b - area_i)
-
-
-def matrix_iof(a, b):
- """
- return iof of a and b, numpy version for data augenmentation
- """
- lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
- rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
-
- area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
- area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
- return area_i / np.maximum(area_a[:, np.newaxis], 1)
-
-
-def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
- """Match each prior box with the ground truth box of the highest jaccard
- overlap, encode the bounding boxes, then return the matched indices
- corresponding to both confidence and location preds.
- Args:
- threshold: (float) The overlap threshold used when matching boxes.
- truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
- priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
- variances: (tensor) Variances corresponding to each prior coord,
- Shape: [num_priors, 4].
- labels: (tensor) All the class labels for the image, Shape: [num_obj].
- landms: (tensor) Ground truth landms, Shape [num_obj, 10].
- loc_t: (tensor) Tensor to be filled w/ encoded location targets.
- conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
- landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
- idx: (int) current batch index
- Return:
- The matched indices corresponding to 1)location 2)confidence
- 3)landm preds.
- """
- # jaccard index
- overlaps = jaccard(truths, point_form(priors))
- # (Bipartite Matching)
- # [1,num_objects] best prior for each ground truth
- best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
-
- # ignore hard gt
- valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
- best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
- if best_prior_idx_filter.shape[0] <= 0:
- loc_t[idx] = 0
- conf_t[idx] = 0
- return
-
- # [1,num_priors] best ground truth for each prior
- best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
- best_truth_idx.squeeze_(0)
- best_truth_overlap.squeeze_(0)
- best_prior_idx.squeeze_(1)
- best_prior_idx_filter.squeeze_(1)
- best_prior_overlap.squeeze_(1)
- best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
- # TODO refactor: index best_prior_idx with long tensor
- # ensure every gt matches with its prior of max overlap
- for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
- best_truth_idx[best_prior_idx[j]] = j
- matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
- conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
- conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
- loc = encode(matches, priors, variances)
-
- matches_landm = landms[best_truth_idx]
- landm = encode_landm(matches_landm, priors, variances)
- loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
- conf_t[idx] = conf # [num_priors] top class label for each prior
- landm_t[idx] = landm
-
-
-def encode(matched, priors, variances):
- """Encode the variances from the priorbox layers into the ground truth boxes
- we have matched (based on jaccard overlap) with the prior boxes.
- Args:
- matched: (tensor) Coords of ground truth for each prior in point-form
- Shape: [num_priors, 4].
- priors: (tensor) Prior boxes in center-offset form
- Shape: [num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- encoded boxes (tensor), Shape: [num_priors, 4]
- """
-
- # dist b/t match center and prior's center
- g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
- # encode variance
- g_cxcy /= (variances[0] * priors[:, 2:])
- # match wh / prior wh
- g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
- g_wh = torch.log(g_wh) / variances[1]
- # return target for smooth_l1_loss
- return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
-
-
-def encode_landm(matched, priors, variances):
- """Encode the variances from the priorbox layers into the ground truth boxes
- we have matched (based on jaccard overlap) with the prior boxes.
- Args:
- matched: (tensor) Coords of ground truth for each prior in point-form
- Shape: [num_priors, 10].
- priors: (tensor) Prior boxes in center-offset form
- Shape: [num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- encoded landm (tensor), Shape: [num_priors, 10]
- """
-
- # dist b/t match center and prior's center
- matched = torch.reshape(matched, (matched.size(0), 5, 2))
- priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
- priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
- priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
- priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
- priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
- g_cxcy = matched[:, :, :2] - priors[:, :, :2]
- # encode variance
- g_cxcy /= (variances[0] * priors[:, :, 2:])
- # g_cxcy /= priors[:, :, 2:]
- g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
- # return target for smooth_l1_loss
- return g_cxcy
-
-
-# Adapted from https://github.com/Hakuyume/chainer-ssd
-def decode(loc, priors, variances):
- """Decode locations from predictions using priors to undo
- the encoding we did for offset regression at train time.
- Args:
- loc (tensor): location predictions for loc layers,
- Shape: [num_priors,4]
- priors (tensor): Prior boxes in center-offset form.
- Shape: [num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- decoded bounding box predictions
- """
-
- boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
- priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
- boxes[:, :2] -= boxes[:, 2:] / 2
- boxes[:, 2:] += boxes[:, :2]
- return boxes
-
-
-def decode_landm(pre, priors, variances):
- """Decode landm from predictions using priors to undo
- the encoding we did for offset regression at train time.
- Args:
- pre (tensor): landm predictions for loc layers,
- Shape: [num_priors,10]
- priors (tensor): Prior boxes in center-offset form.
- Shape: [num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- decoded landm predictions
- """
- tmp = (
- priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
- priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
- priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
- priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
- priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
- )
- landms = torch.cat(tmp, dim=1)
- return landms
-
-
-def batched_decode(b_loc, priors, variances):
- """Decode locations from predictions using priors to undo
- the encoding we did for offset regression at train time.
- Args:
- b_loc (tensor): location predictions for loc layers,
- Shape: [num_batches,num_priors,4]
- priors (tensor): Prior boxes in center-offset form.
- Shape: [1,num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- decoded bounding box predictions
- """
- boxes = (
- priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
- priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
- )
- boxes = torch.cat(boxes, dim=2)
-
- boxes[:, :, :2] -= boxes[:, :, 2:] / 2
- boxes[:, :, 2:] += boxes[:, :, :2]
- return boxes
-
-
-def batched_decode_landm(pre, priors, variances):
- """Decode landm from predictions using priors to undo
- the encoding we did for offset regression at train time.
- Args:
- pre (tensor): landm predictions for loc layers,
- Shape: [num_batches,num_priors,10]
- priors (tensor): Prior boxes in center-offset form.
- Shape: [1,num_priors,4].
- variances: (list[float]) Variances of priorboxes
- Return:
- decoded landm predictions
- """
- landms = (
- priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
- priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
- priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
- priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
- priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
- )
- landms = torch.cat(landms, dim=2)
- return landms
-
-
-def log_sum_exp(x):
- """Utility function for computing log_sum_exp while determining
- This will be used to determine unaveraged confidence loss across
- all examples in a batch.
- Args:
- x (Variable(tensor)): conf_preds from conf layers
- """
- x_max = x.data.max()
- return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
-
-
-# Original author: Francisco Massa:
-# https://github.com/fmassa/object-detection.torch
-# Ported to PyTorch by Max deGroot (02/01/2017)
-def nms(boxes, scores, overlap=0.5, top_k=200):
- """Apply non-maximum suppression at test time to avoid detecting too many
- overlapping bounding boxes for a given object.
- Args:
- boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
- scores: (tensor) The class predscores for the img, Shape:[num_priors].
- overlap: (float) The overlap thresh for suppressing unnecessary boxes.
- top_k: (int) The Maximum number of box preds to consider.
- Return:
- The indices of the kept boxes with respect to num_priors.
- """
-
- keep = torch.Tensor(scores.size(0)).fill_(0).long()
- if boxes.numel() == 0:
- return keep
- x1 = boxes[:, 0]
- y1 = boxes[:, 1]
- x2 = boxes[:, 2]
- y2 = boxes[:, 3]
- area = torch.mul(x2 - x1, y2 - y1)
- v, idx = scores.sort(0) # sort in ascending order
- # I = I[v >= 0.01]
- idx = idx[-top_k:] # indices of the top-k largest vals
- xx1 = boxes.new()
- yy1 = boxes.new()
- xx2 = boxes.new()
- yy2 = boxes.new()
- w = boxes.new()
- h = boxes.new()
-
- # keep = torch.Tensor()
- count = 0
- while idx.numel() > 0:
- i = idx[-1] # index of current largest val
- # keep.append(i)
- keep[count] = i
- count += 1
- if idx.size(0) == 1:
- break
- idx = idx[:-1] # remove kept element from view
- # load bboxes of next highest vals
- torch.index_select(x1, 0, idx, out=xx1)
- torch.index_select(y1, 0, idx, out=yy1)
- torch.index_select(x2, 0, idx, out=xx2)
- torch.index_select(y2, 0, idx, out=yy2)
- # store element-wise max with next highest score
- xx1 = torch.clamp(xx1, min=x1[i])
- yy1 = torch.clamp(yy1, min=y1[i])
- xx2 = torch.clamp(xx2, max=x2[i])
- yy2 = torch.clamp(yy2, max=y2[i])
- w.resize_as_(xx2)
- h.resize_as_(yy2)
- w = xx2 - xx1
- h = yy2 - yy1
- # check sizes of xx1 and xx2.. after each iteration
- w = torch.clamp(w, min=0.0)
- h = torch.clamp(h, min=0.0)
- inter = w * h
- # IoU = i / (area(a) + area(b) - i)
- rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
- union = (rem_areas - inter) + area[i]
- IoU = inter / union # store result in iou
- # keep only elements with an IoU <= overlap
- idx = idx[IoU.le(overlap)]
- return keep, count
+from itertools import product as product
+from math import ceil
+
+import numpy as np
+import torch
+import torchvision
+
+
+class PriorBox:
+
+ def __init__(self, cfg, image_size=None, phase="train"):
+ super(PriorBox, self).__init__()
+ self.min_sizes = cfg["min_sizes"]
+ self.steps = cfg["steps"]
+ self.clip = cfg["clip"]
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+ self.name = "s"
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
+
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ keep = torchvision.ops.nms(
+ boxes=torch.Tensor(dets[:, :4]),
+ scores=torch.Tensor(dets[:, 4]),
+ iou_threshold=thresh,
+ )
+
+ return list(keep)
+
+
+def point_form(boxes):
+ """Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmin, ymin # xmax, ymax
+
+
+def center_size(boxes):
+ """Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2, boxes[:, 2:] - boxes[:, :2], 1) # cx, cy # w, h
+
+
+def intersect(box_a, box_b):
+ """We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when matching boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+ loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location 2)confidence
+ 3)landm preds.
+ """
+ # jaccard index
+ overlaps = jaccard(truths, point_form(priors))
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
+ loc = encode(matches, priors, variances)
+
+ matches_landm = landms[best_truth_idx]
+ landm = encode_landm(matches_landm, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+ landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= variances[0] * priors[:, 2:]
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 10].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded landm (tensor), Shape: [num_priors, 10]
+ """
+
+ # dist b/t match center and prior's center
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+ # encode variance
+ g_cxcy /= variances[0] * priors[:, :, 2:]
+ # g_cxcy /= priors[:, :, 2:]
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+ # return target for smooth_l1_loss
+ return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat(
+ (
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
+ ),
+ 1,
+ )
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ tmp = (
+ priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ )
+ landms = torch.cat(tmp, dim=1)
+ return landms
+
+
+def batched_decode(b_loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ b_loc (tensor): location predictions for loc layers,
+ Shape: [num_batches,num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+ boxes = (
+ priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+ )
+ boxes = torch.cat(boxes, dim=2)
+
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+ boxes[:, :, 2:] += boxes[:, :, :2]
+ return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_batches,num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ landms = (
+ priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+ )
+ landms = torch.cat(landms, dim=2)
+ return landms
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w * h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter / union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/__init__.py
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/__init__.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/__init__.py
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py
similarity index 71%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py
index 5ea44d65..7dd6378c 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py
@@ -1,175 +1,148 @@
-import copy
-import os
-from pathlib import Path
-
-import cv2
-import numpy as np
-import torch
-from torch import nn
-
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.yolo import Model
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.datasets import letterbox
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import (
- check_img_size,
- non_max_suppression_face,
- scale_coords,
- scale_coords_landmarks,
-)
-
-
-def is_high_version():
- from packaging import version
- try:
- torch_v = version.parse(torch.__version__)
- return torch_v > version.parse("1.9.0")
- except Exception:
- return True
-
-
-def isListempty(inList):
- if isinstance(inList, list): # Is a list
- return all(map(isListempty, inList))
- return False # Not a list
-
-
-class YoloDetector:
- def __init__(
- self,
- config_name,
- min_face=10,
- target_size=None,
- device="cuda",
- ):
- """
- config_name: name of .yaml config with network configuration from models/ folder.
- min_face : minimal face size in pixels.
- target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
- None for original resolution.
- """
- self._class_path = Path(__file__).parent.absolute()
- self.target_size = target_size
- self.min_face = min_face
- self.detector = Model(cfg=config_name)
- self.device = device
-
- def _preprocess(self, imgs):
- """
- Preprocessing image before passing through the network. Resize and conversion to torch tensor.
- """
- pp_imgs = []
- for img in imgs:
- h0, w0 = img.shape[:2] # orig hw
- if self.target_size:
- r = self.target_size / min(h0, w0) # resize image to img_size
- if r < 1:
- img = cv2.resize(
- img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR
- )
-
- imgsz = check_img_size(
- max(img.shape[:2]), s=self.detector.stride.max()
- ) # check img_size
- img = letterbox(img, new_shape=imgsz)[0]
- pp_imgs.append(img)
- pp_imgs = np.array(pp_imgs)
- pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
- pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
- pp_imgs = pp_imgs.float() # uint8 to fp16/32
- return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
-
- def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
- """
- Postprocessing of raw pytorch model output.
- Returns:
- bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
- points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
- """
- bboxes = [[] for _ in range(len(origimgs))]
- landmarks = [[] for _ in range(len(origimgs))]
-
- pred = non_max_suppression_face(pred, conf_thres, iou_thres)
-
- for image_id, origimg in enumerate(origimgs):
- img_shape = origimg.shape
- image_height, image_width = img_shape[:2]
- gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
- gn_lks = torch.tensor(img_shape)[
- [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
- ] # normalization gain landmarks
- det = pred[image_id].cpu()
- scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
- scale_coords_landmarks(
- imgs[image_id].shape[1:], det[:, 5:15], img_shape
- ).round()
-
- for j in range(det.size()[0]):
- box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
- box = list(
- map(
- int,
- [
- box[0] * image_width,
- box[1] * image_height,
- box[2] * image_width,
- box[3] * image_height,
- ],
- )
- )
- if box[3] - box[1] < self.min_face:
- continue
- lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
- lm = list(
- map(
- int,
- [
- i * image_width if j % 2 == 0 else i * image_height
- for j, i in enumerate(lm)
- ],
- )
- )
- lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
- bboxes[image_id].append(box)
- landmarks[image_id].append(lm)
- return bboxes, landmarks
-
- def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
- """
- Get bbox coordinates and keypoints of faces on original image.
- Params:
- imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
- conf_thres: confidence threshold for each prediction
- iou_thres: threshold for NMS (filter of intersecting bboxes)
- Returns:
- bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
- points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
- """
- # Pass input images through face detector
- images = imgs if isinstance(imgs, list) else [imgs]
- images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
- origimgs = copy.deepcopy(images)
-
- images = self._preprocess(images)
-
- if is_high_version():
- with torch.inference_mode(): # for pytorch>=1.9
- pred = self.detector(images)[0]
- else:
- with torch.no_grad(): # for pytorch<1.9
- pred = self.detector(images)[0]
-
- bboxes, points = self._postprocess(
- images, origimgs, pred, conf_thres, iou_thres
- )
-
- # return bboxes, points
- if not isListempty(points):
- bboxes = np.array(bboxes).reshape(-1, 4)
- points = np.array(points).reshape(-1, 10)
- padding = bboxes[:, 0].reshape(-1, 1)
- return np.concatenate((bboxes, padding, points), axis=1)
- else:
- return None
-
- def __call__(self, *args):
- return self.predict(*args)
+import copy
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.yolo import Model
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.datasets import letterbox
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import (
+ check_img_size,
+ non_max_suppression_face,
+ scale_coords,
+ scale_coords_landmarks,
+)
+
+try:
+ version_str = torch.__version__.split("+")[0]
+ major, minor, patch = map(int, version_str.split("."))
+ IS_HIGH_VERSION = (major, minor, patch) >= (1, 9, 0)
+except ValueError:
+ # Handle the case of a development version here
+ IS_HIGH_VERSION = False
+
+
+def isListempty(inList):
+ if isinstance(inList, list): # Is a list
+ return all(map(isListempty, inList))
+ return False # Not a list
+
+
+class YoloDetector:
+ def __init__(
+ self,
+ config_name,
+ min_face=10,
+ target_size=None,
+ device="cuda",
+ ):
+ """
+ config_name: name of .yaml config with network configuration from models/ folder.
+ min_face : minimal face size in pixels.
+ target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
+ None for original resolution.
+ """
+ self._class_path = Path(__file__).parent.absolute()
+ self.target_size = target_size
+ self.min_face = min_face
+ self.detector = Model(cfg=config_name)
+ self.device = device
+
+ def _preprocess(self, imgs):
+ """
+ Preprocessing image before passing through the network. Resize and conversion to torch tensor.
+ """
+ pp_imgs = []
+ for img in imgs:
+ h0, w0 = img.shape[:2] # orig hw
+ if self.target_size:
+ r = self.target_size / min(h0, w0) # resize image to img_size
+ if r < 1:
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
+
+ imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
+ img = letterbox(img, new_shape=imgsz)[0]
+ pp_imgs.append(img)
+ pp_imgs = np.array(pp_imgs)
+ pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
+ pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
+ pp_imgs = pp_imgs.float() # uint8 to fp16/32
+ return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
+
+ def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
+ """
+ Postprocessing of raw pytorch model output.
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ bboxes = [[] for _ in range(len(origimgs))]
+ landmarks = [[] for _ in range(len(origimgs))]
+
+ pred = non_max_suppression_face(pred, conf_thres, iou_thres)
+
+ for image_id, origimg in enumerate(origimgs):
+ img_shape = origimg.shape
+ image_height, image_width = img_shape[:2]
+ gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
+ gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
+ det = pred[image_id].cpu()
+ scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
+ scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
+
+ for j in range(det.size()[0]):
+ box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
+ box = list(
+ map(
+ int,
+ [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height],
+ ),
+ )
+ if box[3] - box[1] < self.min_face:
+ continue
+ lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
+ lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
+ lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
+ bboxes[image_id].append(box)
+ landmarks[image_id].append(lm)
+ return bboxes, landmarks
+
+ def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
+ """
+ Get bbox coordinates and keypoints of faces on original image.
+ Params:
+ imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
+ conf_thres: confidence threshold for each prediction
+ iou_thres: threshold for NMS (filter of intersecting bboxes)
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ # Pass input images through face detector
+ images = imgs if isinstance(imgs, list) else [imgs]
+ images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
+ origimgs = copy.deepcopy(images)
+
+ images = self._preprocess(images)
+
+ if IS_HIGH_VERSION:
+ with torch.inference_mode(): # for pytorch>=1.9
+ pred = self.detector(images)[0]
+ else:
+ with torch.no_grad(): # for pytorch<1.9
+ pred = self.detector(images)[0]
+
+ bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
+
+ # return bboxes, points
+ if not isListempty(points):
+ bboxes = np.array(bboxes).reshape(-1, 4)
+ points = np.array(points).reshape(-1, 10)
+ padding = bboxes[:, 0].reshape(-1, 1)
+ return np.concatenate((bboxes, padding, points), axis=1)
+ else:
+ return None
+
+ def __call__(self, *args):
+ return self.predict(*args)
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/__init__.py
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/__init__.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/__init__.py
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py
similarity index 92%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py
index 9c6c5eaa..d9559779 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py
@@ -6,8 +6,8 @@
import torch
from torch import nn
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.datasets import letterbox
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import (
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.datasets import letterbox
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import (
make_divisible,
non_max_suppression,
scale_coords,
@@ -149,22 +149,9 @@ def __init__(self, inp, oup, stride):
),
nn.BatchNorm2d(branch_features),
nn.SiLU(),
- self.depthwise_conv(
- branch_features,
- branch_features,
- kernel_size=3,
- stride=self.stride,
- padding=1,
- ),
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
- nn.Conv2d(
- branch_features,
- branch_features,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=False,
- ),
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.SiLU(),
)
@@ -204,17 +191,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
- return self.conv(
- torch.cat(
- [
- x[..., ::2, ::2],
- x[..., 1::2, ::2],
- x[..., ::2, 1::2],
- x[..., 1::2, 1::2],
- ],
- 1,
- )
- )
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
class Concat(nn.Module):
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py
similarity index 78%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py
index c33acba5..e04914e9 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py
@@ -4,7 +4,7 @@
import torch
from torch import nn
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import Conv
class CrossConv(nn.Module):
@@ -35,16 +35,9 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
- c_ = np.linalg.lstsq(a, b, rcond=None)[
- 0
- ].round() # solve for equal weight indices, ax = b
-
- self.m = nn.ModuleList(
- [
- nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False)
- for g in range(groups)
- ]
- )
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
+
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py
similarity index 64%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py
index b11af3e4..fe22178d 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py
@@ -6,7 +6,7 @@
import yaml # for torch hub
from torch import nn
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import (
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import (
C3,
NMS,
SPP,
@@ -20,16 +20,10 @@
ShuffleV2Block,
StemBlock,
)
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.experimental import (
- CrossConv,
- MixConv2d,
-)
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import make_divisible
-from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.torch_utils import (
- copy_attr,
- fuse_conv_and_bn,
-)
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import make_divisible
+from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
class Detect(nn.Module):
@@ -46,12 +40,8 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer("anchors", a) # shape(nl,na,2)
- self.register_buffer(
- "anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)
- ) # shape(nl,1,na,1,1,2)
- self.m = nn.ModuleList(
- nn.Conv2d(x, self.no * self.na, 1) for x in ch
- ) # output conv
+ self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
def forward(self, x):
z = [] # inference output
@@ -62,12 +52,7 @@ def forward(self, x):
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
- x[i] = (
- x[i]
- .view(bs, self.na, self.no, ny, nx)
- .permute(0, 1, 3, 4, 2)
- .contiguous()
- )
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
@@ -77,32 +62,23 @@ def forward(self, x):
y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
y[..., 5:15] = x[i][..., 5:15]
- y[..., 0:2] = (
- y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)
- ) * self.stride[
- i
- ] # xy
+ y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y[..., 5:7] = (
- y[..., 5:7] * self.anchor_grid[i]
- + self.grid[i].to(x[i].device) * self.stride[i]
+ y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x1 y1
y[..., 7:9] = (
- y[..., 7:9] * self.anchor_grid[i]
- + self.grid[i].to(x[i].device) * self.stride[i]
+ y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x2 y2
y[..., 9:11] = (
- y[..., 9:11] * self.anchor_grid[i]
- + self.grid[i].to(x[i].device) * self.stride[i]
+ y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x3 y3
y[..., 11:13] = (
- y[..., 11:13] * self.anchor_grid[i]
- + self.grid[i].to(x[i].device) * self.stride[i]
+ y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x4 y4
y[..., 13:15] = (
- y[..., 13:15] * self.anchor_grid[i]
- + self.grid[i].to(x[i].device) * self.stride[i]
+ y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x5 y5
z.append(y.view(bs, -1, self.no))
@@ -117,9 +93,7 @@ def _make_grid(nx=20, ny=20):
class Model(nn.Module):
- def __init__(
- self, cfg="yolov5s.yaml", ch=3, nc=None
- ): # model, input channels, number of classes
+ def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes
super().__init__()
self.yaml_file = Path(cfg).name
with Path(cfg).open(encoding="utf8") as f:
@@ -130,18 +104,14 @@ def __init__(
if nc and nc != self.yaml["nc"]:
self.yaml["nc"] = nc # override yaml value
- self.model, self.save = parse_model(
- deepcopy(self.yaml), ch=[ch]
- ) # model, savelist
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml["nc"])] # default names
# Build strides, anchors
m = self.model[-1] # Detect()
if isinstance(m, Detect):
s = 128 # 2x min stride
- m.stride = torch.tensor(
- [s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]
- ) # forward
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
@@ -154,42 +124,27 @@ def forward_once(self, x):
y = [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
- x = (
- y[m.f]
- if isinstance(m.f, int)
- else [x if j == -1 else y[j] for j in m.f]
- ) # from earlier layers
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
return x
- def _initialize_biases(
- self, cf=None
- ): # initialize biases into Detect(), cf is class frequency
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
m = self.model[-1] # Detect() module
for mi, s in zip(m.m, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
- b.data[:, 4] += math.log(
- 8 / (640 / s) ** 2
- ) # obj (8 objects per 640 image)
- b.data[:, 5:] += (
- math.log(0.6 / (m.nc - 0.99))
- if cf is None
- else torch.log(cf / cf.sum())
- ) # cls
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def _print_biases(self):
m = self.model[-1] # Detect() module
for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
- print(
- ("%6g Conv2d.bias:" + "%10.3g" * 6)
- % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())
- )
+ print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print("Fusing layers... ")
@@ -219,28 +174,17 @@ def nms(self, mode=True): # add or remove NMS module
def autoshape(self): # add autoShape module
print("Adding autoShape... ")
m = AutoShape(self) # wrap model
- copy_attr(
- m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()
- ) # copy attributes
+ copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes
return m
def parse_model(d, ch): # model_dict, input_channels(3)
- anchors, nc, gd, gw = (
- d["anchors"],
- d["nc"],
- d["depth_multiple"],
- d["width_multiple"],
- )
- na = (
- (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors
- ) # number of anchors
+ anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
- for i, (f, n, m, args) in enumerate(
- d["backbone"] + d["head"]
- ): # from, number, module, args
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
@@ -281,20 +225,11 @@ def parse_model(d, ch): # model_dict, input_channels(3)
else:
c2 = ch[f]
- m_ = (
- nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)
- ) # module
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace("__main__.", "") # module type
np = sum(x.numel() for x in m_.parameters()) # number params
- m_.i, m_.f, m_.type, m_.np = (
- i,
- f,
- t,
- np,
- ) # attach index, 'from' index, type, number params
- save.extend(
- x % i for x in ([f] if isinstance(f, int) else f) if x != -1
- ) # append to savelist
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5l.yaml
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5l.yaml
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5n.yaml
similarity index 100%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5n.yaml
diff --git a/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py
similarity index 97%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py
index a4eba3e9..cb0de894 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py
@@ -1,12 +1,12 @@
-# Auto-anchor utils
-
-
-def check_anchor_order(m):
- # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
- a = m.anchor_grid.prod(-1).view(-1) # anchor area
- da = a[-1] - a[0] # delta a
- ds = m.stride[-1] - m.stride[0] # delta s
- if da.sign() != ds.sign(): # same order
- print("Reversing anchor order")
- m.anchors[:] = m.anchors.flip(0)
- m.anchor_grid[:] = m.anchor_grid.flip(0)
+# Auto-anchor utils
+
+
+def check_anchor_order(m):
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
+ da = a[-1] - a[0] # delta a
+ ds = m.stride[-1] - m.stride[0] # delta s
+ if da.sign() != ds.sign(): # same order
+ print("Reversing anchor order")
+ m.anchors[:] = m.anchors.flip(0)
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py
similarity index 97%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py
index e672b136..a72609b4 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py
@@ -1,35 +1,35 @@
-import cv2
-import numpy as np
-
-
-def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
- # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
- shape = img.shape[:2] # current shape [height, width]
- if isinstance(new_shape, int):
- new_shape = (new_shape, new_shape)
-
- # Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
- if not scaleup: # only scale down, do not scale up (for better test mAP)
- r = min(r, 1.0)
-
- # Compute padding
- ratio = r, r # width, height ratios
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
- dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
- if auto: # minimum rectangle
- dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
- elif scale_fill: # stretch
- dw, dh = 0.0, 0.0
- new_unpad = (new_shape[1], new_shape[0])
- ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
-
- dw /= 2 # divide padding into 2 sides
- dh /= 2
-
- if shape[::-1] != new_unpad: # resize
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
- return img, ratio, (dw, dh)
+import cv2
+import numpy as np
+
+
+def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
+ shape = img.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
+ elif scale_fill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return img, ratio, (dw, dh)
diff --git a/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py
new file mode 100644
index 00000000..cddcbb07
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py
@@ -0,0 +1,7 @@
+import sys
+
+import torch
+
+sys.path.insert(0, "./facelib/detection/yolov5face")
+model = torch.load("facelib/detection/yolov5face/yolov5n-face.pt", map_location="cpu")["model"]
+torch.save(model.state_dict(), "../../models/facedetection")
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py
similarity index 97%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py
index 1c8e14f5..618d2f31 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py
@@ -1,271 +1,271 @@
-import math
-import time
-
-import numpy as np
-import torch
-import torchvision
-
-
-def check_img_size(img_size, s=32):
- # Verify img_size is a multiple of stride s
- new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
- # if new_size != img_size:
- # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
- return new_size
-
-
-def make_divisible(x, divisor):
- # Returns x evenly divisible by divisor
- return math.ceil(x / divisor) * divisor
-
-
-def xyxy2xywh(x):
- # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
- y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
- y[:, 2] = x[:, 2] - x[:, 0] # width
- y[:, 3] = x[:, 3] - x[:, 1] # height
- return y
-
-
-def xywh2xyxy(x):
- # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
- y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
- y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
- y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
- return y
-
-
-def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
- # Rescale coords (xyxy) from img1_shape to img0_shape
- if ratio_pad is None: # calculate from img0_shape
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
- pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
- else:
- gain = ratio_pad[0][0]
- pad = ratio_pad[1]
-
- coords[:, [0, 2]] -= pad[0] # x padding
- coords[:, [1, 3]] -= pad[1] # y padding
- coords[:, :4] /= gain
- clip_coords(coords, img0_shape)
- return coords
-
-
-def clip_coords(boxes, img_shape):
- # Clip bounding xyxy bounding boxes to image shape (height, width)
- boxes[:, 0].clamp_(0, img_shape[1]) # x1
- boxes[:, 1].clamp_(0, img_shape[0]) # y1
- boxes[:, 2].clamp_(0, img_shape[1]) # x2
- boxes[:, 3].clamp_(0, img_shape[0]) # y2
-
-
-def box_iou(box1, box2):
- # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
- """
- Return intersection-over-union (Jaccard index) of boxes.
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
- Arguments:
- box1 (Tensor[N, 4])
- box2 (Tensor[M, 4])
- Returns:
- iou (Tensor[N, M]): the NxM matrix containing the pairwise
- IoU values for every element in boxes1 and boxes2
- """
-
- def box_area(box):
- return (box[2] - box[0]) * (box[3] - box[1])
-
- area1 = box_area(box1.T)
- area2 = box_area(box2.T)
-
- inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
- return inter / (area1[:, None] + area2 - inter)
-
-
-def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
- """Performs Non-Maximum Suppression (NMS) on inference results
- Returns:
- detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
- """
-
- nc = prediction.shape[2] - 15 # number of classes
- xc = prediction[..., 4] > conf_thres # candidates
-
- # Settings
- # (pixels) maximum box width and height
- max_wh = 4096
- time_limit = 10.0 # seconds to quit after
- redundant = True # require redundant detections
- multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
- merge = False # use merge-NMS
-
- t = time.time()
- output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
- for xi, x in enumerate(prediction): # image index, image inference
- # Apply constraints
- x = x[xc[xi]] # confidence
-
- # Cat apriori labels if autolabelling
- if labels and len(labels[xi]):
- label = labels[xi]
- v = torch.zeros((len(label), nc + 15), device=x.device)
- v[:, :4] = label[:, 1:5] # box
- v[:, 4] = 1.0 # conf
- v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
- x = torch.cat((x, v), 0)
-
- # If none remain process next image
- if not x.shape[0]:
- continue
-
- # Compute conf
- x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
-
- # Box (center x, center y, width, height) to (x1, y1, x2, y2)
- box = xywh2xyxy(x[:, :4])
-
- # Detections matrix nx6 (xyxy, conf, landmarks, cls)
- if multi_label:
- i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
- x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
- else: # best class only
- conf, j = x[:, 15:].max(1, keepdim=True)
- x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
-
- # Filter by class
- if classes is not None:
- x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
-
- # If none remain process next image
- n = x.shape[0] # number of boxes
- if not n:
- continue
-
- # Batched NMS
- c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
- boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
-
- if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
- # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
- iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
- weights = iou * scores[None] # box weights
- x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
- if redundant:
- i = i[iou.sum(1) > 1] # require redundancy
-
- output[xi] = x[i]
- if (time.time() - t) > time_limit:
- break # time limit exceeded
-
- return output
-
-
-def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
- """Performs Non-Maximum Suppression (NMS) on inference results
-
- Returns:
- detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
- """
-
- nc = prediction.shape[2] - 5 # number of classes
- xc = prediction[..., 4] > conf_thres # candidates
-
- # Settings
- # (pixels) maximum box width and height
- max_wh = 4096
- time_limit = 10.0 # seconds to quit after
- redundant = True # require redundant detections
- multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
- merge = False # use merge-NMS
-
- t = time.time()
- output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
- for xi, x in enumerate(prediction): # image index, image inference
- x = x[xc[xi]] # confidence
-
- # Cat apriori labels if autolabelling
- if labels and len(labels[xi]):
- label_id = labels[xi]
- v = torch.zeros((len(label_id), nc + 5), device=x.device)
- v[:, :4] = label_id[:, 1:5] # box
- v[:, 4] = 1.0 # conf
- v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
- x = torch.cat((x, v), 0)
-
- # If none remain process next image
- if not x.shape[0]:
- continue
-
- # Compute conf
- x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
-
- # Box (center x, center y, width, height) to (x1, y1, x2, y2)
- box = xywh2xyxy(x[:, :4])
-
- # Detections matrix nx6 (xyxy, conf, cls)
- if multi_label:
- i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
- x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
- else: # best class only
- conf, j = x[:, 5:].max(1, keepdim=True)
- x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
-
- # Filter by class
- if classes is not None:
- x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
-
- # Check shape
- n = x.shape[0] # number of boxes
- if not n: # no boxes
- continue
-
- x = x[x[:, 4].argsort(descending=True)] # sort by confidence
-
- # Batched NMS
- c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
- boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
- if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
- # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
- iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
- weights = iou * scores[None] # box weights
- x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
- if redundant:
- i = i[iou.sum(1) > 1] # require redundancy
-
- output[xi] = x[i]
- if (time.time() - t) > time_limit:
- print(f"WARNING: NMS time limit {time_limit}s exceeded")
- break # time limit exceeded
-
- return output
-
-
-def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
- # Rescale coords (xyxy) from img1_shape to img0_shape
- if ratio_pad is None: # calculate from img0_shape
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
- pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
- else:
- gain = ratio_pad[0][0]
- pad = ratio_pad[1]
-
- coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
- coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
- coords[:, :10] /= gain
- coords[:, 0].clamp_(0, img0_shape[1]) # x1
- coords[:, 1].clamp_(0, img0_shape[0]) # y1
- coords[:, 2].clamp_(0, img0_shape[1]) # x2
- coords[:, 3].clamp_(0, img0_shape[0]) # y2
- coords[:, 4].clamp_(0, img0_shape[1]) # x3
- coords[:, 5].clamp_(0, img0_shape[0]) # y3
- coords[:, 6].clamp_(0, img0_shape[1]) # x4
- coords[:, 7].clamp_(0, img0_shape[0]) # y4
- coords[:, 8].clamp_(0, img0_shape[1]) # x5
- coords[:, 9].clamp_(0, img0_shape[0]) # y5
- return coords
+import math
+import time
+
+import numpy as np
+import torch
+import torchvision
+
+
+def check_img_size(img_size, s=32):
+ # Verify img_size is a multiple of stride s
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
+ # if new_size != img_size:
+ # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
+ return new_size
+
+
+def make_divisible(x, divisor):
+ # Returns x evenly divisible by divisor
+ return math.ceil(x / divisor) * divisor
+
+
+def xyxy2xywh(x):
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+ coords[:, :4] /= gain
+ clip_coords(coords, img0_shape)
+ return coords
+
+
+def clip_coords(boxes, img_shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
+
+
+def box_iou(box1, box2):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ return inter / (area1[:, None] + area2 - inter)
+
+
+def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 15 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label = labels[xi]
+ v = torch.zeros((len(label), nc + 15), device=x.device)
+ v[:, :4] = label[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
+ if multi_label:
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 15:].max(1, keepdim=True)
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # If none remain process next image
+ n = x.shape[0] # number of boxes
+ if not n:
+ continue
+
+ # Batched NMS
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ break # time limit exceeded
+
+ return output
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label_id = labels[xi]
+ v = torch.zeros((len(label_id), nc + 5), device=x.device)
+ v[:, :4] = label_id[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 5:].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(f"WARNING: NMS time limit {time_limit}s exceeded")
+ break # time limit exceeded
+
+ return output
+
+
+def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
+ coords[:, :10] /= gain
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
+ return coords
diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py
similarity index 97%
rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py
rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py
index af2d0658..f7029623 100644
--- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py
+++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py
@@ -1,40 +1,40 @@
-import torch
-from torch import nn
-
-
-def fuse_conv_and_bn(conv, bn):
- # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
- fusedconv = (
- nn.Conv2d(
- conv.in_channels,
- conv.out_channels,
- kernel_size=conv.kernel_size,
- stride=conv.stride,
- padding=conv.padding,
- groups=conv.groups,
- bias=True,
- )
- .requires_grad_(False)
- .to(conv.weight.device)
- )
-
- # prepare filters
- w_conv = conv.weight.clone().view(conv.out_channels, -1)
- w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
- fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
-
- # prepare spatial bias
- b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
- b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
- fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
-
- return fusedconv
-
-
-def copy_attr(a, b, include=(), exclude=()):
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
- for k, v in b.__dict__.items():
- if (include and k not in include) or k.startswith("_") or k in exclude:
- continue
-
- setattr(a, k, v)
+import torch
+from torch import nn
+
+
+def fuse_conv_and_bn(conv, bn):
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+ fusedconv = (
+ nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ groups=conv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(conv.weight.device)
+ )
+
+ # prepare filters
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
+
+ # prepare spatial bias
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+ return fusedconv
+
+
+def copy_attr(a, b, include=(), exclude=()):
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
+ for k, v in b.__dict__.items():
+ if (include and k not in include) or k.startswith("_") or k in exclude:
+ continue
+
+ setattr(a, k, v)
diff --git a/hordelib/nodes/facerestore/facelib/parsing/__init__.py b/hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py
similarity index 81%
rename from hordelib/nodes/facerestore/facelib/parsing/__init__.py
rename to hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py
index b95f2d8b..f030fbfd 100644
--- a/hordelib/nodes/facerestore/facelib/parsing/__init__.py
+++ b/hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py
@@ -1,28 +1,31 @@
-import torch
-
-from hordelib.nodes.facerestore.facelib.utils import load_file_from_url
-from .bisenet import BiSeNet
-from .parsenet import ParseNet
-
-
-def init_parsing_model(model_name="bisenet", half=False, device="cuda"):
- if model_name == "bisenet":
- model = BiSeNet(num_class=19)
- model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth"
- elif model_name == "parsenet":
- model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
- model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth"
- else:
- raise NotImplementedError(f"{model_name} is not implemented.")
-
- model_path = load_file_from_url(
- url=model_url,
- model_dir="../../models/facedetection",
- progress=True,
- file_name=None,
- )
- load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
- model.load_state_dict(load_net, strict=True)
- model.eval()
- model = model.to(device)
- return model
+import torch
+
+from hordelib.nodes.facerestore_cf.facelib.utils import load_file_from_url
+
+from .bisenet import BiSeNet
+from .parsenet import ParseNet
+
+
+def init_parsing_model(model_name="bisenet", half=False, device="cuda"):
+ if model_name == "bisenet":
+ model = BiSeNet(num_class=19)
+ model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth"
+ filename = "parsing_bisenet.pth"
+ elif model_name == "parsenet":
+ model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
+ model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth"
+ filename = "parsing_parsenet.pth"
+ else:
+ raise NotImplementedError(f"{model_name} is not implemented.")
+
+ model_path = load_file_from_url(
+ url=model_url,
+ model_dir="../../models/facedetection",
+ progress=True,
+ file_name=filename,
+ )
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
diff --git a/hordelib/nodes/facerestore/facelib/parsing/bisenet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py
similarity index 85%
rename from hordelib/nodes/facerestore/facelib/parsing/bisenet.py
rename to hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py
index 3898cab7..051eec1d 100644
--- a/hordelib/nodes/facerestore/facelib/parsing/bisenet.py
+++ b/hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py
@@ -1,140 +1,140 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .resnet import ResNet18
-
-
-class ConvBNReLU(nn.Module):
-
- def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
- super(ConvBNReLU, self).__init__()
- self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
- self.bn = nn.BatchNorm2d(out_chan)
-
- def forward(self, x):
- x = self.conv(x)
- x = F.relu(self.bn(x))
- return x
-
-
-class BiSeNetOutput(nn.Module):
-
- def __init__(self, in_chan, mid_chan, num_class):
- super(BiSeNetOutput, self).__init__()
- self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
- self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
-
- def forward(self, x):
- feat = self.conv(x)
- out = self.conv_out(feat)
- return out, feat
-
-
-class AttentionRefinementModule(nn.Module):
-
- def __init__(self, in_chan, out_chan):
- super(AttentionRefinementModule, self).__init__()
- self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
- self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
- self.bn_atten = nn.BatchNorm2d(out_chan)
- self.sigmoid_atten = nn.Sigmoid()
-
- def forward(self, x):
- feat = self.conv(x)
- atten = F.avg_pool2d(feat, feat.size()[2:])
- atten = self.conv_atten(atten)
- atten = self.bn_atten(atten)
- atten = self.sigmoid_atten(atten)
- out = torch.mul(feat, atten)
- return out
-
-
-class ContextPath(nn.Module):
-
- def __init__(self):
- super(ContextPath, self).__init__()
- self.resnet = ResNet18()
- self.arm16 = AttentionRefinementModule(256, 128)
- self.arm32 = AttentionRefinementModule(512, 128)
- self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
- self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
- self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
-
- def forward(self, x):
- feat8, feat16, feat32 = self.resnet(x)
- h8, w8 = feat8.size()[2:]
- h16, w16 = feat16.size()[2:]
- h32, w32 = feat32.size()[2:]
-
- avg = F.avg_pool2d(feat32, feat32.size()[2:])
- avg = self.conv_avg(avg)
- avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
-
- feat32_arm = self.arm32(feat32)
- feat32_sum = feat32_arm + avg_up
- feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
- feat32_up = self.conv_head32(feat32_up)
-
- feat16_arm = self.arm16(feat16)
- feat16_sum = feat16_arm + feat32_up
- feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
- feat16_up = self.conv_head16(feat16_up)
-
- return feat8, feat16_up, feat32_up # x8, x8, x16
-
-
-class FeatureFusionModule(nn.Module):
-
- def __init__(self, in_chan, out_chan):
- super(FeatureFusionModule, self).__init__()
- self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
- self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
- self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
- self.relu = nn.ReLU(inplace=True)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, fsp, fcp):
- fcat = torch.cat([fsp, fcp], dim=1)
- feat = self.convblk(fcat)
- atten = F.avg_pool2d(feat, feat.size()[2:])
- atten = self.conv1(atten)
- atten = self.relu(atten)
- atten = self.conv2(atten)
- atten = self.sigmoid(atten)
- feat_atten = torch.mul(feat, atten)
- feat_out = feat_atten + feat
- return feat_out
-
-
-class BiSeNet(nn.Module):
-
- def __init__(self, num_class):
- super(BiSeNet, self).__init__()
- self.cp = ContextPath()
- self.ffm = FeatureFusionModule(256, 256)
- self.conv_out = BiSeNetOutput(256, 256, num_class)
- self.conv_out16 = BiSeNetOutput(128, 64, num_class)
- self.conv_out32 = BiSeNetOutput(128, 64, num_class)
-
- def forward(self, x, return_feat=False):
- h, w = x.size()[2:]
- feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
- feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
- feat_fuse = self.ffm(feat_sp, feat_cp8)
-
- out, feat = self.conv_out(feat_fuse)
- out16, feat16 = self.conv_out16(feat_cp8)
- out32, feat32 = self.conv_out32(feat_cp16)
-
- out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
- out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
- out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
-
- if return_feat:
- feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
- feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
- feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
- return out, out16, out32, feat, feat16, feat32
- else:
- return out, out16, out32
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .resnet import ResNet18
+
+
+class ConvBNReLU(nn.Module):
+
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
+ super(ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_chan)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(self.bn(x))
+ return x
+
+
+class BiSeNetOutput(nn.Module):
+
+ def __init__(self, in_chan, mid_chan, num_class):
+ super(BiSeNetOutput, self).__init__()
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+ self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ feat = self.conv(x)
+ out = self.conv_out(feat)
+ return out, feat
+
+
+class AttentionRefinementModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(AttentionRefinementModule, self).__init__()
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+ self.bn_atten = nn.BatchNorm2d(out_chan)
+ self.sigmoid_atten = nn.Sigmoid()
+
+ def forward(self, x):
+ feat = self.conv(x)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv_atten(atten)
+ atten = self.bn_atten(atten)
+ atten = self.sigmoid_atten(atten)
+ out = torch.mul(feat, atten)
+ return out
+
+
+class ContextPath(nn.Module):
+
+ def __init__(self):
+ super(ContextPath, self).__init__()
+ self.resnet = ResNet18()
+ self.arm16 = AttentionRefinementModule(256, 128)
+ self.arm32 = AttentionRefinementModule(512, 128)
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+ def forward(self, x):
+ feat8, feat16, feat32 = self.resnet(x)
+ h8, w8 = feat8.size()[2:]
+ h16, w16 = feat16.size()[2:]
+ h32, w32 = feat32.size()[2:]
+
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
+ avg = self.conv_avg(avg)
+ avg_up = F.interpolate(avg, (h32, w32), mode="nearest")
+
+ feat32_arm = self.arm32(feat32)
+ feat32_sum = feat32_arm + avg_up
+ feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest")
+ feat32_up = self.conv_head32(feat32_up)
+
+ feat16_arm = self.arm16(feat16)
+ feat16_sum = feat16_arm + feat32_up
+ feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest")
+ feat16_up = self.conv_head16(feat16_up)
+
+ return feat8, feat16_up, feat32_up # x8, x8, x16
+
+
+class FeatureFusionModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(FeatureFusionModule, self).__init__()
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, fsp, fcp):
+ fcat = torch.cat([fsp, fcp], dim=1)
+ feat = self.convblk(fcat)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv1(atten)
+ atten = self.relu(atten)
+ atten = self.conv2(atten)
+ atten = self.sigmoid(atten)
+ feat_atten = torch.mul(feat, atten)
+ feat_out = feat_atten + feat
+ return feat_out
+
+
+class BiSeNet(nn.Module):
+
+ def __init__(self, num_class):
+ super(BiSeNet, self).__init__()
+ self.cp = ContextPath()
+ self.ffm = FeatureFusionModule(256, 256)
+ self.conv_out = BiSeNetOutput(256, 256, num_class)
+ self.conv_out16 = BiSeNetOutput(128, 64, num_class)
+ self.conv_out32 = BiSeNetOutput(128, 64, num_class)
+
+ def forward(self, x, return_feat=False):
+ h, w = x.size()[2:]
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
+ feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+ out, feat = self.conv_out(feat_fuse)
+ out16, feat16 = self.conv_out16(feat_cp8)
+ out32, feat32 = self.conv_out32(feat_cp16)
+
+ out = F.interpolate(out, (h, w), mode="bilinear", align_corners=True)
+ out16 = F.interpolate(out16, (h, w), mode="bilinear", align_corners=True)
+ out32 = F.interpolate(out32, (h, w), mode="bilinear", align_corners=True)
+
+ if return_feat:
+ feat = F.interpolate(feat, (h, w), mode="bilinear", align_corners=True)
+ feat16 = F.interpolate(feat16, (h, w), mode="bilinear", align_corners=True)
+ feat32 = F.interpolate(feat32, (h, w), mode="bilinear", align_corners=True)
+ return out, out16, out32, feat, feat16, feat32
+ else:
+ return out, out16, out32
diff --git a/hordelib/nodes/facerestore/facelib/parsing/parsenet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py
similarity index 70%
rename from hordelib/nodes/facerestore/facelib/parsing/parsenet.py
rename to hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py
index e178ebe4..b2ed997a 100644
--- a/hordelib/nodes/facerestore/facelib/parsing/parsenet.py
+++ b/hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py
@@ -1,194 +1,199 @@
-"""Modified from https://github.com/chaofengc/PSFRGAN
-"""
-import numpy as np
-import torch.nn as nn
-from torch.nn import functional as F
-
-
-class NormLayer(nn.Module):
- """Normalization Layers.
-
- Args:
- channels: input channels, for batch norm and instance norm.
- input_size: input shape without batch size, for layer norm.
- """
-
- def __init__(self, channels, normalize_shape=None, norm_type='bn'):
- super(NormLayer, self).__init__()
- norm_type = norm_type.lower()
- self.norm_type = norm_type
- if norm_type == 'bn':
- self.norm = nn.BatchNorm2d(channels, affine=True)
- elif norm_type == 'in':
- self.norm = nn.InstanceNorm2d(channels, affine=False)
- elif norm_type == 'gn':
- self.norm = nn.GroupNorm(32, channels, affine=True)
- elif norm_type == 'pixel':
- self.norm = lambda x: F.normalize(x, p=2, dim=1)
- elif norm_type == 'layer':
- self.norm = nn.LayerNorm(normalize_shape)
- elif norm_type == 'none':
- self.norm = lambda x: x * 1.0
- else:
- assert 1 == 0, f'Norm type {norm_type} not support.'
-
- def forward(self, x, ref=None):
- if self.norm_type == 'spade':
- return self.norm(x, ref)
- else:
- return self.norm(x)
-
-
-class ReluLayer(nn.Module):
- """Relu Layer.
-
- Args:
- relu type: type of relu layer, candidates are
- - ReLU
- - LeakyReLU: default relu slope 0.2
- - PRelu
- - SELU
- - none: direct pass
- """
-
- def __init__(self, channels, relu_type='relu'):
- super(ReluLayer, self).__init__()
- relu_type = relu_type.lower()
- if relu_type == 'relu':
- self.func = nn.ReLU(True)
- elif relu_type == 'leakyrelu':
- self.func = nn.LeakyReLU(0.2, inplace=True)
- elif relu_type == 'prelu':
- self.func = nn.PReLU(channels)
- elif relu_type == 'selu':
- self.func = nn.SELU(True)
- elif relu_type == 'none':
- self.func = lambda x: x * 1.0
- else:
- assert 1 == 0, f'Relu type {relu_type} not support.'
-
- def forward(self, x):
- return self.func(x)
-
-
-class ConvLayer(nn.Module):
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- scale='none',
- norm_type='none',
- relu_type='none',
- use_pad=True,
- bias=True):
- super(ConvLayer, self).__init__()
- self.use_pad = use_pad
- self.norm_type = norm_type
- if norm_type in ['bn']:
- bias = False
-
- stride = 2 if scale == 'down' else 1
-
- self.scale_func = lambda x: x
- if scale == 'up':
- self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
-
- self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
- self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
-
- self.relu = ReluLayer(out_channels, relu_type)
- self.norm = NormLayer(out_channels, norm_type=norm_type)
-
- def forward(self, x):
- out = self.scale_func(x)
- if self.use_pad:
- out = self.reflection_pad(out)
- out = self.conv2d(out)
- out = self.norm(out)
- out = self.relu(out)
- return out
-
-
-class ResidualBlock(nn.Module):
- """
- Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
- """
-
- def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
- super(ResidualBlock, self).__init__()
-
- if scale == 'none' and c_in == c_out:
- self.shortcut_func = lambda x: x
- else:
- self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
-
- scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
- scale_conf = scale_config_dict[scale]
-
- self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
- self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
-
- def forward(self, x):
- identity = self.shortcut_func(x)
-
- res = self.conv1(x)
- res = self.conv2(res)
- return identity + res
-
-
-class ParseNet(nn.Module):
-
- def __init__(self,
- in_size=128,
- out_size=128,
- min_feat_size=32,
- base_ch=64,
- parsing_ch=19,
- res_depth=10,
- relu_type='LeakyReLU',
- norm_type='bn',
- ch_range=[32, 256]):
- super().__init__()
- self.res_depth = res_depth
- act_args = {'norm_type': norm_type, 'relu_type': relu_type}
- min_ch, max_ch = ch_range
-
- ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
- min_feat_size = min(in_size, min_feat_size)
-
- down_steps = int(np.log2(in_size // min_feat_size))
- up_steps = int(np.log2(out_size // min_feat_size))
-
- # =============== define encoder-body-decoder ====================
- self.encoder = []
- self.encoder.append(ConvLayer(3, base_ch, 3, 1))
- head_ch = base_ch
- for i in range(down_steps):
- cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
- self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
- head_ch = head_ch * 2
-
- self.body = []
- for i in range(res_depth):
- self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
-
- self.decoder = []
- for i in range(up_steps):
- cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
- self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
- head_ch = head_ch // 2
-
- self.encoder = nn.Sequential(*self.encoder)
- self.body = nn.Sequential(*self.body)
- self.decoder = nn.Sequential(*self.decoder)
- self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
- self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
-
- def forward(self, x):
- feat = self.encoder(x)
- x = feat + self.body(feat)
- x = self.decoder(x)
- out_img = self.out_img_conv(x)
- out_mask = self.out_mask_conv(x)
- return out_mask, out_img
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+ """Normalization Layers.
+
+ Args:
+ channels: input channels, for batch norm and instance norm.
+ input_size: input shape without batch size, for layer norm.
+ """
+
+ def __init__(self, channels, normalize_shape=None, norm_type="bn"):
+ super(NormLayer, self).__init__()
+ norm_type = norm_type.lower()
+ self.norm_type = norm_type
+ if norm_type == "bn":
+ self.norm = nn.BatchNorm2d(channels, affine=True)
+ elif norm_type == "in":
+ self.norm = nn.InstanceNorm2d(channels, affine=False)
+ elif norm_type == "gn":
+ self.norm = nn.GroupNorm(32, channels, affine=True)
+ elif norm_type == "pixel":
+ self.norm = lambda x: F.normalize(x, p=2, dim=1)
+ elif norm_type == "layer":
+ self.norm = nn.LayerNorm(normalize_shape)
+ elif norm_type == "none":
+ self.norm = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f"Norm type {norm_type} not support."
+
+ def forward(self, x, ref=None):
+ if self.norm_type == "spade":
+ return self.norm(x, ref)
+ else:
+ return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+ """Relu Layer.
+
+ Args:
+ relu type: type of relu layer, candidates are
+ - ReLU
+ - LeakyReLU: default relu slope 0.2
+ - PRelu
+ - SELU
+ - none: direct pass
+ """
+
+ def __init__(self, channels, relu_type="relu"):
+ super(ReluLayer, self).__init__()
+ relu_type = relu_type.lower()
+ if relu_type == "relu":
+ self.func = nn.ReLU(True)
+ elif relu_type == "leakyrelu":
+ self.func = nn.LeakyReLU(0.2, inplace=True)
+ elif relu_type == "prelu":
+ self.func = nn.PReLU(channels)
+ elif relu_type == "selu":
+ self.func = nn.SELU(True)
+ elif relu_type == "none":
+ self.func = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f"Relu type {relu_type} not support."
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ scale="none",
+ norm_type="none",
+ relu_type="none",
+ use_pad=True,
+ bias=True,
+ ):
+ super(ConvLayer, self).__init__()
+ self.use_pad = use_pad
+ self.norm_type = norm_type
+ if norm_type in ["bn"]:
+ bias = False
+
+ stride = 2 if scale == "down" else 1
+
+ self.scale_func = lambda x: x
+ if scale == "up":
+ self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode="nearest")
+
+ self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2)))
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+ self.relu = ReluLayer(out_channels, relu_type)
+ self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+ def forward(self, x):
+ out = self.scale_func(x)
+ if self.use_pad:
+ out = self.reflection_pad(out)
+ out = self.conv2d(out)
+ out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+
+class ResidualBlock(nn.Module):
+ """
+ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+ """
+
+ def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"):
+ super(ResidualBlock, self).__init__()
+
+ if scale == "none" and c_in == c_out:
+ self.shortcut_func = lambda x: x
+ else:
+ self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+ scale_config_dict = {"down": ["none", "down"], "up": ["up", "none"], "none": ["none", "none"]}
+ scale_conf = scale_config_dict[scale]
+
+ self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+ self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none")
+
+ def forward(self, x):
+ identity = self.shortcut_func(x)
+
+ res = self.conv1(x)
+ res = self.conv2(res)
+ return identity + res
+
+
+class ParseNet(nn.Module):
+
+ def __init__(
+ self,
+ in_size=128,
+ out_size=128,
+ min_feat_size=32,
+ base_ch=64,
+ parsing_ch=19,
+ res_depth=10,
+ relu_type="LeakyReLU",
+ norm_type="bn",
+ ch_range=[32, 256],
+ ):
+ super().__init__()
+ self.res_depth = res_depth
+ act_args = {"norm_type": norm_type, "relu_type": relu_type}
+ min_ch, max_ch = ch_range
+
+ ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
+ min_feat_size = min(in_size, min_feat_size)
+
+ down_steps = int(np.log2(in_size // min_feat_size))
+ up_steps = int(np.log2(out_size // min_feat_size))
+
+ # =============== define encoder-body-decoder ====================
+ self.encoder = []
+ self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+ head_ch = base_ch
+ for i in range(down_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+ self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args))
+ head_ch = head_ch * 2
+
+ self.body = []
+ for i in range(res_depth):
+ self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+ self.decoder = []
+ for i in range(up_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+ self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args))
+ head_ch = head_ch // 2
+
+ self.encoder = nn.Sequential(*self.encoder)
+ self.body = nn.Sequential(*self.body)
+ self.decoder = nn.Sequential(*self.decoder)
+ self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+ self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+ def forward(self, x):
+ feat = self.encoder(x)
+ x = feat + self.body(feat)
+ x = self.decoder(x)
+ out_img = self.out_img_conv(x)
+ out_mask = self.out_mask_conv(x)
+ return out_mask, out_img
diff --git a/hordelib/nodes/facerestore/facelib/parsing/resnet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py
similarity index 97%
rename from hordelib/nodes/facerestore/facelib/parsing/resnet.py
rename to hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py
index fec8e82c..e7cc283d 100644
--- a/hordelib/nodes/facerestore/facelib/parsing/resnet.py
+++ b/hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py
@@ -1,69 +1,69 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def conv3x3(in_planes, out_planes, stride=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
-
-
-class BasicBlock(nn.Module):
-
- def __init__(self, in_chan, out_chan, stride=1):
- super(BasicBlock, self).__init__()
- self.conv1 = conv3x3(in_chan, out_chan, stride)
- self.bn1 = nn.BatchNorm2d(out_chan)
- self.conv2 = conv3x3(out_chan, out_chan)
- self.bn2 = nn.BatchNorm2d(out_chan)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = None
- if in_chan != out_chan or stride != 1:
- self.downsample = nn.Sequential(
- nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
- nn.BatchNorm2d(out_chan),
- )
-
- def forward(self, x):
- residual = self.conv1(x)
- residual = F.relu(self.bn1(residual))
- residual = self.conv2(residual)
- residual = self.bn2(residual)
-
- shortcut = x
- if self.downsample is not None:
- shortcut = self.downsample(x)
-
- out = shortcut + residual
- out = self.relu(out)
- return out
-
-
-def create_layer_basic(in_chan, out_chan, bnum, stride=1):
- layers = [BasicBlock(in_chan, out_chan, stride=stride)]
- for i in range(bnum - 1):
- layers.append(BasicBlock(out_chan, out_chan, stride=1))
- return nn.Sequential(*layers)
-
-
-class ResNet18(nn.Module):
-
- def __init__(self):
- super(ResNet18, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
- self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
- self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
- self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
-
- def forward(self, x):
- x = self.conv1(x)
- x = F.relu(self.bn1(x))
- x = self.maxpool(x)
-
- x = self.layer1(x)
- feat8 = self.layer2(x) # 1/8
- feat16 = self.layer3(feat8) # 1/16
- feat32 = self.layer4(feat16) # 1/32
- return feat8, feat16, feat32
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(self, in_chan, out_chan, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
+ self.bn1 = nn.BatchNorm2d(out_chan)
+ self.conv2 = conv3x3(out_chan, out_chan)
+ self.bn2 = nn.BatchNorm2d(out_chan)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ if in_chan != out_chan or stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_chan),
+ )
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = F.relu(self.bn1(residual))
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = shortcut + residual
+ out = self.relu(out)
+ return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+ for i in range(bnum - 1):
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
+ return nn.Sequential(*layers)
+
+
+class ResNet18(nn.Module):
+
+ def __init__(self):
+ super(ResNet18, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(self.bn1(x))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ feat8 = self.layer2(x) # 1/8
+ feat16 = self.layer3(feat8) # 1/16
+ feat32 = self.layer4(feat16) # 1/32
+ return feat8, feat16, feat32
diff --git a/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py
new file mode 100644
index 00000000..47ef2010
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py
@@ -0,0 +1,13 @@
+from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import download_pretrained_models, img2tensor, load_file_from_url, scandir
+
+__all__ = [
+ "align_crop_face_landmarks",
+ "compute_increased_bbox",
+ "get_valid_bboxes",
+ "load_file_from_url",
+ "download_pretrained_models",
+ "paste_face_back",
+ "img2tensor",
+ "scandir",
+]
diff --git a/hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py b/hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py
similarity index 78%
rename from hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py
rename to hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py
index 7917344a..ab6ce3c0 100644
--- a/hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py
+++ b/hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py
@@ -1,561 +1,476 @@
-import cv2
-import numpy as np
-import os
-import torch
-from torchvision.transforms.functional import normalize
-
-from hordelib.nodes.facerestore.facelib.detection import init_detection_model
-from hordelib.nodes.facerestore.facelib.parsing import init_parsing_model
-from hordelib.nodes.facerestore.facelib.utils.misc import img2tensor, imwrite
-
-
-def get_largest_face(det_faces, h, w):
- def get_location(val, length):
- if val < 0:
- return 0
- elif val > length:
- return length
- else:
- return val
-
- face_areas = []
- for det_face in det_faces:
- left = get_location(det_face[0], w)
- right = get_location(det_face[2], w)
- top = get_location(det_face[1], h)
- bottom = get_location(det_face[3], h)
- face_area = (right - left) * (bottom - top)
- face_areas.append(face_area)
- largest_idx = face_areas.index(max(face_areas))
- return det_faces[largest_idx], largest_idx
-
-
-def get_center_face(det_faces, h=0, w=0, center=None):
- if center is not None:
- center = np.array(center)
- else:
- center = np.array([w / 2, h / 2])
- center_dist = []
- for det_face in det_faces:
- face_center = np.array(
- [(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]
- )
- dist = np.linalg.norm(face_center - center)
- center_dist.append(dist)
- center_idx = center_dist.index(min(center_dist))
- return det_faces[center_idx], center_idx
-
-
-class FaceRestoreHelper(object):
- """Helper for the face restoration pipeline (base class)."""
-
- def __init__(
- self,
- upscale_factor,
- face_size=512,
- crop_ratio=(1, 1),
- det_model="retinaface_resnet50",
- save_ext="png",
- template_3points=False,
- pad_blur=False,
- use_parse=False,
- device=None,
- ):
- self.template_3points = template_3points # improve robustness
- self.upscale_factor = upscale_factor
- # the cropped face ratio based on the square face
- self.crop_ratio = crop_ratio # (h, w)
- assert (
- self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1
- ), "crop ration only supports >=1"
- self.face_size = (
- int(face_size * self.crop_ratio[1]),
- int(face_size * self.crop_ratio[0]),
- )
-
- if self.template_3points:
- self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
- else:
- # standard 5 landmarks for FFHQ faces with 512 x 512
- # facexlib
- self.face_template = np.array(
- [
- [192.98138, 239.94708],
- [318.90277, 240.1936],
- [256.63416, 314.01935],
- [201.26117, 371.41043],
- [313.08905, 371.15118],
- ]
- )
-
- # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
- # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
- # [198.22603, 372.82502], [313.91018, 372.75659]])
-
- self.face_template = self.face_template * (face_size / 512.0)
- if self.crop_ratio[0] > 1:
- self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
- if self.crop_ratio[1] > 1:
- self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
- self.save_ext = save_ext
- self.pad_blur = pad_blur
- if self.pad_blur is True:
- self.template_3points = False
-
- self.all_landmarks_5 = []
- self.det_faces = []
- self.affine_matrices = []
- self.inverse_affine_matrices = []
- self.cropped_faces = []
- self.restored_faces = []
- self.pad_input_imgs = []
-
- if device is None:
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- else:
- self.device = device
-
- # init face detection model
- self.face_det = init_detection_model(det_model, half=False, device=self.device)
-
- # init face parsing model
- self.use_parse = use_parse
- self.face_parse = init_parsing_model(model_name="parsenet", device=self.device)
-
- def set_upscale_factor(self, upscale_factor):
- self.upscale_factor = upscale_factor
-
- def read_image(self, img):
- """img can be image path or cv2 loaded image."""
- # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
- if isinstance(img, str):
- img = cv2.imread(img)
-
- if np.max(img) > 256: # 16-bit image
- img = img / 65535 * 255
- if len(img.shape) == 2: # gray image
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif img.shape[2] == 4: # BGRA image with alpha channel
- img = img[:, :, 0:3]
-
- self.input_img = img
-
- if min(self.input_img.shape[:2]) < 512:
- f = 512.0 / min(self.input_img.shape[:2])
- self.input_img = cv2.resize(
- self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR
- )
-
- def get_face_landmarks_5(
- self,
- only_keep_largest=False,
- only_center_face=False,
- resize=None,
- blur_ratio=0.01,
- eye_dist_threshold=None,
- ):
- if resize is None:
- scale = 1
- input_img = self.input_img
- else:
- h, w = self.input_img.shape[0:2]
- scale = resize / min(h, w)
- scale = max(1, scale) # always scale up
- h, w = int(h * scale), int(w * scale)
- interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
- input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
-
- with torch.no_grad():
- bboxes = self.face_det.detect_faces(input_img)
-
- if bboxes is None or bboxes.shape[0] == 0:
- return 0
- else:
- bboxes = bboxes / scale
-
- for bbox in bboxes:
- # remove faces with too small eye distance: side faces or too small faces
- eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
- if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
- continue
-
- if self.template_3points:
- landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
- else:
- landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
- self.all_landmarks_5.append(landmark)
- self.det_faces.append(bbox[0:5])
-
- if len(self.det_faces) == 0:
- return 0
- if only_keep_largest:
- h, w, _ = self.input_img.shape
- self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
- self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
- elif only_center_face:
- h, w, _ = self.input_img.shape
- self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
- self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
-
- # pad blurry images
- if self.pad_blur:
- self.pad_input_imgs = []
- for landmarks in self.all_landmarks_5:
- # get landmarks
- eye_left = landmarks[0, :]
- eye_right = landmarks[1, :]
- eye_avg = (eye_left + eye_right) * 0.5
- mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
- eye_to_eye = eye_right - eye_left
- eye_to_mouth = mouth_avg - eye_avg
-
- # Get the oriented crop rectangle
- # x: half width of the oriented crop rectangle
- x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
- # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
- # norm with the hypotenuse: get the direction
- x /= np.hypot(*x) # get the hypotenuse of a right triangle
- rect_scale = 1.5
- x *= max(
- np.hypot(*eye_to_eye) * 2.0 * rect_scale,
- np.hypot(*eye_to_mouth) * 1.8 * rect_scale,
- )
- # y: half height of the oriented crop rectangle
- y = np.flipud(x) * [-1, 1]
-
- # c: center
- c = eye_avg + eye_to_mouth * 0.1
- # quad: (left_top, left_bottom, right_bottom, right_top)
- quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
- # qsize: side length of the square
- qsize = np.hypot(*x) * 2
- border = max(int(np.rint(qsize * 0.1)), 3)
-
- # get pad
- # pad: (width_left, height_top, width_right, height_bottom)
- pad = (
- int(np.floor(min(quad[:, 0]))),
- int(np.floor(min(quad[:, 1]))),
- int(np.ceil(max(quad[:, 0]))),
- int(np.ceil(max(quad[:, 1]))),
- )
- pad = [
- max(-pad[0] + border, 1),
- max(-pad[1] + border, 1),
- max(pad[2] - self.input_img.shape[0] + border, 1),
- max(pad[3] - self.input_img.shape[1] + border, 1),
- ]
-
- if max(pad) > 1:
- # pad image
- pad_img = np.pad(
- self.input_img,
- ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
- "reflect",
- )
- # modify landmark coords
- landmarks[:, 0] += pad[0]
- landmarks[:, 1] += pad[1]
- # blur pad images
- h, w, _ = pad_img.shape
- y, x, _ = np.ogrid[:h, :w, :1]
- mask = np.maximum(
- 1.0
- - np.minimum(
- np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]
- ),
- 1.0
- - np.minimum(
- np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]
- ),
- )
- blur = int(qsize * blur_ratio)
- if blur % 2 == 0:
- blur += 1
- blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
- # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
-
- pad_img = pad_img.astype("float32")
- pad_img += (blur_img - pad_img) * np.clip(
- mask * 3.0 + 1.0, 0.0, 1.0
- )
- pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(
- mask, 0.0, 1.0
- )
- pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
- self.pad_input_imgs.append(pad_img)
- else:
- self.pad_input_imgs.append(np.copy(self.input_img))
-
- return len(self.all_landmarks_5)
-
- def align_warp_face(self, save_cropped_path=None, border_mode="constant"):
- """Align and warp faces with face template."""
- if self.pad_blur:
- assert len(self.pad_input_imgs) == len(
- self.all_landmarks_5
- ), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}"
- for idx, landmark in enumerate(self.all_landmarks_5):
- # use 5 landmarks to get affine matrix
- # use cv2.LMEDS method for the equivalence to skimage transform
- # ref: https://blog.csdn.net/yichxi/article/details/115827338
- affine_matrix = cv2.estimateAffinePartial2D(
- landmark, self.face_template, method=cv2.LMEDS
- )[0]
- self.affine_matrices.append(affine_matrix)
- # warp and crop faces
- if border_mode == "constant":
- border_mode = cv2.BORDER_CONSTANT
- elif border_mode == "reflect101":
- border_mode = cv2.BORDER_REFLECT101
- elif border_mode == "reflect":
- border_mode = cv2.BORDER_REFLECT
- if self.pad_blur:
- input_img = self.pad_input_imgs[idx]
- else:
- input_img = self.input_img
- cropped_face = cv2.warpAffine(
- input_img,
- affine_matrix,
- self.face_size,
- borderMode=border_mode,
- borderValue=(135, 133, 132),
- ) # gray
- self.cropped_faces.append(cropped_face)
- # save the cropped face
- if save_cropped_path is not None:
- path = os.path.splitext(save_cropped_path)[0]
- save_path = f"{path}_{idx:02d}.{self.save_ext}"
- imwrite(cropped_face, save_path)
-
- def get_inverse_affine(self, save_inverse_affine_path=None):
- """Get inverse affine matrix."""
- for idx, affine_matrix in enumerate(self.affine_matrices):
- inverse_affine = cv2.invertAffineTransform(affine_matrix)
- inverse_affine *= self.upscale_factor
- self.inverse_affine_matrices.append(inverse_affine)
- # save inverse affine matrices
- if save_inverse_affine_path is not None:
- path, _ = os.path.splitext(save_inverse_affine_path)
- save_path = f"{path}_{idx:02d}.pth"
- torch.save(inverse_affine, save_path)
-
- def add_restored_face(self, face):
- self.restored_faces.append(face)
-
- def paste_faces_to_input_image(
- self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None
- ):
- h, w, _ = self.input_img.shape
- h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
-
- if upsample_img is None:
- # simply resize the background
- # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
- upsample_img = cv2.resize(
- self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR
- )
- else:
- upsample_img = cv2.resize(
- upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
- )
-
- assert len(self.restored_faces) == len(
- self.inverse_affine_matrices
- ), "length of restored_faces and affine_matrices are different."
-
- inv_mask_borders = []
- for restored_face, inverse_affine in zip(
- self.restored_faces, self.inverse_affine_matrices
- ):
- if face_upsampler is not None:
- restored_face = face_upsampler.enhance(
- restored_face, outscale=self.upscale_factor
- )[0]
- inverse_affine /= self.upscale_factor
- inverse_affine[:, 2] *= self.upscale_factor
- face_size = (
- self.face_size[0] * self.upscale_factor,
- self.face_size[1] * self.upscale_factor,
- )
- else:
- # Add an offset to inverse affine matrix, for more precise back alignment
- if self.upscale_factor > 1:
- extra_offset = 0.5 * self.upscale_factor
- else:
- extra_offset = 0
- inverse_affine[:, 2] += extra_offset
- face_size = self.face_size
- inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
-
- # if draw_box or not self.use_parse: # use square parse maps
- # mask = np.ones(face_size, dtype=np.float32)
- # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
- # # remove the black borders
- # inv_mask_erosion = cv2.erode(
- # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
- # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
- # total_face_area = np.sum(inv_mask_erosion) # // 3
- # # add border
- # if draw_box:
- # h, w = face_size
- # mask_border = np.ones((h, w, 3), dtype=np.float32)
- # border = int(1400/np.sqrt(total_face_area))
- # mask_border[border:h-border, border:w-border,:] = 0
- # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
- # inv_mask_borders.append(inv_mask_border)
- # if not self.use_parse:
- # # compute the fusion edge based on the area of face
- # w_edge = int(total_face_area**0.5) // 20
- # erosion_radius = w_edge * 2
- # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
- # blur_size = w_edge * 2
- # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
- # if len(upsample_img.shape) == 2: # upsample_img is gray image
- # upsample_img = upsample_img[:, :, None]
- # inv_soft_mask = inv_soft_mask[:, :, None]
-
- # always use square mask
- mask = np.ones(face_size, dtype=np.float32)
- inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
- # remove the black borders
- inv_mask_erosion = cv2.erode(
- inv_mask,
- np.ones(
- (int(2 * self.upscale_factor), int(2 * self.upscale_factor)),
- np.uint8,
- ),
- )
- pasted_face = inv_mask_erosion[:, :, None] * inv_restored
- total_face_area = np.sum(inv_mask_erosion) # // 3
- # add border
- if draw_box:
- h, w = face_size
- mask_border = np.ones((h, w, 3), dtype=np.float32)
- border = int(1400 / np.sqrt(total_face_area))
- mask_border[border : h - border, border : w - border, :] = 0
- inv_mask_border = cv2.warpAffine(
- mask_border, inverse_affine, (w_up, h_up)
- )
- inv_mask_borders.append(inv_mask_border)
- # compute the fusion edge based on the area of face
- w_edge = int(total_face_area**0.5) // 20
- erosion_radius = w_edge * 2
- inv_mask_center = cv2.erode(
- inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)
- )
- blur_size = w_edge * 2
- inv_soft_mask = cv2.GaussianBlur(
- inv_mask_center, (blur_size + 1, blur_size + 1), 0
- )
- if len(upsample_img.shape) == 2: # upsample_img is gray image
- upsample_img = upsample_img[:, :, None]
- inv_soft_mask = inv_soft_mask[:, :, None]
-
- # parse mask
- if self.use_parse:
- # inference
- face_input = cv2.resize(
- restored_face, (512, 512), interpolation=cv2.INTER_LINEAR
- )
- face_input = img2tensor(
- face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True
- )
- normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- face_input = torch.unsqueeze(face_input, 0).to(self.device)
- with torch.no_grad():
- out = self.face_parse(face_input)[0]
- out = out.argmax(dim=1).squeeze().cpu().numpy()
-
- parse_mask = np.zeros(out.shape)
- MASK_COLORMAP = [
- 0,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 255,
- 0,
- 255,
- 0,
- 0,
- 0,
- ]
- for idx, color in enumerate(MASK_COLORMAP):
- parse_mask[out == idx] = color
- # blur the mask
- parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
- parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
- # remove the black borders
- thres = 10
- parse_mask[:thres, :] = 0
- parse_mask[-thres:, :] = 0
- parse_mask[:, :thres] = 0
- parse_mask[:, -thres:] = 0
- parse_mask = parse_mask / 255.0
-
- parse_mask = cv2.resize(parse_mask, face_size)
- parse_mask = cv2.warpAffine(
- parse_mask, inverse_affine, (w_up, h_up), flags=3
- )
- inv_soft_parse_mask = parse_mask[:, :, None]
- # pasted_face = inv_restored
- fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype("int")
- inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (
- 1 - fuse_mask
- )
-
- if (
- len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4
- ): # alpha channel
- alpha = upsample_img[:, :, 3:]
- upsample_img = (
- inv_soft_mask * pasted_face
- + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
- )
- upsample_img = np.concatenate((upsample_img, alpha), axis=2)
- else:
- upsample_img = (
- inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
- )
-
- if np.max(upsample_img) > 256: # 16-bit image
- upsample_img = upsample_img.astype(np.uint16)
- else:
- upsample_img = upsample_img.astype(np.uint8)
-
- # draw bounding box
- if draw_box:
- # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
- img_color = np.ones([*upsample_img.shape], dtype=np.float32)
- img_color[:, :, 0] = 0
- img_color[:, :, 1] = 255
- img_color[:, :, 2] = 0
- for inv_mask_border in inv_mask_borders:
- upsample_img = (
- inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
- )
- # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
-
- if save_path is not None:
- path = os.path.splitext(save_path)[0]
- save_path = f"{path}.{self.save_ext}"
- imwrite(upsample_img, save_path)
- return upsample_img
-
- def clean_all(self):
- self.all_landmarks_5 = []
- self.restored_faces = []
- self.affine_matrices = []
- self.cropped_faces = []
- self.inverse_affine_matrices = []
- self.det_faces = []
- self.pad_input_imgs = []
+import os
+
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms.functional import normalize
+
+from hordelib.nodes.facerestore_cf.facelib.detection import init_detection_model
+from hordelib.nodes.facerestore_cf.facelib.parsing import init_parsing_model
+from hordelib.nodes.facerestore_cf.facelib.utils.misc import img2tensor, imwrite
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper:
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(
+ self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model="retinaface_resnet50",
+ save_ext="png",
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None,
+ ):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = upscale_factor
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1"
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+
+ if self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array(
+ [
+ [192.98138, 239.94708],
+ [318.90277, 240.1936],
+ [256.63416, 314.01935],
+ [201.26117, 371.41043],
+ [313.08905, 371.15118],
+ ],
+ )
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ else:
+ self.device = device
+
+ # init face detection model
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(model_name="parsenet", device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+
+ if min(self.input_img.shape[:2]) < 512:
+ f = 512.0 / min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def get_face_landmarks_5(
+ self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None,
+ ):
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_det.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (
+ int(np.floor(min(quad[:, 0]))),
+ int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))),
+ )
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1),
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect")
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(
+ 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]),
+ )
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype("float32")
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode="constant"):
+ """Align and warp faces with face template."""
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5,
+ ), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}"
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == "constant":
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == "reflect101":
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == "reflect":
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ cropped_face = cv2.warpAffine(
+ input_img,
+ affine_matrix,
+ self.face_size,
+ borderMode=border_mode,
+ borderValue=(135, 133, 132),
+ ) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f"{path}_{idx:02d}.{self.save_ext}"
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f"{path}_{idx:02d}.pth"
+ torch.save(inverse_affine, save_path)
+
+ def add_restored_face(self, face):
+ self.restored_faces.append(face)
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices,
+ ), "length of restored_faces and affine_matrices are different."
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices, strict=False):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask,
+ np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8),
+ )
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400 / np.sqrt(total_face_area))
+ mask_border[border : h - border, border : w - border, :] = 0
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.0
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype("int")
+ inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
+
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
+ alpha = upsample_img[:, :, 3:]
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+ else:
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
+
+ if np.max(upsample_img) > 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:, :, 0] = 0
+ img_color[:, :, 1] = 255
+ img_color[:, :, 2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f"{path}.{self.save_ext}"
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
diff --git a/hordelib/nodes/facerestore/facelib/utils/face_utils.py b/hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py
similarity index 93%
rename from hordelib/nodes/facerestore/facelib/utils/face_utils.py
rename to hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py
index 5ee39570..5769a0f7 100644
--- a/hordelib/nodes/facerestore/facelib/utils/face_utils.py
+++ b/hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py
@@ -1,283 +1,275 @@
-import cv2
-import numpy as np
-import torch
-
-
-def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
- left, top, right, bot = bbox
- width = right - left
- height = bot - top
-
- if preserve_aspect:
- width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
- height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
- else:
- width_increase = height_increase = increase_area
- left = int(left - width_increase * width)
- top = int(top - height_increase * height)
- right = int(right + width_increase * width)
- bot = int(bot + height_increase * height)
- return (left, top, right, bot)
-
-
-def get_valid_bboxes(bboxes, h, w):
- left = max(bboxes[0], 0)
- top = max(bboxes[1], 0)
- right = min(bboxes[2], w)
- bottom = min(bboxes[3], h)
- return (left, top, right, bottom)
-
-
-def align_crop_face_landmarks(
- img,
- landmarks,
- output_size,
- transform_size=None,
- enable_padding=True,
- return_inverse_affine=False,
- shrink_ratio=(1, 1),
-):
- """Align and crop face with landmarks.
-
- The output_size and transform_size are based on width. The height is
- adjusted based on shrink_ratio_h/shring_ration_w.
-
- Modified from:
- https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
-
- Args:
- img (Numpy array): Input image.
- landmarks (Numpy array): 5 or 68 or 98 landmarks.
- output_size (int): Output face size.
- transform_size (ing): Transform size. Usually the four time of
- output_size.
- enable_padding (float): Default: True.
- shrink_ratio (float | tuple[float] | list[float]): Shring the whole
- face for height and width (crop larger area). Default: (1, 1).
-
- Returns:
- (Numpy array): Cropped face.
- """
- lm_type = "retinaface_5" # Options: dlib_5, retinaface_5
-
- if isinstance(shrink_ratio, (float, int)):
- shrink_ratio = (shrink_ratio, shrink_ratio)
- if transform_size is None:
- transform_size = output_size * 4
-
- # Parse landmarks
- lm = np.array(landmarks)
- if lm.shape[0] == 5 and lm_type == "retinaface_5":
- eye_left = lm[0]
- eye_right = lm[1]
- mouth_avg = (lm[3] + lm[4]) * 0.5
- elif lm.shape[0] == 5 and lm_type == "dlib_5":
- lm_eye_left = lm[2:4]
- lm_eye_right = lm[0:2]
- eye_left = np.mean(lm_eye_left, axis=0)
- eye_right = np.mean(lm_eye_right, axis=0)
- mouth_avg = lm[4]
- elif lm.shape[0] == 68:
- lm_eye_left = lm[36:42]
- lm_eye_right = lm[42:48]
- eye_left = np.mean(lm_eye_left, axis=0)
- eye_right = np.mean(lm_eye_right, axis=0)
- mouth_avg = (lm[48] + lm[54]) * 0.5
- elif lm.shape[0] == 98:
- lm_eye_left = lm[60:68]
- lm_eye_right = lm[68:76]
- eye_left = np.mean(lm_eye_left, axis=0)
- eye_right = np.mean(lm_eye_right, axis=0)
- mouth_avg = (lm[76] + lm[82]) * 0.5
-
- eye_avg = (eye_left + eye_right) * 0.5
- eye_to_eye = eye_right - eye_left
- eye_to_mouth = mouth_avg - eye_avg
-
- # Get the oriented crop rectangle
- # x: half width of the oriented crop rectangle
- x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
- # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
- # norm with the hypotenuse: get the direction
- x /= np.hypot(*x) # get the hypotenuse of a right triangle
- rect_scale = 1 # TODO: you can edit it to get larger rect
- x *= max(
- np.hypot(*eye_to_eye) * 2.0 * rect_scale,
- np.hypot(*eye_to_mouth) * 1.8 * rect_scale,
- )
- # y: half height of the oriented crop rectangle
- y = np.flipud(x) * [-1, 1]
-
- x *= shrink_ratio[1] # width
- y *= shrink_ratio[0] # height
-
- # c: center
- c = eye_avg + eye_to_mouth * 0.1
- # quad: (left_top, left_bottom, right_bottom, right_top)
- quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
- # qsize: side length of the square
- qsize = np.hypot(*x) * 2
-
- quad_ori = np.copy(quad)
- # Shrink, for large face
- # TODO: do we really need shrink
- shrink = int(np.floor(qsize / output_size * 0.5))
- if shrink > 1:
- h, w = img.shape[0:2]
- rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
- img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
- quad /= shrink
- qsize /= shrink
-
- # Crop
- h, w = img.shape[0:2]
- border = max(int(np.rint(qsize * 0.1)), 3)
- crop = (
- int(np.floor(min(quad[:, 0]))),
- int(np.floor(min(quad[:, 1]))),
- int(np.ceil(max(quad[:, 0]))),
- int(np.ceil(max(quad[:, 1]))),
- )
- crop = (
- max(crop[0] - border, 0),
- max(crop[1] - border, 0),
- min(crop[2] + border, w),
- min(crop[3] + border, h),
- )
- if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
- img = img[crop[1] : crop[3], crop[0] : crop[2], :]
- quad -= crop[0:2]
-
- # Pad
- # pad: (width_left, height_top, width_right, height_bottom)
- h, w = img.shape[0:2]
- pad = (
- int(np.floor(min(quad[:, 0]))),
- int(np.floor(min(quad[:, 1]))),
- int(np.ceil(max(quad[:, 0]))),
- int(np.ceil(max(quad[:, 1]))),
- )
- pad = (
- max(-pad[0] + border, 0),
- max(-pad[1] + border, 0),
- max(pad[2] - w + border, 0),
- max(pad[3] - h + border, 0),
- )
- if enable_padding and max(pad) > border - 4:
- pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
- img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect")
- h, w = img.shape[0:2]
- y, x, _ = np.ogrid[:h, :w, :1]
- mask = np.maximum(
- 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
- 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]),
- )
- blur = int(qsize * 0.02)
- if blur % 2 == 0:
- blur += 1
- blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
-
- img = img.astype("float32")
- img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
- img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
- img = np.clip(img, 0, 255) # float32, [0, 255]
- quad += pad[:2]
-
- # Transform use cv2
- h_ratio = shrink_ratio[0] / shrink_ratio[1]
- dst_h, dst_w = int(transform_size * h_ratio), transform_size
- template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
- # use cv2.LMEDS method for the equivalence to skimage transform
- # ref: https://blog.csdn.net/yichxi/article/details/115827338
- affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
- cropped_face = cv2.warpAffine(
- img,
- affine_matrix,
- (dst_w, dst_h),
- borderMode=cv2.BORDER_CONSTANT,
- borderValue=(135, 133, 132),
- ) # gray
-
- if output_size < transform_size:
- cropped_face = cv2.resize(
- cropped_face,
- (output_size, int(output_size * h_ratio)),
- interpolation=cv2.INTER_LINEAR,
- )
-
- if return_inverse_affine:
- dst_h, dst_w = int(output_size * h_ratio), output_size
- template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
- # use cv2.LMEDS method for the equivalence to skimage transform
- # ref: https://blog.csdn.net/yichxi/article/details/115827338
- affine_matrix = cv2.estimateAffinePartial2D(
- quad_ori,
- np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]),
- method=cv2.LMEDS,
- )[0]
- inverse_affine = cv2.invertAffineTransform(affine_matrix)
- else:
- inverse_affine = None
- return cropped_face, inverse_affine
-
-
-def paste_face_back(img, face, inverse_affine):
- h, w = img.shape[0:2]
- face_h, face_w = face.shape[0:2]
- inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
- mask = np.ones((face_h, face_w, 3), dtype=np.float32)
- inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
- # remove the black borders
- inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
- inv_restored_remove_border = inv_mask_erosion * inv_restored
- total_face_area = np.sum(inv_mask_erosion) // 3
- # compute the fusion edge based on the area of face
- w_edge = int(total_face_area**0.5) // 20
- erosion_radius = w_edge * 2
- inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
- blur_size = w_edge * 2
- inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
- img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
- # float32, [0, 255]
- return img
-
-
-if __name__ == "__main__":
- import os
-
- from hordelib.nodes.facerestore.facelib.detection import init_detection_model
- from hordelib.nodes.facerestore.facelib.utils.face_restoration_helper import get_largest_face
-
- img_path = "/home/wxt/datasets/ffhq/ffhq_wild/00009.png"
- img_name = os.path.splitext(os.path.basename(img_path))[0]
-
- # initialize model
- det_net = init_detection_model("retinaface_resnet50", half=False)
- img_ori = cv2.imread(img_path)
- h, w = img_ori.shape[0:2]
- # if larger than 800, scale it
- scale = max(h / 800, w / 800)
- if scale > 1:
- img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
-
- with torch.no_grad():
- bboxes = det_net.detect_faces(img, 0.97)
- if scale > 1:
- bboxes *= scale # the score is incorrect
- bboxes = get_largest_face(bboxes, h, w)[0]
-
- landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
-
- cropped_face, inverse_affine = align_crop_face_landmarks(
- img_ori,
- landmarks,
- output_size=512,
- transform_size=None,
- enable_padding=True,
- return_inverse_affine=True,
- shrink_ratio=(1, 1),
- )
-
- cv2.imwrite(f"tmp/{img_name}_cropeed_face.png", cropped_face)
- img = paste_face_back(img_ori, cropped_face, inverse_affine)
- cv2.imwrite(f"tmp/{img_name}_back.png", img)
+import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ if preserve_aspect:
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+ else:
+ width_increase = height_increase = increase_area
+ left = int(left - width_increase * width)
+ top = int(top - height_increase * height)
+ right = int(right + width_increase * width)
+ bot = int(bot + height_increase * height)
+ return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+ left = max(bboxes[0], 0)
+ top = max(bboxes[1], 0)
+ right = min(bboxes[2], w)
+ bottom = min(bboxes[3], h)
+ return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(
+ img,
+ landmarks,
+ output_size,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=False,
+ shrink_ratio=(1, 1),
+):
+ """Align and crop face with landmarks.
+
+ The output_size and transform_size are based on width. The height is
+ adjusted based on shrink_ratio_h/shring_ration_w.
+
+ Modified from:
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+ Args:
+ img (Numpy array): Input image.
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
+ output_size (int): Output face size.
+ transform_size (ing): Transform size. Usually the four time of
+ output_size.
+ enable_padding (float): Default: True.
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+ face for height and width (crop larger area). Default: (1, 1).
+
+ Returns:
+ (Numpy array): Cropped face.
+ """
+ lm_type = "retinaface_5" # Options: dlib_5, retinaface_5
+
+ if isinstance(shrink_ratio, (float, int)):
+ shrink_ratio = (shrink_ratio, shrink_ratio)
+ if transform_size is None:
+ transform_size = output_size * 4
+
+ # Parse landmarks
+ lm = np.array(landmarks)
+ if lm.shape[0] == 5 and lm_type == "retinaface_5":
+ eye_left = lm[0]
+ eye_right = lm[1]
+ mouth_avg = (lm[3] + lm[4]) * 0.5
+ elif lm.shape[0] == 5 and lm_type == "dlib_5":
+ lm_eye_left = lm[2:4]
+ lm_eye_right = lm[0:2]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = lm[4]
+ elif lm.shape[0] == 68:
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[48] + lm[54]) * 0.5
+ elif lm.shape[0] == 98:
+ lm_eye_left = lm[60:68]
+ lm_eye_right = lm[68:76]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[76] + lm[82]) * 0.5
+
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1 # TODO: you can edit it to get larger rect
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ x *= shrink_ratio[1] # width
+ y *= shrink_ratio[0] # height
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+
+ quad_ori = np.copy(quad)
+ # Shrink, for large face
+ # TODO: do we really need shrink
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ h, w = img.shape[0:2]
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop
+ h, w = img.shape[0:2]
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (
+ int(np.floor(min(quad[:, 0]))),
+ int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))),
+ )
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+ img = img[crop[1] : crop[3], crop[0] : crop[2], :]
+ quad -= crop[0:2]
+
+ # Pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ h, w = img.shape[0:2]
+ pad = (
+ int(np.floor(min(quad[:, 0]))),
+ int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))),
+ )
+ pad = (
+ max(-pad[0] + border, 0),
+ max(-pad[1] + border, 0),
+ max(pad[2] - w + border, 0),
+ max(pad[3] - h + border, 0),
+ )
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect")
+ h, w = img.shape[0:2]
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(
+ 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]),
+ )
+ blur = int(qsize * 0.02)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+ img = img.astype("float32")
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = np.clip(img, 0, 255) # float32, [0, 255]
+ quad += pad[:2]
+
+ # Transform use cv2
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+ cropped_face = cv2.warpAffine(
+ img,
+ affine_matrix,
+ (dst_w, dst_h),
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(135, 133, 132),
+ ) # gray
+
+ if output_size < transform_size:
+ cropped_face = cv2.resize(
+ cropped_face,
+ (output_size, int(output_size * h_ratio)),
+ interpolation=cv2.INTER_LINEAR,
+ )
+
+ if return_inverse_affine:
+ dst_h, dst_w = int(output_size * h_ratio), output_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ quad_ori,
+ np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]),
+ method=cv2.LMEDS,
+ )[0]
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ else:
+ inverse_affine = None
+ return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+ h, w = img.shape[0:2]
+ face_h, face_w = face.shape[0:2]
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+ # float32, [0, 255]
+ return img
+
+
+if __name__ == "__main__":
+ import os
+
+ from hordelib.nodes.facerestore_cf.facelib.detection import init_detection_model
+ from hordelib.nodes.facerestore_cf.facelib.utils.face_restoration_helper import get_largest_face
+
+ img_path = "/home/wxt/datasets/ffhq/ffhq_wild/00009.png"
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
+
+ # initialize model
+ det_net = init_detection_model("retinaface_resnet50", half=False)
+ img_ori = cv2.imread(img_path)
+ h, w = img_ori.shape[0:2]
+ # if larger than 800, scale it
+ scale = max(h / 800, w / 800)
+ if scale > 1:
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+ with torch.no_grad():
+ bboxes = det_net.detect_faces(img, 0.97)
+ if scale > 1:
+ bboxes *= scale # the score is incorrect
+ bboxes = get_largest_face(bboxes, h, w)[0]
+
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+ cropped_face, inverse_affine = align_crop_face_landmarks(
+ img_ori,
+ landmarks,
+ output_size=512,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=True,
+ shrink_ratio=(1, 1),
+ )
+
+ cv2.imwrite(f"tmp/{img_name}_cropeed_face.png", cropped_face)
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
+ cv2.imwrite(f"tmp/{img_name}_back.png", img)
diff --git a/hordelib/nodes/facerestore/facelib/utils/misc.py b/hordelib/nodes/facerestore_cf/facelib/utils/misc.py
similarity index 70%
rename from hordelib/nodes/facerestore/facelib/utils/misc.py
rename to hordelib/nodes/facerestore_cf/facelib/utils/misc.py
index 68b7fe9d..0eee9f24 100644
--- a/hordelib/nodes/facerestore/facelib/utils/misc.py
+++ b/hordelib/nodes/facerestore_cf/facelib/utils/misc.py
@@ -1,143 +1,132 @@
-import cv2
-import os
-import os.path as osp
-import torch
-from torch.hub import download_url_to_file, get_dir
-from urllib.parse import urlparse
-# from basicsr.utils.download_util import download_file_from_google_drive
-#import gdown
-
-
-ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-
-def download_pretrained_models(file_ids, save_path_root):
- os.makedirs(save_path_root, exist_ok=True)
-
- for file_name, file_id in file_ids.items():
- file_url = 'https://drive.google.com/uc?id='+file_id
- save_path = osp.abspath(osp.join(save_path_root, file_name))
- if osp.exists(save_path):
- user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
- if user_response.lower() == 'y':
- print(f'Covering {file_name} to {save_path}')
- print("skipping gdown in facelib/utils/misc.py "+file_url)
- #gdown.download(file_url, save_path, quiet=False)
- # download_file_from_google_drive(file_id, save_path)
- elif user_response.lower() == 'n':
- print(f'Skipping {file_name}')
- else:
- raise ValueError('Wrong input. Only accepts Y/N.')
- else:
- print(f'Downloading {file_name} to {save_path}')
- print("skipping gdown in facelib/utils/misc.py "+file_url)
- #gdown.download(file_url, save_path, quiet=False)
- # download_file_from_google_drive(file_id, save_path)
-
-
-def imwrite(img, file_path, params=None, auto_mkdir=True):
- """Write image to file.
-
- Args:
- img (ndarray): Image array to be written.
- file_path (str): Image file path.
- params (None or list): Same as opencv's :func:`imwrite` interface.
- auto_mkdir (bool): If the parent folder of `file_path` does not exist,
- whether to create it automatically.
-
- Returns:
- bool: Successful or not.
- """
- if auto_mkdir:
- dir_name = os.path.abspath(os.path.dirname(file_path))
- os.makedirs(dir_name, exist_ok=True)
- return cv2.imwrite(file_path, img, params)
-
-
-def img2tensor(imgs, bgr2rgb=True, float32=True):
- """Numpy array to tensor.
-
- Args:
- imgs (list[ndarray] | ndarray): Input images.
- bgr2rgb (bool): Whether to change bgr to rgb.
- float32 (bool): Whether to change to float32.
-
- Returns:
- list[tensor] | tensor: Tensor images. If returned results only have
- one element, just return tensor.
- """
-
- def _totensor(img, bgr2rgb, float32):
- if img.shape[2] == 3 and bgr2rgb:
- if img.dtype == 'float64':
- img = img.astype('float32')
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = torch.from_numpy(img.transpose(2, 0, 1))
- if float32:
- img = img.float()
- return img
-
- if isinstance(imgs, list):
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
- else:
- return _totensor(imgs, bgr2rgb, float32)
-
-
-def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
- """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
- """
- if model_dir is None:
- hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, 'checkpoints')
-
- os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
-
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- if file_name is not None:
- filename = file_name
- cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
- if not os.path.exists(cached_file):
- print(f'Downloading: "{url}" to {cached_file}\n')
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
- return cached_file
-
-
-def scandir(dir_path, suffix=None, recursive=False, full_path=False):
- """Scan a directory to find the interested files.
- Args:
- dir_path (str): Path of the directory.
- suffix (str | tuple(str), optional): File suffix that we are
- interested in. Default: None.
- recursive (bool, optional): If set to True, recursively scan the
- directory. Default: False.
- full_path (bool, optional): If set to True, include the dir_path.
- Default: False.
- Returns:
- A generator for all the interested files with relative paths.
- """
-
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
- raise TypeError('"suffix" must be a string or tuple of strings')
-
- root = dir_path
-
- def _scandir(dir_path, suffix, recursive):
- for entry in os.scandir(dir_path):
- if not entry.name.startswith('.') and entry.is_file():
- if full_path:
- return_path = entry.path
- else:
- return_path = osp.relpath(entry.path, root)
-
- if suffix is None:
- yield return_path
- elif return_path.endswith(suffix):
- yield return_path
- else:
- if recursive:
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
- else:
- continue
-
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
+import os
+import os.path as osp
+from urllib.parse import urlparse
+
+import cv2
+import folder_paths
+import torch
+from torch.hub import download_url_to_file, get_dir
+from hordelib.shared_model_manager import SharedModelManager
+
+# from hordelib.nodes.facerestore.basicsr.utils.download_util import download_file_from_google_drive
+# import gdown
+
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = "https://drive.google.com/uc?id=" + file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n")
+ if user_response.lower() == "y":
+ print(f"Covering {file_name} to {save_path}")
+ print("skipping gdown in facelib/utils/misc.py " + file_url)
+ # gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == "n":
+ print(f"Skipping {file_name}")
+ else:
+ raise ValueError("Wrong input. Only accepts Y/N.")
+ else:
+ print(f"Downloading {file_name} to {save_path}")
+ print("skipping gdown in facelib/utils/misc.py " + file_url)
+ # gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py"""
+ return str(SharedModelManager.manager.gfpgan.model_folder_path / file_name)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith(".") and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/README.md b/hordelib/nodes/facerestore_cf/r_chainner/README.md
new file mode 100644
index 00000000..6b254d3f
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/r_chainner/README.md
@@ -0,0 +1,3 @@
+Clean implementation for GFPGAN loading copied from [this commit](https://github.com/Gourieff/comfyui-reactor-node/commit/a7ae66912f80e8ccd97bb83bf83ab8187b077287#diff-b668993a9f6df352129e883337a4f2c96b31ab61afd82c4ae948d40864962c12) to solve [this issue](https://github.com/mav-rik/facerestore_cf)
+
+Solution Discovered in this issue: https://github.com/comfyanonymous/ComfyUI/issues/3594
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/__init__.py b/hordelib/nodes/facerestore_cf/r_chainner/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py b/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py
new file mode 100644
index 00000000..7f2f0e75
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py
@@ -0,0 +1,370 @@
+# pylint: skip-file
+# type: ignore
+import math
+import random
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from hordelib.nodes.facerestore_cf.r_chainner.stylegan2_clean_arch import StyleGAN2GeneratorClean
+
+
+class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+ narrow (float): The narrow ratio for channels. Default: 1.
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
+ """
+
+ def __init__(
+ self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ narrow=1,
+ sft_half=False,
+ ):
+ super(StyleGAN2GeneratorCSFT, self).__init__(
+ out_size,
+ num_style_feat=num_style_feat,
+ num_mlp=num_mlp,
+ channel_multiplier=channel_multiplier,
+ narrow=narrow,
+ )
+ self.sft_half = sft_half
+
+ def forward(
+ self,
+ styles,
+ conditions,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False,
+ ):
+ """Forward function for StyleGAN2GeneratorCSFT.
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ conditions (list[Tensor]): SFT conditions to generators.
+ input_is_latent (bool): Whether input is latent style. Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
+ truncation (float): The truncation ratio. Default: 1.
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
+ inject_index (int | None): The injection index for mixing noise. Default: None.
+ return_latents (bool): Whether to return style latents. Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
+ ]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+ styles = style_truncation
+ # get style latents with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = (
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ )
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.style_convs[::2],
+ self.style_convs[1::2],
+ noise[1::2],
+ noise[2::2],
+ self.to_rgbs,
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+
+ # the conditions may have fewer levels
+ if i < len(conditions):
+ # SFT part to combine the conditions
+ if self.sft_half: # only apply SFT to half of the channels
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
+ out = torch.cat([out_same, out_sft], dim=1)
+ else: # apply SFT to all the channels
+ out = out * conditions[i - 1] + conditions[i]
+
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
+
+
+class ResBlock(nn.Module):
+ """Residual block with bilinear upsampling/downsampling.
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
+ """
+
+ def __init__(self, in_channels, out_channels, mode="down"):
+ super(ResBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
+ if mode == "down":
+ self.scale_factor = 0.5
+ elif mode == "up":
+ self.scale_factor = 2
+
+ def forward(self, x):
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
+ # upsample/downsample
+ out = F.interpolate(
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
+ )
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
+ # skip
+ x = F.interpolate(
+ x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
+ )
+ skip = self.skip(x)
+ out = out + skip
+ return out
+
+
+class GFPGANv1Clean(nn.Module):
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ input_is_latent (bool): Whether input is latent style. Default: False.
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
+ narrow (float): The narrow ratio for channels. Default: 1.
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
+ """
+
+ def __init__(
+ self,
+ state_dict,
+ ):
+ super(GFPGANv1Clean, self).__init__()
+
+ out_size = 512
+ num_style_feat = 512
+ channel_multiplier = 2
+ decoder_load_path = None
+ fix_decoder = False
+ num_mlp = 8
+ input_is_latent = True
+ different_w = True
+ narrow = 1
+ sft_half = True
+
+ self.model_arch = "GFPGAN"
+ self.sub_type = "Face SR"
+ self.scale = 8
+ self.in_nc = 3
+ self.out_nc = 3
+ self.state = state_dict
+
+ self.supports_fp16 = False
+ self.supports_bf16 = True
+ self.min_size_restriction = 512
+
+ self.input_is_latent = input_is_latent
+ self.different_w = different_w
+ self.num_style_feat = num_style_feat
+
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
+ channels = {
+ "4": int(512 * unet_narrow),
+ "8": int(512 * unet_narrow),
+ "16": int(512 * unet_narrow),
+ "32": int(512 * unet_narrow),
+ "64": int(256 * channel_multiplier * unet_narrow),
+ "128": int(128 * channel_multiplier * unet_narrow),
+ "256": int(64 * channel_multiplier * unet_narrow),
+ "512": int(32 * channel_multiplier * unet_narrow),
+ "1024": int(16 * channel_multiplier * unet_narrow),
+ }
+
+ self.log_size = int(math.log(out_size, 2))
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
+
+ self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
+
+ # downsample
+ in_channels = channels[f"{first_out_size}"]
+ self.conv_body_down = nn.ModuleList()
+ for i in range(self.log_size, 2, -1):
+ out_channels = channels[f"{2**(i - 1)}"]
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
+ in_channels = out_channels
+
+ self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
+
+ # upsample
+ in_channels = channels["4"]
+ self.conv_body_up = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f"{2**i}"]
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
+ in_channels = out_channels
+
+ # to RGB
+ self.toRGB = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
+
+ if different_w:
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
+ else:
+ linear_out_channel = num_style_feat
+
+ self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
+
+ # the decoder: stylegan2 generator with SFT modulations
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
+ out_size=out_size,
+ num_style_feat=num_style_feat,
+ num_mlp=num_mlp,
+ channel_multiplier=channel_multiplier,
+ narrow=narrow,
+ sft_half=sft_half,
+ )
+
+ # load pre-trained stylegan2 model if necessary
+ if decoder_load_path:
+ self.stylegan_decoder.load_state_dict(
+ torch.load(
+ decoder_load_path, map_location=lambda storage, loc: storage
+ )["params_ema"]
+ )
+ # fix decoder without updating params
+ if fix_decoder:
+ for _, param in self.stylegan_decoder.named_parameters():
+ param.requires_grad = False
+
+ # for SFT modulations (scale and shift)
+ self.condition_scale = nn.ModuleList()
+ self.condition_shift = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f"{2**i}"]
+ if sft_half:
+ sft_out_channels = out_channels
+ else:
+ sft_out_channels = out_channels * 2
+ self.condition_scale.append(
+ nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
+ )
+ )
+ self.condition_shift.append(
+ nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
+ )
+ )
+ self.load_state_dict(state_dict)
+
+ def forward(
+ self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
+ ):
+ """Forward function for GFPGANv1Clean.
+ Args:
+ x (Tensor): Input images.
+ return_latents (bool): Whether to return style latents. Default: False.
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
+ """
+ conditions = []
+ unet_skips = []
+ out_rgbs = []
+
+ # encoder
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
+ for i in range(self.log_size - 2):
+ feat = self.conv_body_down[i](feat)
+ unet_skips.insert(0, feat)
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
+
+ # style code
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
+ if self.different_w:
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
+
+ # decode
+ for i in range(self.log_size - 2):
+ # add unet skip
+ feat = feat + unet_skips[i]
+ # ResUpLayer
+ feat = self.conv_body_up[i](feat)
+ # generate scale and shift for SFT layers
+ scale = self.condition_scale[i](feat)
+ conditions.append(scale.clone())
+ shift = self.condition_shift[i](feat)
+ conditions.append(shift.clone())
+ # generate rgb images
+ if return_rgb:
+ out_rgbs.append(self.toRGB[i](feat))
+
+ # decoder
+ image, _ = self.stylegan_decoder(
+ [style_code],
+ conditions,
+ return_latents=return_latents,
+ input_is_latent=self.input_is_latent,
+ randomize_noise=randomize_noise,
+ )
+
+ return image, out_rgbs
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py b/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py
new file mode 100644
index 00000000..598e605c
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py
@@ -0,0 +1,29 @@
+
+from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
+from hordelib.nodes.facerestore_cf.r_chainner.types import PyTorchModel
+
+
+class UnsupportedModel(Exception):
+ pass
+
+
+def load_state_dict(state_dict) -> PyTorchModel:
+
+ state_dict_keys = list(state_dict.keys())
+
+ if "params_ema" in state_dict_keys:
+ state_dict = state_dict["params_ema"]
+ elif "params-ema" in state_dict_keys:
+ state_dict = state_dict["params-ema"]
+ elif "params" in state_dict_keys:
+ state_dict = state_dict["params"]
+
+ state_dict_keys = list(state_dict.keys())
+
+ # GFPGAN
+ if (
+ "toRGB.0.weight" in state_dict_keys
+ and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
+ ):
+ model = GFPGANv1Clean(state_dict)
+ return model
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py b/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py
new file mode 100644
index 00000000..c48de9af
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py
@@ -0,0 +1,453 @@
+# pylint: skip-file
+# type: ignore
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn import init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+class NormStyleCode(nn.Module):
+ def forward(self, x):
+ """Normalize the style codes.
+ Args:
+ x (Tensor): Style codes with shape (b, c).
+ Returns:
+ Tensor: Normalized tensor.
+ """
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+class ModulatedConv2d(nn.Module):
+ """Modulated Conv2d used in StyleGAN2.
+ There is no bias in ModulatedConv2d.
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ eps=1e-8,
+ ):
+ super(ModulatedConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.demodulate = demodulate
+ self.sample_mode = sample_mode
+ self.eps = eps
+
+ # modulation inside each modulated conv
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
+ # initialization
+ default_init_weights(
+ self.modulation,
+ scale=1,
+ bias_fill=1,
+ a=0,
+ mode="fan_in",
+ nonlinearity="linear",
+ )
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
+ / math.sqrt(in_channels * kernel_size**2)
+ )
+ self.padding = kernel_size // 2
+
+ def forward(self, x, style):
+ """Forward function.
+ Args:
+ x (Tensor): Tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+ Returns:
+ Tensor: Modulated tensor after convolution.
+ """
+ b, c, h, w = x.shape # c = c_in
+ # weight modulation
+ style = self.modulation(style).view(b, 1, c, 1, 1)
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+ weight = self.weight * style # (b, c_out, c_in, k, k)
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+ weight = weight.view(
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
+ )
+
+ # upsample or downsample if necessary
+ if self.sample_mode == "upsample":
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
+ elif self.sample_mode == "downsample":
+ x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
+
+ b, c, h, w = x.shape
+ x = x.view(1, b * c, h, w)
+ # weight: (b*c_out, c_in, k, k), groups=b
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
+ f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
+ )
+
+
+class StyleConv(nn.Module):
+ """Style conv used in StyleGAN2.
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ ):
+ super(StyleConv, self).__init__()
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=demodulate,
+ sample_mode=sample_mode,
+ )
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x, style, noise=None):
+ # modulate
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
+ # noise injection
+ if noise is None:
+ b, _, h, w = out.shape
+ noise = out.new_empty(b, 1, h, w).normal_()
+ out = out + self.weight * noise
+ # add bias
+ out = out + self.bias
+ # activation
+ out = self.activate(out)
+ return out
+
+
+class ToRGB(nn.Module):
+ """To RGB (image space) from features.
+ Args:
+ in_channels (int): Channel number of input.
+ num_style_feat (int): Channel number of style features.
+ upsample (bool): Whether to upsample. Default: True.
+ """
+
+ def __init__(self, in_channels, num_style_feat, upsample=True):
+ super(ToRGB, self).__init__()
+ self.upsample = upsample
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ 3,
+ kernel_size=1,
+ num_style_feat=num_style_feat,
+ demodulate=False,
+ sample_mode=None,
+ )
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, x, style, skip=None):
+ """Forward function.
+ Args:
+ x (Tensor): Feature tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+ skip (Tensor): Base/skip tensor. Default: None.
+ Returns:
+ Tensor: RGB images.
+ """
+ out = self.modulated_conv(x, style)
+ out = out + self.bias
+ if skip is not None:
+ if self.upsample:
+ skip = F.interpolate(
+ skip, scale_factor=2, mode="bilinear", align_corners=False
+ )
+ out = out + skip
+ return out
+
+
+class ConstantInput(nn.Module):
+ """Constant input.
+ Args:
+ num_channel (int): Channel number of constant input.
+ size (int): Spatial size of constant input.
+ """
+
+ def __init__(self, num_channel, size):
+ super(ConstantInput, self).__init__()
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+ def forward(self, batch):
+ out = self.weight.repeat(batch, 1, 1, 1)
+ return out
+
+
+class StyleGAN2GeneratorClean(nn.Module):
+ """Clean version of StyleGAN2 Generator.
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(
+ self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
+ ):
+ super(StyleGAN2GeneratorClean, self).__init__()
+ # Style MLP layers
+ self.num_style_feat = num_style_feat
+ style_mlp_layers = [NormStyleCode()]
+ for i in range(num_mlp):
+ style_mlp_layers.extend(
+ [
+ nn.Linear(num_style_feat, num_style_feat, bias=True),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ ]
+ )
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
+ # initialization
+ default_init_weights(
+ self.style_mlp,
+ scale=1,
+ bias_fill=0,
+ a=0.2,
+ mode="fan_in",
+ nonlinearity="leaky_relu",
+ )
+
+ # channel list
+ channels = {
+ "4": int(512 * narrow),
+ "8": int(512 * narrow),
+ "16": int(512 * narrow),
+ "32": int(512 * narrow),
+ "64": int(256 * channel_multiplier * narrow),
+ "128": int(128 * channel_multiplier * narrow),
+ "256": int(64 * channel_multiplier * narrow),
+ "512": int(32 * channel_multiplier * narrow),
+ "1024": int(16 * channel_multiplier * narrow),
+ }
+ self.channels = channels
+
+ self.constant_input = ConstantInput(channels["4"], size=4)
+ self.style_conv1 = StyleConv(
+ channels["4"],
+ channels["4"],
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ )
+ self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
+
+ self.log_size = int(math.log(out_size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+ self.num_latent = self.log_size * 2 - 2
+
+ self.style_convs = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channels = channels["4"]
+ # noise
+ for layer_idx in range(self.num_layers):
+ resolution = 2 ** ((layer_idx + 5) // 2)
+ shape = [1, 1, resolution, resolution]
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
+ # style convs and to_rgbs
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f"{2**i}"]
+ self.style_convs.append(
+ StyleConv(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode="upsample",
+ )
+ )
+ self.style_convs.append(
+ StyleConv(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ )
+ )
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
+ in_channels = out_channels
+
+ def make_noise(self):
+ """Make noise for noise injection."""
+ device = self.constant_input.weight.device
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+ return noises
+
+ def get_latent(self, x):
+ return self.style_mlp(x)
+
+ def mean_latent(self, num_latent):
+ latent_in = torch.randn(
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
+ )
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+ return latent
+
+ def forward(
+ self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False,
+ ):
+ """Forward function for StyleGAN2GeneratorClean.
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ input_is_latent (bool): Whether input is latent style. Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
+ truncation (float): The truncation ratio. Default: 1.
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
+ inject_index (int | None): The injection index for mixing noise. Default: None.
+ return_latents (bool): Whether to return style latents. Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
+ ]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+ styles = style_truncation
+ # get style latents with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = (
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ )
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.style_convs[::2],
+ self.style_convs[1::2],
+ noise[1::2],
+ noise[2::2],
+ self.to_rgbs,
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
diff --git a/hordelib/nodes/facerestore_cf/r_chainner/types.py b/hordelib/nodes/facerestore_cf/r_chainner/types.py
new file mode 100644
index 00000000..20e39f68
--- /dev/null
+++ b/hordelib/nodes/facerestore_cf/r_chainner/types.py
@@ -0,0 +1,19 @@
+
+from typing import Union
+
+from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
+
+
+PyTorchFaceModels = (GFPGANv1Clean,)
+PyTorchFaceModel = Union[GFPGANv1Clean]
+
+
+def is_pytorch_face_model(model: object):
+ return isinstance(model, PyTorchFaceModels)
+
+PyTorchModels = (*PyTorchFaceModels, )
+PyTorchModel = Union[PyTorchFaceModel]
+
+
+def is_pytorch_model(model: object):
+ return isinstance(model, PyTorchModels)
diff --git a/hordelib/nodes/node_model_loader.py b/hordelib/nodes/node_model_loader.py
index 4ca17900..0105d019 100644
--- a/hordelib/nodes/node_model_loader.py
+++ b/hordelib/nodes/node_model_loader.py
@@ -9,8 +9,8 @@
import torch
from loguru import logger
-from hordelib.shared_model_manager import SharedModelManager
from hordelib.comfy_horde import log_free_ram
+from hordelib.shared_model_manager import SharedModelManager
# Don't let the name fool you, this class is trying to load all the files that will be necessary
diff --git a/hordelib/pipeline_designs/pipeline_image_facefix.json b/hordelib/pipeline_designs/pipeline_image_facefix.json
index 19b15fbe..42cc8630 100644
--- a/hordelib/pipeline_designs/pipeline_image_facefix.json
+++ b/hordelib/pipeline_designs/pipeline_image_facefix.json
@@ -1,6 +1,6 @@
{
- "last_node_id": 8,
- "last_link_id": 8,
+ "last_node_id": 11,
+ "last_link_id": 14,
"nodes": [
{
"id": 6,
@@ -9,10 +9,10 @@
771,
331
],
- "size": [
- 427.0001220703125,
- 416.33331298828125
- ],
+ "size": {
+ "0": 427.0001220703125,
+ "1": 416.33331298828125
+ },
"flags": {},
"order": 3,
"mode": 0,
@@ -20,7 +20,7 @@
{
"name": "images",
"type": "IMAGE",
- "link": 7
+ "link": 14
}
],
"title": "output_image",
@@ -30,11 +30,11 @@
]
},
{
- "id": 4,
- "type": "UpscaleModelLoader",
+ "id": 10,
+ "type": "FaceRestoreModelLoader",
"pos": [
- 24,
- 73
+ -5,
+ 145
],
"size": {
"0": 315,
@@ -45,133 +45,144 @@
"mode": 0,
"outputs": [
{
- "name": "UPSCALE_MODEL",
- "type": "UPSCALE_MODEL",
+ "name": "FACERESTORE_MODEL",
+ "type": "FACERESTORE_MODEL",
"links": [
- 6
+ 12
],
+ "shape": 3,
"slot_index": 0
}
],
"title": "model_loader",
"properties": {
- "Node name for S&R": "UpscaleModelLoader"
+ "Node name for S&R": "FaceRestoreModelLoader"
},
"widgets_values": [
"CodeFormers.pth"
]
},
{
- "id": 8,
- "type": "FaceRestoreWithModel",
- "title": "face_restore_with_model",
+ "id": 1,
+ "type": "LoadImage",
"pos": [
- 390,
- 194
+ 30,
+ 325
+ ],
+ "size": [
+ 315,
+ 314
],
- "size": {
- "0": 315,
- "1": 78
- },
"flags": {},
- "order": 2,
+ "order": 1,
"mode": 0,
- "inputs": [
- {
- "name": "upscale_model",
- "type": "UPSCALE_MODEL",
- "link": 6
- },
- {
- "name": "image",
- "type": "IMAGE",
- "link": 8
- }
- ],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
- 7
+ 13
],
"slot_index": 0
+ },
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": null
}
],
+ "title": "image_loader",
"properties": {
- "Node name for S&R": "FaceRestoreWithModel"
+ "Node name for S&R": "LoadImage"
},
"widgets_values": [
- "retinaface_resnet50"
+ "test_facefix.png",
+ "image"
]
},
{
- "id": 1,
- "type": "LoadImage",
+ "id": 11,
+ "type": "FaceRestoreCFWithModel",
"pos": [
- 30,
- 325
+ 397,
+ 166
],
"size": {
"0": 315,
"1": 102
},
"flags": {},
- "order": 1,
+ "order": 2,
"mode": 0,
+ "inputs": [
+ {
+ "name": "facerestore_model",
+ "type": "FACERESTORE_MODEL",
+ "link": 12
+ },
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 13
+ }
+ ],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
- 8
+ 14
],
+ "shape": 3,
"slot_index": 0
- },
- {
- "name": "MASK",
- "type": "MASK",
- "links": null
}
],
- "title": "image_loader",
+ "title": "face_restore_with_model",
"properties": {
- "Node name for S&R": "LoadImage"
+ "Node name for S&R": "FaceRestoreCFWithModel"
},
"widgets_values": [
- "test_facefix.png",
- "image"
+ "retinaface_resnet50",
+ 0.5
]
}
],
"links": [
[
- 6,
- 4,
+ 12,
+ 10,
0,
- 8,
+ 11,
0,
- "UPSCALE_MODEL"
+ "FACERESTORE_MODEL"
],
[
- 7,
- 8,
- 0,
- 6,
+ 13,
+ 1,
0,
+ 11,
+ 1,
"IMAGE"
],
[
- 8,
- 1,
+ 14,
+ 11,
+ 0,
+ 6,
0,
- 8,
- 1,
"IMAGE"
]
],
"groups": [],
"config": {},
- "extra": {},
+ "extra": {
+ "ds": {
+ "scale": 1,
+ "offset": [
+ 383.39140502393127,
+ 89.69757412917136
+ ]
+ }
+ },
"version": 0.4
}
diff --git a/hordelib/pipelines/pipeline_image_facefix.json b/hordelib/pipelines/pipeline_image_facefix.json
index a00f479a..b06dcfae 100644
--- a/hordelib/pipelines/pipeline_image_facefix.json
+++ b/hordelib/pipelines/pipeline_image_facefix.json
@@ -9,20 +9,11 @@
"title": "image_loader"
}
},
- "4": {
- "inputs": {
- "model_name": "CodeFormers.pth"
- },
- "class_type": "UpscaleModelLoader",
- "_meta": {
- "title": "model_loader"
- }
- },
"6": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
- "8",
+ "11",
0
]
},
@@ -31,11 +22,21 @@
"title": "output_image"
}
},
- "8": {
+ "10": {
+ "inputs": {
+ "model_name": "CodeFormers.pth"
+ },
+ "class_type": "FaceRestoreModelLoader",
+ "_meta": {
+ "title": "model_loader"
+ }
+ },
+ "11": {
"inputs": {
"facedetection": "retinaface_resnet50",
- "upscale_model": [
- "4",
+ "codeformer_fidelity": 0.5,
+ "facerestore_model": [
+ "10",
0
],
"image": [
@@ -43,9 +44,9 @@
0
]
},
+ "class_type": "FaceRestoreCFWithModel",
"_meta": {
"title": "face_restore_with_model"
- },
- "class_type": "FaceRestoreWithModel"
+ }
}
}
diff --git a/mypy.ini b/mypy.ini
index 10a57f6c..fa0778a7 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,11 +1,15 @@
[mypy]
-exclude = (build|dist|ComfyUI|comfy_controlnet_preprocessors|facerestore|comfy_horde\.py|examples|diffusers)
+exclude = (build|dist|ComfyUI|comfy_controlnet_preprocessors|facerestore_cf|comfy_horde\.py|examples|diffusers)
[mypy-hordelib.nodes.comfy_controlnet_preprocessors.*]
ignore_errors = True
ignore_missing_imports = True
+[mypy-hordelib.nodes.facerestore_cf.*]
+ignore_errors = True
+ignore_missing_imports = True
+
[mypy-hordelib.nodes.comfyui_layerdiffuse.*]
ignore_errors = True
ignore_missing_imports = True
diff --git a/pyproject.toml b/pyproject.toml
index 4837a77c..7922ba9d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,7 +78,7 @@ exclude = '''
[tool.ruff] # XXX this isn't part of CI yet
line-length=119
-exclude=["ComfyUI", "comfy_controlnet_preprocessors", "facerestore", "comfy_qr", "comfyui_layerdiffuse", "build"]
+exclude=["ComfyUI", "comfy_controlnet_preprocessors", "facerestore_cf", "comfy_qr", "comfyui_layerdiffuse", "build"]
ignore=[
# "F401", # imported but unused
"E402", # Module level import not at top of file
diff --git a/requirements.txt b/requirements.txt
index ab2ca5ec..90cec3c1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -50,3 +50,4 @@ kornia
qrcode
spandrel
spandrel_extra_arches
+lpips
diff --git a/tests/test_horde_pp.py b/tests/test_horde_pp.py
index 72dccc6a..3bb5f442 100644
--- a/tests/test_horde_pp.py
+++ b/tests/test_horde_pp.py
@@ -66,8 +66,8 @@ def post_processor_check(
similarity_constraints = ImageSimilarityConstraints(
cosine_fail_floor=CosineSimilarityResultCode.PERCEPTUALLY_IDENTICAL,
cosine_warn_floor=CosineSimilarityResultCode.EXTREMELY_SIMILAR,
- histogram_fail_threshold=HistogramDistanceResultCode.VERY_DISSIMILAR_DISTRIBUTION,
- histogram_warn_threshold=HistogramDistanceResultCode.SIMILAR_DISTRIBUTION,
+ histogram_fail_threshold=HistogramDistanceResultCode.VERY_SIMILAR_DISTRIBUTION,
+ histogram_warn_threshold=HistogramDistanceResultCode.EXTREMELY_SIMILAR_DISTRIBUTION,
)
assert cls.shared_model_manager.manager.download_model(model_name)
assert cls.shared_model_manager.manager.is_model_available(model_name) is True