From d10d44db697751bab3f9357966e6451d84d2197f Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 1 Oct 2024 16:02:23 +0800 Subject: [PATCH 01/18] instantmesh inference & stage 1 training, also the eval script is provided --- examples/instantmesh/README.md | 104 ++ .../configs/instant-mesh-large.yaml | 17 + .../configs/instant-nerf-large-train.yaml | 50 + examples/instantmesh/data/objaverse.py | 406 +++++ examples/instantmesh/eval.py | 202 +++ examples/instantmesh/inference.py | 146 ++ examples/instantmesh/model_stage1.py | 168 ++ examples/instantmesh/models/__init__.py | 0 .../instantmesh/models/decoder/__init__.py | 0 .../instantmesh/models/decoder/transformer.py | 142 ++ .../instantmesh/models/encoder/__init__.py | 0 examples/instantmesh/models/encoder/dino.py | 539 ++++++ .../models/encoder/dino_wrapper.py | 69 + .../instantmesh/models/geometry/__init__.py | 0 .../models/geometry/camera/__init__.py | 7 + .../geometry/camera/perspective_camera.py | 34 + .../models/geometry/rep_3d/__init__.py | 6 + .../models/geometry/rep_3d/flexicubes.py | 611 +++++++ .../geometry/rep_3d/flexicubes_geometry.py | 120 ++ .../models/geometry/rep_3d/tables.py | 1472 +++++++++++++++++ examples/instantmesh/models/lrm.py | 197 +++ examples/instantmesh/models/lrm_mesh.py | 371 +++++ .../instantmesh/models/renderer/__init__.py | 0 .../models/renderer/synthesizer.py | 225 +++ .../models/renderer/synthesizer_mesh.py | 155 ++ .../models/renderer/utils/__init__.py | 0 .../models/renderer/utils/renderer.py | 403 +++++ examples/instantmesh/requirements.txt | 7 + examples/instantmesh/train.py | 464 ++++++ examples/instantmesh/train.sh | 3 + examples/instantmesh/utils/__init__.py | 0 examples/instantmesh/utils/camera_util.py | 119 ++ examples/instantmesh/utils/eval_util.py | 208 +++ examples/instantmesh/utils/loss_util.py | 164 ++ examples/instantmesh/utils/mesh_util.py | 141 ++ .../instantmesh/utils/ms_callback_util.py | 112 ++ examples/instantmesh/utils/train_util.py | 34 + 37 files changed, 6696 insertions(+) create mode 100644 examples/instantmesh/README.md create mode 100644 examples/instantmesh/configs/instant-mesh-large.yaml create mode 100644 examples/instantmesh/configs/instant-nerf-large-train.yaml create mode 100644 examples/instantmesh/data/objaverse.py create mode 100644 examples/instantmesh/eval.py create mode 100644 examples/instantmesh/inference.py create mode 100644 examples/instantmesh/model_stage1.py create mode 100644 examples/instantmesh/models/__init__.py create mode 100644 examples/instantmesh/models/decoder/__init__.py create mode 100644 examples/instantmesh/models/decoder/transformer.py create mode 100644 examples/instantmesh/models/encoder/__init__.py create mode 100644 examples/instantmesh/models/encoder/dino.py create mode 100644 examples/instantmesh/models/encoder/dino_wrapper.py create mode 100644 examples/instantmesh/models/geometry/__init__.py create mode 100644 examples/instantmesh/models/geometry/camera/__init__.py create mode 100644 examples/instantmesh/models/geometry/camera/perspective_camera.py create mode 100644 examples/instantmesh/models/geometry/rep_3d/__init__.py create mode 100644 examples/instantmesh/models/geometry/rep_3d/flexicubes.py create mode 100644 examples/instantmesh/models/geometry/rep_3d/flexicubes_geometry.py create mode 100644 examples/instantmesh/models/geometry/rep_3d/tables.py create mode 100644 examples/instantmesh/models/lrm.py create mode 100644 examples/instantmesh/models/lrm_mesh.py create mode 100644 examples/instantmesh/models/renderer/__init__.py create mode 100644 examples/instantmesh/models/renderer/synthesizer.py create mode 100644 examples/instantmesh/models/renderer/synthesizer_mesh.py create mode 100644 examples/instantmesh/models/renderer/utils/__init__.py create mode 100644 examples/instantmesh/models/renderer/utils/renderer.py create mode 100644 examples/instantmesh/requirements.txt create mode 100644 examples/instantmesh/train.py create mode 100644 examples/instantmesh/train.sh create mode 100644 examples/instantmesh/utils/__init__.py create mode 100644 examples/instantmesh/utils/camera_util.py create mode 100644 examples/instantmesh/utils/eval_util.py create mode 100644 examples/instantmesh/utils/loss_util.py create mode 100644 examples/instantmesh/utils/mesh_util.py create mode 100644 examples/instantmesh/utils/ms_callback_util.py create mode 100644 examples/instantmesh/utils/train_util.py diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md new file mode 100644 index 0000000000..af21872ca2 --- /dev/null +++ b/examples/instantmesh/README.md @@ -0,0 +1,104 @@ +# InstantMesh: 3D Mesh Generation from Multiview Images + +- [ ] Elaborate on the design methdology in the intro on Oct 2 +- [ ] Training part + +This `instantmesh` module under `.../sv3d/models` is implemented for the 3D mesh generation using the multiview images extracted from [the sv3d pipeline](https://github.com/mindspore-lab/mindone/pull/574). + +A walk-through of the file structure is provided here as below. + +
+Files Tree + + +```bash +├── instantmesh +│ ├── decoder # triplane feature transformer decoder +│ │ └── transformer.py +│ ├── encoder # dino vit decoder to extract img feat +│ │ ├── dino_wrapper.py +│ │ └── dino.py +│ ├── renderer # a wrapper that synthesizes sdf/texture from triplane feat +│ │ ├── synthesizer_mesh.py # triplane synthesizer, the triplane feat is decoded thru nerf to predict texture rgb & 3D sdf +│ │ ├── synthesizer.py # triplane synthesizer, the triplane feat is decoded thru nerf to predict novel view rgba +│ │ ├── utils +│ │ │ └── renderer.py +│ │ └── synthesizer.py +│ ├── geometry # use Flexicubes to extract isosurface +│ │ ├── rep_3d +│ │ │ ├── flexicubes_geometry.py +│ │ │ ├── tables.py +│ │ │ └── flexicubes.py +│ │ └── camera +│ │ └── perspective_camera.py +│ ├── lrm_mesh.py # model arch for the instantmesh inference +│ └── lrm.py # model arch for the instantmesh stage 1 training +├── utils +│ ├── camera_util.py +│ ├── train_util.py +│ ├── eval_util.py +│ ├── loss_util.py +│ ├── ms_callback_util.py +│ └── mesh_util.py +├── data +│ └── objaverse.py # training dataset definition and batchify +├── configs +│ └── instant-mesh-large.yaml +├── inference.py # instantmesh inference +├── train.py # instantmesh stage 1 training +├── eval.py # instantmesh stage 1 evaluation, mview imgs to novel view synthesis +└── model_stage1.py # model arch for the stage 1 training +``` + +
+ +## Introduction + +InstantMesh [[1]](#acknowledgements) synergizes the strengths of a multiview diffusion model and a sparse-view reconstruction model based on the LRM [[2]](#acknowledgements) architecture. It also adopts FlexiCubes [[3]](#acknowledgements) isosurface extraction for a smoother and more elegant mesh extraction. + +Using the multiview images input from 3D mesh extracted from [the sv3d pipeline](../../simple_video_sample.py), we extracted 3D meshes as below. Please kindly find the input illustrated by following the link to the sv3d pipeline above. + +|

akun

|

anya

| +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +|
|
| + +The illustrations here are better viewed in viewers than with HTML support (e.g., the vscode built-in viewer). + +## Environments + +1. To kickstart: + +```bash +pip install -r requirements.txt +``` + +2. Inference is tested on the machine with the following specs using 1x NPU: + +```text +Mindspore Version: 2.3.0.B528 +CANN Version: 7.3 +Ascend Driver: 23.0.rc3.6 +``` + +## Pretrained Models + +To better accommodate the mindone transformer codebase, we provide an out-of-the-box [checkpoints conversion script](../../tools/convert_instantmesh_ckpt.py) that works seamlessly with the mindspore version of transformers. + +The image features are extracted with dino-vit, which depends on HuggingFace's transformer package. We reuse [the MindSpore's implementation](https://github.com/mindspore-lab/mindone/blob/master/mindone/transformers/modeling_utils.py#L499) and the only challenge remains to be that `.bin` checkpoint of [dino-vit](https://huggingface.co/facebook/dino-vitb16/tree/main) is not supported by MindSpore off-the-shelf. The checkpoint script above serves easy conversion purposes and ensures that dino-vit is still based on `MSPreTrainedModel` safe and sound. + +## Inference + +```shell +python inference.py --ckpt PATH_TO_CKPT \ +--input_vid PATH_TO_INPUT_MULTIVIEW_VID +``` + +## Training +### Data Curation + +## Acknowledgements + +1. Xu, Jiale, et al. "Instantmesh: Efficient 3d mesh generation from a single image with sparse-view large reconstruction models." arXiv preprint arXiv:2404.07191 (2024). +2. Hong, Yicong, et al. "Lrm: Large reconstruction model for single image to 3d." arXiv preprint arXiv:2311.04400 (2023). +3. Shen, Tianchang, et al. "Flexible Isosurface Extraction for Gradient-Based Mesh Optimization." ACM Trans. Graph. 42.4 (2023): 37-1. +4. Lorensen, William E., and Harvey E. Cline. "Marching cubes: A high resolution 3D surface construction algorithm." Seminal graphics: pioneering efforts that shaped the field. 1998. 347-353. diff --git a/examples/instantmesh/configs/instant-mesh-large.yaml b/examples/instantmesh/configs/instant-mesh-large.yaml new file mode 100644 index 0000000000..c44b823b64 --- /dev/null +++ b/examples/instantmesh/configs/instant-mesh-large.yaml @@ -0,0 +1,17 @@ +model: + encoder_model_name: 'facebook/dino-vitb16' + target: models.instantmesh.models3d.lrm_mesh.InstantMesh + params: + encoder_model_name: 'YOUR_PATH_HF/models--facebook--dino-vitb16/snapshots/f205d5d8e640a89a2b8ef0369670dfc37cc07fc2' # coz needs to enforce the is_local flag (with pretrained_model_name_or_path as dir), thus here put in the abs path as a workaround + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 128 + grid_res: 128 + grid_scale: 2.1 + + +infer: + # model_path: ckpts/instant_mesh_large.ckpt # by default as in torch, the model loaded from hf, let's do conversion insitu + texture_resolution: 1024 + render_resolution: 512 diff --git a/examples/instantmesh/configs/instant-nerf-large-train.yaml b/examples/instantmesh/configs/instant-nerf-large-train.yaml new file mode 100644 index 0000000000..cb848b3556 --- /dev/null +++ b/examples/instantmesh/configs/instant-nerf-large-train.yaml @@ -0,0 +1,50 @@ +model: + base_learning_rate: 4.0e-04 + target: model_stage1.InstantMeshStage1WithLoss + params: + input_size: 320 + render_size: 192 + # render_size: 96 + lrm_generator_config: + openlrm_ckpt: 'YOUR_PATH/openlrm.ckpt' + target: models3d.lrm.InstantNeRF + params: + encoder_feat_dim: 768 + encoder_freeze: false + # encoder_model_name: facebook/dino-vitb16 + encoder_model_name: 'YOUR_PATH_HF/models--facebook--dino-vitb16/snapshots/f205d5d8e640a89a2b8ef0369670dfc37cc07fc2' # coz needs to enforce the is_local flag (with pretrained_model_name_or_path as dir), thus here put in the abs path as a workaround + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 64 # official ckpt is 128, if loaded pretrained make sure it's 128 + use_recompute: true + +eval_render_size: 96 # large may have oom + +data: + batch_size: 1 + num_workers: 4 + train: + target: data.objaverse.ObjaverseDataset + params: + root_dir: YOUR_PATH_DATA # for overfitting exp + meta_fname: uid_set.pkl + input_image_dir: input + target_image_dir: input + input_view_num: 3 + target_view_num: 2 + input_size: 320 + render_size: 192 + total_view_n: 16 + fov: 50 + camera_rotation: true + val: + target: data.objaverse.ValidationDataset + params: + root_dir: YOUR_PATH_DATA/target + input_view_num: 6 + input_image_size: 320 + fov: 30 diff --git a/examples/instantmesh/data/objaverse.py b/examples/instantmesh/data/objaverse.py new file mode 100644 index 0000000000..1daea22871 --- /dev/null +++ b/examples/instantmesh/data/objaverse.py @@ -0,0 +1,406 @@ +import json +import math +import os +import pickle +import sys +from pathlib import Path + +import numpy as np +from PIL import Image + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath( + os.path.join(__dir__, "../../../../../") +) # TODO: remove in future when mindone is ready for install +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) # for loading utils +print(sys.path) + +from transformers import ViTImageProcessor +from utils.camera_util import FOV_to_intrinsics, center_looking_at_camera_pose, get_circular_camera_poses + +import mindspore as ms +from mindspore import Tensor +from mindspore.dataset.vision import Inter, Resize, ToPIL + + +def read_pickle(pkl_path): + with open(pkl_path, "rb") as f: + return pickle.load(f) + + +def read_txt2list(txt_path): + list_entry = [] + with open(txt_path, "r") as f: + for line in f: + x = line[:-1] + list_entry.append(x) + return list_entry + + +def read_json(json_path): + with open(json_path) as f: + return json.load(f) + + +def random_crop_return_params(imgs, height, width): + """imgs: (b h w c)""" + assert imgs.shape[1] >= height + assert imgs.shape[2] >= width + top = np.random.randint(0, imgs.shape[1] - height + 1) + left = np.random.randint( + 0, imgs.shape[2] - width + 1 + ) # same as torch left inclusive, right exclusive, caveat: if using random pkg, right is inclusive + imgs = np.array([img[top : top + height, left : left + width] for img in imgs]) + return imgs, (top, left, height, width) + + +def crop_with_param(imgs, top, left, height, width): + return np.array([img[top : top + height, left : left + width] for img in imgs]) + + +class ObjaverseDataset: + def __init__( + self, + root_dir="training_examples/", + meta_fname="uid_set.pkl", + input_image_dir="input", + target_image_dir="input", + input_view_num=6, + target_view_num=4, + input_size=None, + render_size=None, + total_view_n=32, + fov=50, + camera_rotation=True, + ): + self.root_dir = Path(root_dir) + self.input_image_dir = input_image_dir + self.target_image_dir = target_image_dir + self.input_view_num = input_view_num + self.target_view_num = target_view_num + self.total_view_n = total_view_n + self.fov = fov + self.camera_rotation = camera_rotation + self.output_columns = [ + "images", + "cameras", + "render_cameras", + "target_images", + "target_alphas", + "render_size", + "crop_params", + ] + + if meta_fname == "uid_set.pkl": + self.paths = read_pickle(os.path.join(root_dir, meta_fname))[-3:] + # [:1] # only takes the first scene for debugging + print("dataset read pickle") + elif meta_fname.split(".")[-1] == "txt": + self.paths = read_txt2list(os.path.join(root_dir, meta_fname)) + print("reading the fixed pose target list as the dataset") + else: + raise ValueError(f"set up meta_fname {meta_fname} is not matched with the datset proc") + + # dataaug: vit img processor peeled from dino-vit, not learnable thus cannot put in the .construct() unlike torch + self.img_processor = ViTImageProcessor.from_pretrained("facebook/dino-vitb16") + self.topil = ToPIL() + + # make tuple as PIL requires + self.input_size = (input_size, input_size) + self.render_size = render_size + print("============= length of dataset %d =============" % len(self.paths)) + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color, _is_gt=False): + """ + replace background pixel with random color in rendering + """ + pil_img = Image.open(path) # h w c + image = np.asarray(pil_img, dtype=np.float32) / 255.0 + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + return image, alpha + + def prepare_sample_data(self, sample: dict) -> tuple[dict, dict]: + """ + The prepare_batch_data() in the pl original implmenetaion. Move to here in the dataset, as + 1. let Pil handling and Dino ViT preprocessing input imgs; + 2. ms dataloader only allows tensor data flushing, as defined by the output_columns. Thus cannot put this into construct as pl did. + """ + lrm_generator_input = {} + render_gt = {} + + images = sample["input_images"] # (1 n h w c) in dataloader, (nhwc) if simply _getitem_. + + # requested by topil, which is limited by ms.Resize and pil.resize + images = np.asarray(images * 255, dtype=np.uint8) + + input_antialias_resizer = Resize(self.input_size, interpolation=Inter.ANTIALIAS) + images = np.asarray([input_antialias_resizer(self.topil(img)) for img in images]) # img: n h w c + images = images.astype("float32") / 255.0 + + # images = images / 255.0 # for dino-vit processor, it takes fp32 (0, 1) + images = images.clip(min=0, max=1) + + # requested by vit proc, restore into n c h w + images = images.transpose(0, 3, 1, 2) # nhwc -> nchw for antialias input images + + # dino-vit wrapper forward(), moved from the dino-wrapper + # normalize from fp32 (0, 1) to (-2.1, 2.6) + images = self.img_processor( + images=images, + return_tensors="np", + do_rescale=False, + do_resize=False, + )["pixel_values"] + lrm_generator_input["images"] = images + + input_c2ws = sample["input_c2ws"].reshape((self.input_view_num, 16)) + input_Ks = sample["input_Ks"].reshape((self.input_view_num, 9)) + target_c2ws = sample["target_c2ws"].reshape((self.target_view_num, 16)) + target_Ks = sample["target_Ks"].reshape((self.target_view_num, 9)) + render_cameras_input = np.concatenate([input_c2ws, input_Ks], axis=-1) + render_cameras_target = np.concatenate([target_c2ws, target_Ks], axis=-1) + render_cameras = np.concatenate([render_cameras_input, render_cameras_target], axis=0) # n_in+n_ta, 25 + + input_extrinsics = input_c2ws[:, :12] + input_intrinsics = np.stack( + [ + input_Ks[:, 0], + input_Ks[:, 4], + input_Ks[:, 2], + input_Ks[:, 5], + ], + axis=-1, + ) + cameras = np.concatenate([input_extrinsics, input_intrinsics], axis=-1) + + # add noise to input cameras + cameras = cameras + np.random.rand(*cameras.shape) * 0.04 - 0.02 + + lrm_generator_input["cameras"] = cameras.astype("float32") + lrm_generator_input["render_cameras"] = render_cameras + + # construct target images and alpha channels from input+target + target_images = np.concatenate([sample["input_images"], sample["target_images"]], axis=0) + target_alphas = np.concatenate([sample["input_alphas"], sample["target_alphas"]], axis=0) + + target_images = np.asarray(target_images * 255, dtype=np.uint8) + target_alphas = np.asarray(target_alphas * 255, dtype=np.uint8) + + render_size = np.random.randint(self.render_size, 513) + + # crop and display the correct target img/alpha + target_antialias_resizer = Resize(render_size, interpolation=Inter.ANTIALIAS) + target_images = np.asarray([target_antialias_resizer(self.topil(img)) for img in target_images]) + target_images = target_images.astype("float32") / 255.0 + target_images = target_images.clip(min=0, max=1) + + target_alphas = np.asarray( + [target_antialias_resizer(self.topil(img)) for img in target_alphas] + ) # (n h w), the resizer squeeze the last dim when it's 1 + target_alphas = target_alphas[..., None] + target_alphas = target_alphas.astype("float32") / 255.0 + + # random crop with get_params implementation + target_images, crop_params = random_crop_return_params(target_images, self.render_size, self.render_size) + target_alphas = crop_with_param(target_alphas, *crop_params) + + render_gt["target_images"] = target_images.transpose( + 0, 3, 1, 2 + ) # nhwc -> nchw for calculating loss correctly with the render imgs + render_gt["target_alphas"] = target_alphas.transpose(0, 3, 1, 2) # nhwc -> nchw + + lrm_generator_input["render_size"] = render_size + lrm_generator_input["crop_params"] = crop_params + + return lrm_generator_input, render_gt + + def __getitem__(self, index): + input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) + target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) + + indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False) + input_indices = indices[: self.input_view_num] + target_indices = indices[self.input_view_num :] + + """background color, default: white""" + bg_white = [1.0, 1.0, 1.0] + + image_list = [] + alpha_list = [] + pose_list = [] + + K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_image_path, "meta.pkl")) + input_cameras = cam_poses + for idx in input_indices: + image, alpha = self.load_im(os.path.join(input_image_path, "%03d.png" % idx), bg_white) + pose = input_cameras[idx] + pose = np.concatenate([pose, np.asarray([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + pose_list.append(pose) + + # K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_image_path, 'meta.pkl')) # duplicate line with above? + target_cameras = cam_poses + for idx in target_indices: + image, alpha = self.load_im(os.path.join(target_image_path, "%03d.png" % idx), bg_white) + pose = target_cameras[idx] + pose = np.concatenate([pose, np.asarray([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + pose_list.append(pose) + + images = np.stack( + image_list, axis=0, dtype=np.float32 + ) # (6+V, H, W, C), for PIL proc/ms.resizer/cropper in prepare_sample_data(), thus !=[(6+V, 3, H, W)] and it should be uint8 before pass to topil() + alphas = np.stack(alpha_list, axis=0, dtype=np.float32) # (6+V, H, W, 1) + + w2cs = np.stack(pose_list, axis=0, dtype=np.float32) # (6+V, 4, 4) + c2ws = np.linalg.inv(w2cs).astype(np.float32) + + # random rotation along z axis + if self.camera_rotation: + degree = np.random.uniform(0, math.pi * 2) + rot = np.expand_dims( + np.asarray( + [ + [np.cos(degree), -np.sin(degree), 0, 0], + [np.sin(degree), np.cos(degree), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ), + axis=0, + ).astype(np.float32) + c2ws = np.matmul(rot, c2ws) + + # random scaling + if np.random.rand() < 0.5: + scale = np.random.uniform(0.7, 1.1) + c2ws[:, :3, 3] *= scale + + # instrinsics of perspective cameras + K = FOV_to_intrinsics(self.fov) + Ks = np.tile(np.expand_dims(K, axis=0), (self.input_view_num + self.target_view_num, 1, 1)).astype(np.float32) + + data = { + "input_images": images[: self.input_view_num], # (6, H, W, 3) + "input_alphas": alphas[: self.input_view_num], # (6, H, W, 1) + "input_c2ws": c2ws[: self.input_view_num], # (6, 4, 4) + "input_Ks": Ks[: self.input_view_num], # (6, 3, 3) + # lrm generator input and supervision + "target_images": images[self.input_view_num :], # (V, H, W, 3) + "target_alphas": alphas[self.input_view_num :], # (V, H, W, 1) + "target_c2ws": c2ws[self.input_view_num :], # (V, 4, 4) + "target_Ks": Ks[self.input_view_num :], # (V, 3, 3) + } + + lrm_generator_input, render_gt = self.prepare_sample_data(data) + + return ( + lrm_generator_input["images"], + lrm_generator_input["cameras"], + lrm_generator_input["render_cameras"], + render_gt["target_images"], + render_gt["target_alphas"], + lrm_generator_input["render_size"], + lrm_generator_input["crop_params"], + ) + + +class ValidationDataset: + def __init__( + self, + root_dir="objaverse/", + input_view_num=6, + input_image_size=320, + fov=30, + ): + self.root_dir = Path(root_dir) + self.input_view_num = input_view_num + self.input_image_size = input_image_size + self.fov = fov + + self.paths = sorted(os.listdir(self.root_dir))[-3:] + print("============= length of dataset %d =============" % len(self.paths)) + + cam_distance = 4.0 + azimuths = np.asarray([30, 90, 150, 210, 270, 330]) + elevations = np.asarray([20, -10, 20, -10, 20, -10]) + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + x = cam_distance * np.cos(elevations) * np.cos(azimuths) + y = cam_distance * np.cos(elevations) * np.sin(azimuths) + z = cam_distance * np.sin(elevations) + + cam_locations = np.stack([x, y, z], axis=-1) + cam_locations = Tensor.from_numpy(cam_locations).float() + c2ws = center_looking_at_camera_pose(cam_locations) + self.c2ws = c2ws.astype(ms.float32) + K = FOV_to_intrinsics(self.fov) + # .unsqueeze(0).tile((6, 1, 1)).float() + # .astype(np.float32) + self.Ks = Tensor(np.tile(np.expand_dims(K, axis=0), (6, 1, 1)), ms.float32) + + self.render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0).float() + # = FOV_to_intrinsics(self.fov) + # .unsqueeze(0).tile((self.render_c2ws.shape[0], 1, 1)).float() + self.render_Ks = Tensor(np.tile(np.expand_dims(K, axis=0), (self.render_c2ws.shape[0], 1, 1)), ms.float32) + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + """ + replace background pixel with random color in rendering + """ + pil_img = Image.open(path) + pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) + + image = np.asarray(pil_img, dtype=np.float32) / 255.0 + if image.shape[-1] == 4: + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + else: + alpha = np.ones_like(image[:, :, :1]) + + # comment below as we need (v h w c) in topil, not (v c h w) + # image = np.asarray(image, dtype=np.float32).transpose(2, 0, 1) + # alpha = np.asarray(alpha, dtype=np.float32).transpose(2, 0, 1) + return image, alpha + + def __getitem__(self, index): + # load data + input_image_path = os.path.join(self.root_dir, self.paths[index]) + + """background color, default: white""" + bkg_color = [1.0, 1.0, 1.0] + + image_list = [] + alpha_list = [] + + for idx in range(self.input_view_num): + image, alpha = self.load_im(os.path.join(input_image_path, f"{idx: 03d}.png"), bkg_color) + image_list.append(image) + alpha_list.append(alpha) + + images = np.stack(image_list, axis=0, dtype=np.float32) # (6+V, 3, H, W) + alphas = np.stack(alpha_list, axis=0, dtype=np.float32) # (6+V, 1, H, W) + + data = { + "input_images": images, + "input_alphas": alphas, + "input_c2ws": self.c2ws, + "input_Ks": self.Ks, + "input_image_path": input_image_path, + "render_c2ws": self.render_c2ws, + "render_Ks": self.render_Ks, + } + return data diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py new file mode 100644 index 0000000000..60619d4b58 --- /dev/null +++ b/examples/instantmesh/eval.py @@ -0,0 +1,202 @@ +"""Eval script using the model stage 1 trained ckpt to conduct arbitral novel view synthesis. + +Design methdology: Unlike the ms translation that has been done for training, +we make the eval here more similar to the inference script below with np utilization. +Because for training, the np data proc parts should be translated into ms as much as possible +(see ~/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py), +but not for the inference. + +Thus we can do np for all data proc here to avoid tedious ms tranlsation of data. +Refer to inference.py for the full stage inference. +""" +import argparse +import datetime +import os +import sys +import time + +import mindspore as ms +from mindspore import mint + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../../"))) # for loading mindone +# from loguru import logger +import logging + +from omegaconf import OmegaConf +from transformers import ViTImageProcessor +from utils.eval_util import init_inference_env, make_grid_ms, save_image_ms, str2bool + +from mindone.utils.config import instantiate_from_config +from mindone.utils.logger import set_logger +from mindone.utils.seed import set_random_seed + +logger = logging.getLogger(__name__) + +from typing import Optional + + +def evaluate(args, epoch_num: Optional[str]): + if args.append_timestr: + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + save_dir = f"{args.output_path}/{time_str}" + else: + save_dir = f"{args.output_path}" + image_path = save_dir + if not args.debug: + os.makedirs(image_path, exist_ok=True) + + rid, dnum = init_inference_env( + args.mode, + args.seed, + device_target=args.device_target, + device_id=int(os.getenv("DEVICE_ID")), + jit_level=args.jit_level, + debug=args.debug, + ) + set_random_seed(42) + set_logger(name="", output_dir=args.output_path, rank=rid, log_level=eval(args.log_level)) + + # valdata preparation + config = OmegaConf.load(args.datacfg) + data = instantiate_from_config(config.data.val) + valset_len = data.__len__() + + # init model & load ckpt + config = OmegaConf.load(args.modelcfg) + config.model.params.lrm_generator_config.params.dtype = args.dtype + config.model.params.lrm_ckpt_path = args.itmh_ckpt + + model = instantiate_from_config(config.model) + + # create img name + if args.itmh_ckpt.split("/")[-1] == "instant_nerf_large_ms.ckpt": + image_path = os.path.join(image_path, "val_official_instantmesh_ckpt.png") + else: + if not epoch_num: # train_resume.ckpt + epoch_num = ms.load_checkpoint(args.itmh_ckpt).get( + "epoch_num", 0 + ) # 0 means that there is no this key in the resume ckpt file + image_path = os.path.join(image_path, f"val_e{epoch_num}.png") + + validation_step_outputs = [] + batches_time = [] + for index in range(0, valset_len, args.batch_size): + val_batch_np = data.__getitem__(index) + + # [torch] prepare_validation_batch_data(): + # prepare for validation batch/mv2mesh inference(): see raw repo TODO del this comment once this script finishes + val_input = model.prepare_validation_batch_data( + val_batch_np, + render_size=config.eval_render_size, + _use_dataloader=False, + ) + images = val_input["images"] + + # [torch] forward(): + # RGB image with [0,1] scale and properly sized requested by the ViTImgProc + if images.ndim == 5: + (B, N, C, H, W) = images.shape + images = images.reshape(B * N, C, H, W) + + # ViTImageProcessor moved out from dino wrapper to the main here, to avoid being in .construct(), do ImageNetStandard normalization + img_processor = ViTImageProcessor.from_pretrained( + config.model.params.lrm_generator_config.params.encoder_model_name + ) + images = img_processor( + images=images, + return_tensors="np", + do_rescale=False, + do_resize=False, + )["pixel_values"] + val_input["images"] = ms.Tensor(images).reshape(B, N, C, H, W) + + # infer + start_time = time.time() + render_images, render_alphas = model.forward_nocalloss(**val_input) + + batch_time = time.time() - start_time + batches_time.append(batch_time) + logger.info(f"Batch time cost: {batch_time: .3f}s.") + # save result both img and alpha, in validation_step() + # render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') + render_images = mint.permute(render_images, dims=(0, 2, 3, 1, 4)).flatten(start_dim=-2) + validation_step_outputs.append(render_images) + + mean_time = sum(batches_time) / len(batches_time) + logger.info(f"Mean Batch time: {mean_time: .3f}s.") + # save mviews outputs + images = mint.cat(validation_step_outputs, dim=0) # enable for multiple batches + + # images = rearrange(images, 'r b c h w -> (r b) c h w') + # images = images.flatten(start_dim=0, end_dim=1) + assert len(images.shape) == 4, "images' shape not matched" + + grid = make_grid_ms(images, nrow=1, normalize=True, value_range=(0, 1)) + if not args.debug: + save_image_ms(grid, image_path) + logger.info(f"Saved image to {image_path}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--itmh_ckpt", default="CKPT_PATH") + parser.add_argument( + "--debug", + # default=True, # also setting debug as true will set pynative sync as true as well + default=False, # also setting debug as true will set pynative sync as true as well + help="When debugging, set it true, to avoid saving too many ckpts and burn out the storage.", + ) + parser.add_argument( + "--output_path", + type=str, + default="output", + help="output dir to save the generated videos", + ) + parser.add_argument( + "--append_timestr", + type=str2bool, + default=True, + help="If true, an subfolder named with timestamp under output_path will be created to save the sampling results", + ) + parser.add_argument("--datacfg", default="configs/instant-nerf-large-train.yaml") + parser.add_argument("--modelcfg", default="configs/instant-nerf-large-train.yaml") + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--mode", type=int, default=1, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1) (default=0)") + parser.add_argument("--seed", type=int, default=42, help="Inference seed") + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--jit_level", + default="O0", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) + parser.add_argument("--batch_size", default=1, type=int, help="infer batch size") + parser.add_argument( + "--dtype", + default="fp32", # if amp level O0/1, must pass fp32 + type=str, + choices=["bf16", "fp16", "fp32"], + help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--log_level", + type=str, + default="logging.INFO", + help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + if args.itmh_ckpt.split("/")[-1] == "train_resume.ckpt": + epoch_num = None + else: + epoch_num = args.itmh_ckpt.split("-e")[-1].split(".")[0] + evaluate(args, epoch_num) diff --git a/examples/instantmesh/inference.py b/examples/instantmesh/inference.py new file mode 100644 index 0000000000..585bacf854 --- /dev/null +++ b/examples/instantmesh/inference.py @@ -0,0 +1,146 @@ +""" Inference of InstantMesh which uses multiview images to generate 3D mesh. + +Note that the rendering part of instantmesh contains one of the cuda rasterization extensions, thus not implemented at the moment. +""" +import argparse +import os + +import imageio +import mcubes +import numpy as np +from einops import rearrange +from loguru import logger +from omegaconf import OmegaConf +from PIL import Image +from transformers import ViTImageProcessor +from utils.camera_util import get_sv3d_input_cameras + +# from models.instantmesh.utils.mesh_util import save_obj +from utils.train_util import instantiate_from_config + +import mindspore as ms +from mindspore import Tensor, nn +from mindspore.dataset.vision import ToPIL + +from mindone.utils.seed import set_random_seed + + +def args_parse(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_vid", default="INPUT_MULTIVIEW_IMG_VID_PATH", help="it has to be the 21 frames vid from sv3d output" + ) # TODO make sure that read in video will have the exactly same nparr from the sv3d output. OTHERWISE need to dump npz from sv3d and read in. + parser.add_argument("--name", default="anya_ms") + parser.add_argument("--scale", type=float, default=1.0, help="Scale of generated object.") + parser.add_argument( + "--config", type=str, default="models/instantmesh/configs/instant-mesh-large.yaml", help="Path to config file." + ) + parser.add_argument("--output_path", type=str, default="outputs", help="Output directory.") + args = parser.parse_args() + return args + + +class InstantMeshPipeline(nn.Cell): + def __init__(self, infer_config, model): + super().__init__() + self.infer_config = infer_config + self.model = model + + def construct(self, inputs: ms.Tensor, radius: float) -> Tensor: + input_cam = get_sv3d_input_cameras(radius=radius) + logger.info(f"registered cam shape is {input_cam.shape}") + logger.info(f"registered cam dtype is {input_cam.dtype}") + planes = self.model.forward_planes(inputs, input_cam) + # Uncomment this when Flexicubes available for ms + # mesh_out = self.model.extract_mesh_with_texture(planes, **self.infer_config) + logger.info( + "No support for Flexicubes at the moment, due to the MS operator issues. " + "Use a vanilla marching cube to extract meshes from SDF..." + ) + mesh_out = self.model.extract_mesh_triplane_feat_marching_cubes(planes) + return mesh_out + + +if __name__ == "__main__": + args = args_parse() + + ms.set_context( + mode=1, + device_target="Ascend", + # device_target='CPU', + device_id=6, + pynative_synchronize=True, + ) + set_random_seed(42) + config = OmegaConf.load(args.config) + config_name = os.path.basename(args.config).replace(".yaml", "") + model_config = config.model_config + infer_config = config.infer_config + + # read the vid and proc accordingly + input_vid_arr = imageio.mimread(args.input_vid) + images = np.asarray(input_vid_arr, dtype=np.uint8) + + logger.info("loading instantmesh model for multiview to 3d generation...") + model = instantiate_from_config(model_config) + model_ckpt_path = ".ckpts" + state_dict = ms.load_checkpoint(os.path.join(model_ckpt_path, "instant_mesh_large_ms.ckpt")) + state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith("lrm_generator.")} + m, u = ms.load_param_into_net(model, state_dict, strict_load=True) + mesh_path = os.path.join(args.output_path, "meshes") + + # np img preprocessing + name = args.name + logger.info(f"Creating {name} ...") + topil = ToPIL() + logger.info(f"the imgae shape is {images.shape}") + + # note that ms.vision.Resizer only takes int, cannot be float + # therefore the output clip results are slightly different from torch, which causes the triplane transformer output nan features + # thus fc cannot have the mesh... + # the workaround here is to use PIL built-in resize + images = np.array([(topil(img)).resize((320, 320), Image.LANCZOS) for img in images]) # img: n h w c + images = images.astype("float32") / 255.0 + images = images.clip(min=0, max=1) + images = np.expand_dims(images, axis=0) # b n h w c + images = rearrange(images, "b n h w c -> b n c h w") # b n c h w + + _debug_dump_vid = False + _use_torchvision_arr = False + if _debug_dump_vid: + test = rearrange(images[0], "t c h w -> t h w c") * 255 + imageio.mimwrite("resized_antialias_video_jul29.mp4", test.astype(np.uint8)) + if _use_torchvision_arr: # here load in the same image tensor as the torch version + images = np.load("resized_np_arr.npy") # b n c h w + + logger.info(f"the imgae shape is {images.shape}") + + # RGB image with [0,1] scale and properly sized requested by the ViTImgProc + if images.ndim == 5: + (B, N, C, H, W) = images.shape # image: [B, N, C, H, W] + images = images.reshape(B * N, C, H, W) + + # ViTImageProcessor moved out from dino wrapper to the main here, to avoid being in .construct(), do ImageNetStandard normalization + img_processor = ViTImageProcessor.from_pretrained(model_config.params.encoder_model_name) + inputs = img_processor( + images=images, + return_tensors="np", + do_rescale=False, + do_resize=False, + )["pixel_values"] + inputs = ms.Tensor(inputs).reshape(B, N, C, H, W) + pipeline = InstantMeshPipeline(infer_config, model) + mesh_out = pipeline(inputs, radius=4.0 * args.scale) + mesh_path_sample = os.path.join(mesh_path, f"{name}.obj") + # Uncomment this when Flexicubes available for ms + # vertices, faces, vertex_colors = mesh_out + # save_obj( + # vertices, + # faces, + # vertex_colors, + # mesh_path_sample + # ) + sdf = mesh_out + verts, faces = mcubes.marching_cubes(sdf.asnumpy().squeeze(0), 0) + mcubes.export_obj(verts, faces, mesh_path_sample) + logger.info(f"Mesh saved to {mesh_path}") diff --git a/examples/instantmesh/model_stage1.py b/examples/instantmesh/model_stage1.py new file mode 100644 index 0000000000..ba130aaf5e --- /dev/null +++ b/examples/instantmesh/model_stage1.py @@ -0,0 +1,168 @@ +# from loguru import logger +import logging +import os + +logger = logging.getLogger(__name__) + +import numpy as np +from einops import rearrange +from PIL import Image +from utils.loss_util import LPIPS + +import mindspore as ms +from mindspore import Tensor, mint, nn, ops +from mindspore.dataset.vision import ToPIL + +from mindone.utils.config import instantiate_from_config + + +class InstantMeshStage1WithLoss(nn.Cell): + """The training pipeline for instant mesh model.""" + + def __init__( + self, + lrm_generator_config=None, + lrm_ckpt_path=None, # these under two args are for loading ckpts + input_size=256, + render_size=192, + ): + super().__init__() + + self.input_size = input_size + self.render_size = render_size + self.lrm_generator = instantiate_from_config(lrm_generator_config) + + # load pretrained model + if lrm_ckpt_path is not None: + logger.info(f"LOADING lrm ckpts from {lrm_ckpt_path} \ninside model_stage1") + lrm_ckpt_sdict = ms.load_checkpoint(lrm_ckpt_path) + start_epoch = int(lrm_ckpt_sdict.get("epoch_num", ms.Tensor(0, ms.int32)).asnumpy().item()) + m, u = ms.load_param_into_net(self.lrm_generator, lrm_ckpt_sdict, strict_load=False) + self.resume_epoch = start_epoch + 1 + else: + logger.info( + "NOT loading ckpt inside model_stage1, will load openlrm model as configured in train.py (if applicable)." + ) + + self.lpips = LPIPS() + self.topil = ToPIL() + self.validation_step_outputs = [] + + def on_fit_start(self): + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, "images"), exist_ok=True) + os.makedirs(os.path.join(self.logdir, "images_val"), exist_ok=True) + + def prepare_validation_batch_data(self, batch, render_size, _use_dataloader=False): + """Used during eval/inference, cast all np input into Tensors. + + Args: + batch: np array, img that read in from the val dataset, which is np.fp32. + """ + topil = ToPIL() + lrm_generator_input = {} + + # cast input images np arr from fp32 to as topil() does not take fp32 + images = (batch["input_images"] * 255).astype("uint8") + + # this for: images = v2.functional.resize(images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) + images = np.array( + [(topil(img)).resize((320, 320), Image.LANCZOS) for img in images] + ) # img: 1 h w c; images: n h w c + images = images.astype("float32") / 255.0 + images = images.clip(min=0, max=1) + images = np.expand_dims(images, axis=0) + images = rearrange(images, "b n h w c -> b n c h w") + + input_c2ws = batch["input_c2ws"].flatten(start_dim=-2) + input_Ks = batch["input_Ks"].flatten(start_dim=-2) + + input_extrinsics = input_c2ws[:, :12] + input_intrinsics = ops.stack( + [ + input_Ks[:, 0], + input_Ks[:, 4], + input_Ks[:, 2], + input_Ks[:, 5], + ], + axis=-1, + ) + cameras = ops.cat([input_extrinsics, input_intrinsics], axis=-1) + + lrm_generator_input["cameras"] = cameras + + render_c2ws = batch["render_c2ws"].flatten(start_dim=-2) + render_Ks = batch["render_Ks"].flatten(start_dim=-2) + render_cameras = ops.cat([render_c2ws, render_Ks], axis=-1) + lrm_generator_input["render_cameras"] = render_cameras + + # create batch dim when not using dataloader, presuming bsize==1 + if not _use_dataloader: + for k, v in lrm_generator_input.items(): + lrm_generator_input[k] = v.unsqueeze(0) + + # assign the proc images at last, which is left as np array, for the ViTProc in eval.py + lrm_generator_input["images"] = images + lrm_generator_input["render_size"] = render_size + + return lrm_generator_input + + def construct( + self, + images: Tensor, + cameras: Tensor, + render_cameras: Tensor, + target_images: Tensor, + target_alphas: Tensor, + render_size: Tensor, + crop_params: Tensor, + ) -> Tensor: + """For training, only return loss.""" + images_rgb, images_depth, images_weight = self.lrm_generator( + images, cameras, render_cameras, render_size.item(), crop_params # to int + ) + render_images = ops.clamp(images_rgb, 0.0, 1.0) + render_alphas = ops.clamp(images_weight, 0.0, 1.0) + + loss = self.compute_loss(render_images, render_alphas, target_images, target_alphas) + + return loss + + def forward_nocalloss(self, images: Tensor, cameras: Tensor, render_cameras: Tensor, render_size: int) -> Tensor: + """For evaluate().""" + + images_rgb, images_depth, images_weight = self.lrm_generator( + images, cameras, render_cameras, render_size, crop_params=None + ) + render_images = ops.clamp(images_rgb, 0.0, 1.0) + render_alphas = ops.clamp(images_weight, 0.0, 1.0) + return render_images, render_alphas + + def compute_loss(self, render_images, render_alphas, target_images, target_alphas): + # NOTE: the rgb value range of OpenLRM is [0, 1] + + # render_images = render_out['render_images'] + # # TODO move the transform for gt data target_xx back to the dataset proc, BALANCE cpu/npucore proc for higher eff + # target_images = render_gt['target_images'].to(render_images.dtype) + target_images = target_images.to(render_images.dtype) + # render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + # target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + b, n, c, h, w = render_images.shape + render_images = render_images.reshape(b * n, c, h, w) * 2.0 - 1.0 + b, n, c, h, w = target_images.shape + target_images = target_images.reshape(b * n, c, h, w) * 2.0 - 1.0 + + loss_mse = ops.mse_loss(render_images, target_images) + + # FIXME current lpips loss wrong, not positive. And 2e3 scale here is for the difference of the lpips implemented here vs. torchmetric's + # loss_lpips = 2e2 * mint.mean(self.lpips(render_images, target_images)) + loss_lpips = 2.0 * mint.mean(self.lpips(render_images, target_images)) + + target_alphas = target_alphas.permute((0, 1, 4, 2, 3)) # b n h w c -> b n c h w + loss_mask = ops.mse_loss(render_alphas, target_alphas) + + logger.info(f"loss mse: {loss_mse}, loss mask: {loss_mask}, loss lpips: {loss_lpips}") + + loss = loss_mse + loss_mask + loss_lpips + + return loss diff --git a/examples/instantmesh/models/__init__.py b/examples/instantmesh/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/decoder/__init__.py b/examples/instantmesh/models/decoder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/decoder/transformer.py b/examples/instantmesh/models/decoder/transformer.py new file mode 100644 index 0000000000..4058104712 --- /dev/null +++ b/examples/instantmesh/models/decoder/transformer.py @@ -0,0 +1,142 @@ +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class BasicTransformerBlock(nn.Cell): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + """ + + # use attention from torch.nn.MultiHeadAttention + # Block contains a cross-attention layer, a self-attention layer, and a MLP + def __init__( + self, + inner_dim: int, + cond_dim: int, + num_heads: int, + eps: float, + attn_drop: float = 0.0, + attn_bias: bool = False, + mlp_ratio: float = 4.0, + mlp_drop: float = 0.0, + dtype: ms.dtype = ms.float32, + ): + super().__init__() + + self.norm1 = nn.LayerNorm([inner_dim]) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, + num_heads=num_heads, + kdim=cond_dim, + vdim=cond_dim, + dropout=attn_drop, + has_bias=attn_bias, + batch_first=True, + dtype=dtype, + ) + self.norm2 = nn.LayerNorm([inner_dim]) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, + num_heads=num_heads, + dropout=attn_drop, + has_bias=attn_bias, + batch_first=True, + dtype=dtype, + ) + self.norm3 = nn.LayerNorm([inner_dim]) + self.mlp = nn.SequentialCell( + nn.Dense(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(p=mlp_drop), + nn.Dense(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(p=mlp_drop), + ) + + def construct(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class TriplaneTransformer(nn.Cell): + """ + Transformer with condition that generates a triplane representation. + + Reference: + Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 + """ + + def __init__( + self, + inner_dim: int, + image_feat_dim: int, + triplane_low_res: int, + triplane_high_res: int, + triplane_dim: int, + num_layers: int, + num_heads: int, + eps: float = 1e-6, + dtype: ms.dtype = ms.float32, + use_recompute: bool = False, + ): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + # modules + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = ms.Parameter( + ops.randn(1, 3 * triplane_low_res**2, inner_dim) * (1.0 / inner_dim) ** 0.5 + ) # [L, D] + self.layers = nn.CellList( + [ + BasicTransformerBlock( + inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps, dtype=dtype + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm([inner_dim], epsilon=eps) + self.deconv = nn.Conv2dTranspose( + inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0, pad_mode="valid", has_bias=True + ) + + if use_recompute: + for b in self.layers: + b.recompute() + self.norm.recompute() + self.deconv.recompute() + + # @ms.jit + def construct(self, image_feats): + # image_feats: [N, L_cond, D_cond] + + N = image_feats.shape[0] + H = W = self.triplane_low_res + # L = 3 * H * W + + x = self.pos_embed.tile((N, 1, 1)) # [N, L, D] + for layer in self.layers: + x = layer(x, image_feats) + x = self.norm(x) + + # separate each plane and apply deconv + x = x.view(N, 3, H, W, -1) + # x = ops.einsum('nihwd->indhw', x) # [3, N, D, H, W] + # x = ops.reshape(x, (3, N, -1, H, W)) + x = ops.permute(x, (1, 0, 4, 2, 3)) + x = x.view(3 * N, -1, H, W) # [3*N, D, H, W] + x = self.deconv(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + # x = ops.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = ops.permute(x, (1, 0, 2, 3, 4)) + + return x diff --git a/examples/instantmesh/models/encoder/__init__.py b/examples/instantmesh/models/encoder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/encoder/dino.py b/examples/instantmesh/models/encoder/dino.py new file mode 100644 index 0000000000..d885f2eefd --- /dev/null +++ b/examples/instantmesh/models/encoder/dino.py @@ -0,0 +1,539 @@ +""" MindSpore Dino ViT model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +from transformers import ViTConfig + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Parameter, Tensor, ops +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.nn import LayerNorm + +from mindone.transformers.activations import ACT2FN +from mindone.transformers.mindspore_utils import find_pruneable_heads_and_indices, prune_linear_layer +from mindone.transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from mindone.transformers.modeling_utils import MSPreTrainedModel as PreTrainedModel + + +class ViTEmbeddings(nn.Cell): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = Parameter(ops.randn(1, 1, config.hidden_size)) + self.mask_token = Parameter(ops.zeros((1, 1, config.hidden_size))) if use_mask_token else None + self.patch_embeddings = ViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = Parameter(ops.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: Tensor, height: int, width: int) -> Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # FIXME this ops huge difference + patch_pos_embed = ops.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + recompute_scale_factor=True, + mode="bicubic", + align_corners=False, + ) + + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return ops.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), axis=1) + + def construct( + self, + pixel_values: Tensor, + bool_masked_pos: Optional[Tensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand((batch_size, seq_length, -1)) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).to(mask_tokens.dtype) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.to(embeddings.dtype).broadcast_to((batch_size, -1, -1)) + embeddings = ops.cat((cls_tokens, embeddings), axis=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTPatchEmbeddings(nn.Cell): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, has_bias=True) + + def construct(self, pixel_values: Tensor, interpolate_pos_encoding: bool = False) -> Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(start_dim=2).swapaxes(1, 2) + return embeddings + + +class ViTSelfAttention(nn.Cell): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Dense(config.hidden_size, self.all_head_size, bias_init=config.qkv_bias) + self.key = nn.Dense(config.hidden_size, self.all_head_size, bias_init=config.qkv_bias) + self.value = nn.Dense(config.hidden_size, self.all_head_size, bias_init=config.qkv_bias) + self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def construct( + self, hidden_states, head_mask: Optional[Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = ops.matmul(query_layer, key_layer.swapaxes(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = ops.Softmax(axis=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = ops.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3) + new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Cell): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Dense(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + + def construct(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Cell): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def construct( + self, + hidden_states: Tensor, + head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Cell): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Dense(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def construct(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Cell): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Dense(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + + def construct(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) # broadcasting + + +class ViTLayer(nn.Cell): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + # self.chunk_size_feed_construct = config.chunk_size_feed_construct + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = LayerNorm([config.hidden_size], epsilon=config.layer_norm_eps) + self.layernorm_after = LayerNorm([config.hidden_size], epsilon=config.layer_norm_eps) + + self.adaLN_modulation = nn.SequentialCell( + nn.SiLU(), nn.Dense(config.hidden_size, 4 * config.hidden_size, bias_init=True) + ) + # nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + self.adaLN_modulation[-1].weight = ops.zeros_like(self.adaLN_modulation[-1].weight) + self.adaLN_modulation[-1].bias = ops.zeros_like(self.adaLN_modulation[-1].bias) + + def construct( + self, + hidden_states: Tensor, + adaln_input: Tensor = None, + head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor]]: + shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, axis=1) + + self_attention_outputs = self.attention( + modulate( + self.layernorm_before(hidden_states), shift_msa, scale_msa + ), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTEncoder(nn.Cell): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.CellList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: Tensor, + adaln_input: Tensor = None, + head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + # return_dict: bool = True, # not returning dict, as graph mode does not support the dict wrapper which is a non-cell class + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # logger.info('in vit encoder, retruning tuple') + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + + def _init_weights(self, module: Union[nn.Dense, nn.Conv2d, LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Dense, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.set_data( + initializer( + TruncatedNormal(sigma=self.config.initializer_range, mean=0.0), module.weight.shape, ms.float32 + ) + ) + # module.weight.data = nn.init.trunc_normal_( # torch references + # module.weight.data.to(ms.float32), mean=0.0, std=self.config.initializer_range + # ).to(module.weight.dtype) + if module.bias is not None: + module.bias.set_data(initializer("Zero", module.bias.shape, ms.float32)) + # module.bias.data.zero_() + elif isinstance(module, LayerNorm): + module.beta.set_data(initializer("Zero", module.beta.shape, ms.float32)) + module.gamma.set_data(initializer("One", module.gamma.shape, ms.float32)) + elif isinstance(module, ViTEmbeddings): + # module.position_embeddings.data = nn.init.trunc_normal_( + # module.position_embeddings.data.to(ms.float32), + # mean=0.0, + # std=self.config.initializer_range, + # ).to(module.position_embeddings.dtype) + module.position_embeddings.set_data( + initializer( + TruncatedNormal(sigma=self.config.initializer_range, mean=0.0), + module.position_embeddings.shape, + ms.float32, + ) + ) + + # module.cls_token.data = nn.init.trunc_normal_( + # module.cls_token.data.to(ms.float32), + # mean=0.0, + # std=self.config.initializer_range, + # ).to(module.cls_token.dtype) + module.cls_token.set_data( + initializer( + TruncatedNormal(sigma=self.config.initializer_range, mean=0.0), module.cls_token.shape, ms.float32 + ) + ) + + +class ViTPooler(nn.Cell): + def __init__(self, config: ViTConfig): + super().__init__() + self.dense = nn.Dense(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def construct(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ViTModel(ViTPreTrainedModel): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm([config.hidden_size], epsilon=config.layer_norm_eps) + # self.layernorm = mo_LayerNorm([config.hidden_size], eps=config.layer_norm_eps, dtype=ms.float32) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def construct( + self, + pixel_values: Optional[Tensor] = None, + adaln_input: Optional[Tensor] = None, + bool_masked_pos: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + # return_dict: Optional[bool] = False, # not returning dict, as graph mode does not support the dict wrapper which is a non-cell class + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`ms.Tensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + adaln_input=adaln_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs[0] + # logger.info(f'seq output shape {sequence_output.shape}, dtype {sequence_output.dtype}') + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + # print(f'in vit model, retruning tuple. and pooled output is {pooled_output}') + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] diff --git a/examples/instantmesh/models/encoder/dino_wrapper.py b/examples/instantmesh/models/encoder/dino_wrapper.py new file mode 100644 index 0000000000..cb52f370cf --- /dev/null +++ b/examples/instantmesh/models/encoder/dino_wrapper.py @@ -0,0 +1,69 @@ +import mindspore as ms +import mindspore.nn as nn + +# comment when debug +from .dino import ViTModel + +# from dino import ViTModel + + +class DinoWrapper(nn.Cell): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + + def __init__(self, model_name: str, use_recompute: bool = False): + super().__init__() + self.model = self._build_dino(model_name) + self.camera_embedder = nn.SequentialCell( + nn.Dense(16, self.model.config.hidden_size, has_bias=True), + nn.SiLU(), + nn.Dense(self.model.config.hidden_size, self.model.config.hidden_size, has_bias=True), + ) + if use_recompute: + self.camera_embedder.recompute() + self.model.encoder.recompute() + self.model.layernorm.recompute() + + def construct(self, images: ms.Tensor, camera: ms.Tensor): # because img processor only takes np img + # image: [B, N, C, H, W] + # camera: [B, N, D] + if images.ndim == 5: + (B, N, C, H, W) = images.shape + images = images.reshape(B * N, C, H, W) + + # embed camera + N = camera.shape[1] + camera_embeddings = self.camera_embedder(camera) + # camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') + cam_emb_shape = camera_embeddings.shape + camera_embeddings = camera_embeddings.reshape(cam_emb_shape[0] * cam_emb_shape[1], cam_emb_shape[2]) + embeddings = camera_embeddings + + # This resampling of positional embedding uses bicubic interpolation + outputs = self.model(pixel_values=images, adaln_input=embeddings, interpolate_pos_encoding=True)[0] + return outputs + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + + try: + model = ViTModel.from_pretrained( + model_name, + mindspore_dtype=ms.float32, + add_pooling_layer=False, + local_files_only=True, + use_safetensors=True, + ) + + return model + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") + import time + + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/examples/instantmesh/models/geometry/__init__.py b/examples/instantmesh/models/geometry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/geometry/camera/__init__.py b/examples/instantmesh/models/geometry/camera/__init__.py new file mode 100644 index 0000000000..66fa0375ad --- /dev/null +++ b/examples/instantmesh/models/geometry/camera/__init__.py @@ -0,0 +1,7 @@ +from mindspore import nn + + +class Camera(nn.Cell): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/examples/instantmesh/models/geometry/camera/perspective_camera.py b/examples/instantmesh/models/geometry/camera/perspective_camera.py new file mode 100644 index 0000000000..b7800fef15 --- /dev/null +++ b/examples/instantmesh/models/geometry/camera/perspective_camera.py @@ -0,0 +1,34 @@ +import numpy as np + +import mindspore as ms +from mindspore import ops + +from . import Camera + + +# the ndc projection +def projection(x=0.1, n=1.0, f=50.0, near_plane=None): + if near_plane is None: + near_plane = n + return np.array( + [ + [n / x, 0, 0, 0], + [0, n / -x, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0], + ] + ).astype(np.float32) + + +class PerspectiveCamera(Camera): + def __init__(self, fovy=49.0, device="cuda"): + super(PerspectiveCamera, self).__init__() + self.device = device + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + self.proj_mtx = ( + ms.Tensor.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + ) + + def project(self, points_bxnx4): + out = ops.matmul(points_bxnx4, ops.transpose(self.proj_mtx, 1, 2)) + return out diff --git a/examples/instantmesh/models/geometry/rep_3d/__init__.py b/examples/instantmesh/models/geometry/rep_3d/__init__.py new file mode 100644 index 0000000000..7f88002edf --- /dev/null +++ b/examples/instantmesh/models/geometry/rep_3d/__init__.py @@ -0,0 +1,6 @@ +class Geometry: + def __init__(self): + pass + + def forward(self): + pass diff --git a/examples/instantmesh/models/geometry/rep_3d/flexicubes.py b/examples/instantmesh/models/geometry/rep_3d/flexicubes.py new file mode 100644 index 0000000000..097c22965d --- /dev/null +++ b/examples/instantmesh/models/geometry/rep_3d/flexicubes.py @@ -0,0 +1,611 @@ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint, ops + +from .tables import check_table, dmc_table, num_vd_table, tet_table + +__all__ = ["FlexiCubes"] + + +class FlexiCubes(nn.Cell): + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + dmc_table (ms.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (ms.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (ms.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (ms.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (ms.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (ms.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (ms.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners_idx (ms.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + edge_dir_table (ms.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (ms.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (ms.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + cube_edges (np.array): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + cube_corners (np.array): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + """ + + def __init__(self, qef_reg_scale=1e-3, weight_scale=0.99): + super().__init__() + self.dmc_table = ms.Tensor(dmc_table, dtype=ms.int64) + self.num_vd_table = ms.Tensor(num_vd_table, dtype=ms.int64) + self.check_table = ms.Tensor(check_table, dtype=ms.int64) + self.tet_table = ms.Tensor(tet_table, dtype=ms.int64) + self.quad_split_1 = ms.Tensor([0, 1, 2, 0, 2, 3], dtype=ms.int64) + self.quad_split_2 = ms.Tensor([0, 1, 3, 3, 1, 2], dtype=ms.int64) + self.quad_split_train = ms.Tensor([0, 1, 1, 2, 2, 3, 3, 0], dtype=ms.int64) + self.cube_corners_idx = ops.pow(2, ops.arange(8)) + self.cube_edges = ms.Tensor( + [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], dtype=ms.int8 + ) + self.edge_dir_table = ms.Tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=ms.int64) + self.dir_faces_table = ms.Tensor( + [[[5, 4], [3, 2], [4, 5], [2, 3]], [[5, 4], [1, 0], [4, 5], [0, 1]], [[3, 2], [1, 0], [2, 3], [0, 1]]], + dtype=ms.int64, + ) + self.adj_pairs = ms.Tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=ms.int64) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + # non-ms vars: these vars need to be used in init but not constrcut, thus cannot in ms + self.cube_corners = np.array( + [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]] + ) + # self.cube_edges = [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4] + + # an init function, needs to be done with np + def np_construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (ms.Tensor, ms.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = np.arange(8) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = np.ones(res) + + # res = ms.Tensor([res], dtype=ms.float32) + coords = np.stack([i for i in np.nonzero(voxel_grid_template)], axis=1) / res # N, 3 + verts = (self.cube_corners[None] / res + coords[:, None]).reshape(-1, 3) + cubes = (base_cube_f[None] + np.arange(coords.shape[0])[:, None] * 8).reshape(-1) + + verts_rounded = np.round(verts * 10**5) / (10**5) + # ms implementation, tbr + # unique_func = ops.UniqueConsecutive(axis=0, return_idx=True, return_counts=False) + # verts_unique, inverse_indices, _ = unique_func(verts_rounded) + verts_unique, inverse_indices = np.unique(verts_rounded, axis=0, return_inverse=True) + # verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes # wrapped coz used in flexicubes_geo .init() + + def __call__( + self, + x_nx3: ms.Tensor, + s_n: ms.Tensor, + cube_fx8: ms.Tensor, + res, + beta_fx12: ms.Tensor = None, + alpha_fx8: ms.Tensor = None, + gamma_f: ms.Tensor = None, + training=False, + output_tetmesh=False, + grad_func=None, + ): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (ms.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (ms.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (ms.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (ms.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (ms.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (ms.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (ms.Tensor, ms.float64Tensor, ms.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return ( + ops.zeros((0, 3)), + ops.zeros((0, 4), dtype=ms.int64) if output_tetmesh else ops.zeros((0, 3), dtype=ms.int64), + ops.zeros((0)), + ) + else: + surf_cubes = ( + surf_cubes.bool() + ) # in order to do tensor sum in ms, this masking tensor is casted into uint8, now cast back to bool for masking + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func + ) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func + ) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, + s_n, + cube_fx8, + vertices, + faces, + surf_edges, + s_edges, + vd_idx_map, + case_ids, + edge_indices, + surf_cubes, + training, + ) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = ops.norm(ue - mint.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = ops.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - mint.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = ops.tanh(beta_fx12) * self.weight_scale + 1 + else: + beta_fx12 = ops.ones((n_cubes, 12), dtype=ms.float32) + + if alpha_fx8 is not None: + alpha_fx8 = ops.tanh(alpha_fx8) * self.weight_scale + 1 + else: + alpha_fx8 = ops.ones((n_cubes, 8), dtype=ms.float32) + + if gamma_f is not None: + gamma_f = ops.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale) / 2 + else: + gamma_f = ops.ones((n_cubes), dtype=ms.float32) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.unsqueeze(0)).sum(-1) + + problem_config = self.check_table[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = ops.zeros(list(res) + [5], dtype=ms.int64) + vol_idx = ops.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + (vol_idx_problem_adj[..., 0] >= 0).to(ms.uint8) + & (vol_idx_problem_adj[..., 0] < res[0]).to(ms.uint8) + & (vol_idx_problem_adj[..., 1] >= 0).to(ms.uint8) + & (vol_idx_problem_adj[..., 1] < res[1]).to(ms.uint8) + & (vol_idx_problem_adj[..., 2] >= 0).to(ms.uint8) + & (vol_idx_problem_adj[..., 2] < res[2]).to(ms.uint8) + ).bool() + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[ + vol_idx_problem_adj[..., 0], vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2] + ] + + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = problem_config_adj[..., 0] == 1 + idx = ops.arange(case_ids.shape[0])[to_check][within_range][to_invert] + if len(idx) > 0: + case_ids[(idx,)] = problem_config[to_invert][..., -1] + return case_ids + + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + # unique_edges, _idx_map, counts = ops.unique_consecutive(all_edges, axis=0, return_idx=True, return_counts=True) + # unique_edges, _idx_map, counts = mint.unique(all_edges, axis=0, return_inverse=True, return_counts=True) # ms not supporting this + unique_edges, _idx_map, counts = mint.unique( + all_edges.to(ms.bfloat16), dim=0, return_inverse=True, return_counts=True + ) + unique_edges = unique_edges.long() + mask_edges = occ_n.to(ms.uint8)[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges.to(ms.uint8)[_idx_map] + counts = counts[_idx_map] + + mapping = ops.ones((unique_edges.shape[0]), dtype=ms.int64) * -1 + mapping[mask_edges] = ops.arange(mask_edges.sum()) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = ms.Tensor( + s_n < 0, dtype=ms.uint8 + ) # bool type cannot be sampled in the following line, needs to be ms tensor + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = ops.sum(occ_fx8, -1) # 8 verts per cube + surf_cubes = (_occ_sum > 0).to(ms.uint8) & (_occ_sum < 8).to(ms.uint8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = ops.cat( + [ + mint.index_select(input=edges_weight, index=ms.Tensor([1]), dim=edge_dim), + -mint.index_select(input=edges_weight, index=ms.Tensor([0]), dim=edge_dim), + ], + edge_dim, + ) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _compute_vd( + self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func + ): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = mint.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = mint.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = mint.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = mint.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = ops.zeros((case_ids.shape[0], 12), dtype=ms.int64) + + for num in mint.unique(num_vd): + cur_cubes = num_vd == num # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = ops.arange(curr_num_vd).unsqueeze(-1).tile((1, 7)) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = ( + ops.arange(idx_map.shape[0])[cur_cubes].unsqueeze(-1).tile((1, num * 7)).reshape_as(curr_edge_group) + ) + + curr_mask = curr_edge_group != -1 + edge_group.append(ops.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(ops.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(ops.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(ops.masked_select(gamma_f, cur_cubes).unsqueeze(-1).tile((1, num)).reshape(-1)) + + edge_group = ops.cat(edge_group) + edge_group_to_vd = ops.cat(edge_group_to_vd) + edge_group_to_cube = ops.cat(edge_group_to_cube) + vd_num_edges = ops.cat(vd_num_edges) + vd_gamma = ops.cat(vd_gamma) + + vd = ops.zeros((total_num_vd, 3)) + beta_sum = ops.zeros((total_num_vd, 1)) + + idx_group = mint.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = mint.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = mint.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = mint.index_select(input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = mint.index_select( + input=alpha_nx12x2.reshape(-1, 2), dim=0, index=edge_group_to_cube * 12 + edge_group + ).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = mint.gather( + input=beta_fx12.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group + ).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = ops.arange(vd.shape[0]) # + total_num_vd + + vd_idx_map = ops.scatter( + input=vd_idx_map.reshape(-1), + axis=0, + index=edge_group_to_cube * 12 + edge_group, + src=v_idx[edge_group_to_vd], + ) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate( + self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func + ): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + # with ops.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = ops.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = ops.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + # with ops.no_grad(): + vd_gamma = ops.norm(grad_func(vd), dim=-1) + quad_gamma = mint.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = mint.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = mint.index_select(input=quad_gamma, index=ms.Tensor([0]), dim=1) * mint.index_select( + input=quad_gamma, index=ms.Tensor([2]), dim=1 + ) + gamma_13 = mint.index_select(input=quad_gamma, index=ms.Tensor([1]), dim=1) * mint.index_select( + input=quad_gamma, index=ms.Tensor([3]), dim=1 + ) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = ops.zeros((quad_gamma.shape[0], 6), dtype=ms.int64) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = mint.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = ( + mint.index_select(input=vd_quad, index=ms.Tensor([0]), dim=1) + + mint.index_select(input=vd_quad, index=ms.Tensor([2]), dim=1) + ) / 2 + vd_13 = ( + mint.index_select(input=vd_quad, index=ms.Tensor([1]), dim=1) + + mint.index_select(input=vd_quad, index=ms.Tensor([3]), dim=1) + ) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ( + (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + ).squeeze(1) + vd_center_idx = ops.arange(vd_center.shape[0]) + vd.shape[0] + vd = ops.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = ops.cat([faces, vd_center_idx.reshape(-1, 1, 1).tile((1, 4, 1))], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, + x_nx3, + s_n, + cube_fx8, + vertices, + faces, + surf_edges, + s_edges, + vd_idx_map, + case_ids, + edge_indices, + surf_cubes, + training, + ): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = ops.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = ops.ones((occ_n.shape[0]), dtype=ms.int64) * -1 + mapping_inside_verts[occ_n] = ops.arange(occ_n.sum()) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[ + surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[s_edges < 0] + ] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand((-1, 2)).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand((-1, 4)).reshape(-1) + + tets_surface = ops.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = occ_sum == 8 + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = ops.arange(inside_cubes_center.shape[0]) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = ops.ones(((surface_n_inside_cubes).sum(), 13), dtype=ms.int64) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_func = ops.UniqueConsecutive(axis=0, return_idx=True, return_counts=True) + unique_edges, _idx_map, counts = unique_func(all_edges) + # unique_edges, _idx_map, counts = mint.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = ops.ones((unique_edges.shape[0]), dtype=ms.int64) * -1 + mapping[mask_edges] = ops.arange(mask_edges.sum()) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = ops.sort(group) + cube_idx = ( + ops.arange((_idx_map.shape[0] // 12), dtype=ms.int64).unsqueeze(1).expand((-1, 12)).reshape(-1)[group_mask] + ) + edge_idx = ( + ops.arange((12), dtype=ms.int64).unsqueeze(0).expand((_idx_map.shape[0] // 12, -1)).reshape(-1)[group_mask] + ) + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = ops.ones((surface_n_inside_cubes).sum(), dtype=ms.int64) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = ops.cat([quad_edge, inside_edge], -1)[mask] + + tets = ops.cat([tets_surface, tets_inside]) + vertices = ops.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets + + def construct(self, *args, **kwargs): + return super().construct(*args, **kwargs) + + +if __name__ == "__main__": + ms.context.set_context(mode=1, device_target="Ascend", device_id=7) + test_fc = FlexiCubes(weight_scale=0.5) + print(test_fc) + v, i = test_fc.construct_voxel_grid(res=64) + print(f"v shape {v.shape}") + print(f"i shape {i.shape}") diff --git a/examples/instantmesh/models/geometry/rep_3d/flexicubes_geometry.py b/examples/instantmesh/models/geometry/rep_3d/flexicubes_geometry.py new file mode 100644 index 0000000000..ee4fd0116a --- /dev/null +++ b/examples/instantmesh/models/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,120 @@ +import numpy as np + +import mindspore as ms +from mindspore import nn + +from .flexicubes import FlexiCubes + + +def get_center_boundary_index(grid_res): + v = np.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=np.bool_) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = np.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = np.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(nn.Cell): + def __init__(self, grid_res=64, scale=2.0, renderer=None, render_type="neural_render", args=None): + super().__init__() + self.grid_res = grid_res + self.args = args + self.fc = FlexiCubes(weight_scale=0.5) + verts, indices = self.fc.np_construct_voxel_grid(grid_res) + self.verts, self.indices = ms.Tensor(verts, dtype=ms.float32), ms.Tensor(indices, dtype=ms.int32) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + # all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + # self.all_edges = np.unique(all_edges) # buggy, this is a huge computatition, if done with np it takes a really long time. And it's not used anyway + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return np.min(self.verts, dim=0).values, np.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc( + v_deformed_nx3, + sdf_n, + indices, + self.grid_res, + beta_fx12=weight_n[:, :12], + alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], + training=is_training, + ) + return verts, faces, v_reg_loss + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == "neural_render": + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + hierarchical_mask=hierarchical_mask, + ) + + return_value["tex_pos"] = tex_pos + return_value["mask"] = mask + return_value["hard_mask"] = hard_mask + return_value["rast"] = rast + return_value["v_pos_clip"] = v_pos_clip + return_value["mask_pyramid"] = mask_pyramid + return_value["depth"] = depth + return_value["normal"] = normal + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value + + +if __name__ == "__main__": + geo = FlexiCubesGeometry() + print(geo) diff --git a/examples/instantmesh/models/geometry/rep_3d/tables.py b/examples/instantmesh/models/geometry/rep_3d/tables.py new file mode 100644 index 0000000000..6deed09c1e --- /dev/null +++ b/examples/instantmesh/models/geometry/rep_3d/tables.py @@ -0,0 +1,1472 @@ +dmc_table = [ + [ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 5, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 7, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 3, 5, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 5, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 5, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [2, 3, 5, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [8, 9, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 7, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 3, 4, 5, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 5, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 6, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 3, 5, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 10, -1, -1, -1, -1], + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], + [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [2, 3, 5, 6, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 6, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [6, 7, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 6, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 6, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 6, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [6, 7, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 3, 6, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [5, 6, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], + [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 6, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 6, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [2, 3, 6, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 6, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 3, 4, 6, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 6, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [6, 7, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [5, 7, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 5, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 4, 5, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 7, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 7, 10, 11, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [8, 9, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 9, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 8, 10, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 10, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 5, 7, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 5, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [2, 3, 8, 9, 10, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 2, 9, 10, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 10, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 2, 5, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 4, 7, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [1, 2, 8, 9, 11, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 2, 8, 11, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [2, 3, 11, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 5, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [5, 7, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 1, 4, 5, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [4, 5, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], + [ + [0, 3, 4, 7, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [4, 7, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 8, 9, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 1, 9, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 3, 8, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], + [ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + ], +] +num_vd_table = [ + 0, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 3, + 1, + 2, + 2, + 2, + 1, + 2, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 1, + 2, + 3, + 1, + 1, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 2, + 3, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 1, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 3, + 2, + 2, + 2, + 2, + 2, + 1, + 3, + 4, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 3, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 3, + 2, + 3, + 2, + 4, + 2, + 2, + 2, + 2, + 1, + 2, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, +] +check_table = [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 194], + [1, -1, 0, 0, 193], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 164], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 161], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 152], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 145], + [1, 0, 0, 1, 144], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 137], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 133], + [1, 0, 1, 0, 132], + [1, 1, 0, 0, 131], + [1, 1, 0, 0, 130], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 100], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 98], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 96], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 88], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 82], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 74], + [0, 0, 0, 0, 0], + [1, 0, 1, 0, 72], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 70], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 67], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 65], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 56], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 52], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 44], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 0, 0, 40], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 38], + [1, 0, -1, 0, 37], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 33], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 28], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 26], + [1, 0, 0, -1, 25], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, -1, 0, 0, 20], + [0, 0, 0, 0, 0], + [1, 0, -1, 0, 18], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 9], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, -1, 6], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], +] +tet_table = [ + [-1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, -1], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [0, 4, 0, 4, 4, -1], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [5, 5, 5, 5, 5, 5], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, -1, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, -1, 2, 4, 4, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 4, 4, 2], + [1, 1, 1, 1, 1, 1], + [2, 4, 2, 4, 4, 2], + [0, 4, 0, 4, 4, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 5, 2, 5, 5, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, -1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [4, 1, 1, 4, 4, 1], + [0, 1, 1, 0, 0, 1], + [4, 0, 0, 4, 4, 0], + [2, 2, 2, 2, 2, 2], + [-1, 1, 1, 4, 4, 1], + [0, 1, 1, 4, 4, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [5, 1, 1, 5, 5, 1], + [0, 1, 1, 0, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [8, 8, 8, 8, 8, 8], + [1, 1, 1, 4, 4, 1], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 4, 4, 1], + [0, 4, 0, 4, 4, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 5, 5, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [6, 6, 6, 6, 6, 6], + [6, -1, 0, 6, 0, 6], + [6, 0, 0, 6, 0, 6], + [6, 1, 1, 6, 1, 6], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [6, 4, -1, 6, 4, 6], + [6, 4, 0, 6, 4, 6], + [6, 0, 0, 6, 0, 6], + [6, 1, 1, 6, 1, 6], + [5, 5, 5, 5, 5, 5], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [2, 4, 2, 2, 4, 2], + [0, 4, 0, 4, 4, 0], + [2, 0, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [6, 1, 1, 6, -1, 6], + [6, 1, 1, 6, 0, 6], + [6, 0, 0, 6, 0, 6], + [6, 2, 2, 6, 2, 6], + [4, 1, 1, 4, 4, 1], + [0, 1, 1, 0, 0, 1], + [4, 0, 0, 4, 4, 4], + [2, 2, 2, 2, 2, 2], + [6, 1, 1, 6, 4, 6], + [6, 1, 1, 6, 4, 6], + [6, 0, 0, 6, 0, 6], + [6, 2, 2, 6, 2, 6], + [5, 1, 1, 5, 5, 1], + [0, 1, 1, 0, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [6, 6, 6, 6, 6, 6], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 4, 1], + [0, 4, 0, 4, 4, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 5, 0, 5, 0, 5], + [5, 5, 5, 5, 5, 5], + [5, 5, 5, 5, 5, 5], + [0, 5, 0, 5, 0, 5], + [-1, 5, 0, 5, 0, 5], + [1, 5, 1, 5, 1, 5], + [4, 5, -1, 5, 4, 5], + [0, 5, 0, 5, 0, 5], + [4, 5, 0, 5, 4, 5], + [1, 5, 1, 5, 1, 5], + [4, 4, 4, 4, 4, 4], + [0, 4, 0, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [6, 6, 6, 6, 6, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 5, 2, 5, -1, 5], + [0, 5, 0, 5, 0, 5], + [2, 5, 2, 5, 0, 5], + [1, 5, 1, 5, 1, 5], + [2, 5, 2, 5, 4, 5], + [0, 5, 0, 5, 0, 5], + [2, 5, 2, 5, 4, 5], + [1, 5, 1, 5, 1, 5], + [2, 4, 2, 4, 4, 2], + [0, 4, 0, 4, 4, 4], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 6, 2, 6, 6, 2], + [0, 0, 0, 0, 0, 0], + [2, 0, 2, 0, 0, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [4, 1, 1, 1, 4, 1], + [0, 1, 1, 1, 0, 1], + [4, 0, 0, 4, 4, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 4, 1], + [0, 0, 0, 0, 0, 0], + [4, 0, 0, 4, 4, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [6, 0, 0, 6, 0, 6], + [0, 0, 0, 0, 0, 0], + [6, 6, 6, 6, 6, 6], + [5, 5, 5, 5, 5, 5], + [5, 5, 0, 5, 0, 5], + [5, 5, 0, 5, 0, 5], + [5, 5, 1, 5, 1, 5], + [4, 4, 4, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [4, 4, 0, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [4, 4, 4, 4, 4, 4], + [4, 4, 0, 4, 4, 4], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [8, 8, 8, 8, 8, 8], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 0, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 1, 1, 4, 4, 1], + [2, 2, 2, 2, 2, 2], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 4, 2, 4, 4, 2], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [4, 4, 4, 4, 4, 4], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [12, 12, 12, 12, 12, 12], +] diff --git a/examples/instantmesh/models/lrm.py b/examples/instantmesh/models/lrm.py new file mode 100644 index 0000000000..71c2b1cc63 --- /dev/null +++ b/examples/instantmesh/models/lrm.py @@ -0,0 +1,197 @@ +from typing import Optional, Tuple + +import mcubes + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint + +from .decoder.transformer import TriplaneTransformer +from .encoder.dino_wrapper import DinoWrapper +from .renderer.synthesizer import TriplaneSynthesizer + + +class InstantNeRF(nn.Cell): + """ + Full model for training the LRM with nerf. + """ + + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = "facebook/dino-vitb16", + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + render_size: int = 192, + use_recompute: bool = False, + dtype: Optional[str] = None, + ): + super().__init__() + self.render_size = render_size + self.chunk_size = 1 + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + use_recompute=use_recompute, # enable the finest recompute + ) + + dtype_map = {"fp32": ms.float32, "fp16": ms.float16, "bf16": ms.bfloat16} + dtype = dtype_map[dtype] + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + dtype=dtype, + use_recompute=use_recompute, # enable the finest recompute + ) + + # TODO support the random crop before training get started and pass as args here, good for training performance + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + dtype=dtype, + use_recompute=use_recompute, # enable the finest recompute + ) + + def forward_planes(self, images: ms.Tensor, cameras: ms.Tensor): + # cameras: b n 16 + # images: b n c h w + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + + image_feats = image_feats.reshape( + B, int(image_feats.shape[-3] * image_feats.shape[-2] / B), image_feats.shape[-1] + ) + # logger.info(f'the shape in forward plane after reshape: {image_feats.shape}') + + # transformer decode the plane feat + planes = self.transformer(image_feats) + + return planes + + def construct( + self, + images: ms.Tensor, + cameras: ms.Tensor, + render_cameras: ms.Tensor, + render_size: int, + crop_params: Tuple[int], + ): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + B, V = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + + l_rgb_depth_weight = [] + for i in range(0, V, self.chunk_size): + syn_out = self.synthesizer( + planes, + cameras=render_cameras[:, i : i + self.chunk_size], + render_size=render_size, + crop_params=crop_params, + ) + l_rgb_depth_weight.append(syn_out) + images_rgb = mint.cat([view[0] for view in l_rgb_depth_weight], dim=1) + images_depth = mint.cat([view[1] for view in l_rgb_depth_weight], dim=1) + images_weight = mint.cat([view[2] for view in l_rgb_depth_weight], dim=1) + + return images_rgb, images_depth, images_weight + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + """ + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + """ + tex_pos = mint.cat(tex_pos, dim=0) + if hard_mask is not None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = mint.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = mint.cat( + [ + tex_pos_one_shape, + mint.zeros((1, max_point - tex_pos_one_shape.shape[1], 3), dtype=ms.float32), + ], + dim=1, + ) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = mint.cat(sample_tex_pose_list, dim=0) + + tex_feat = self.synthesizer.forward_points( + planes, + tex_pos, + )["rgb"] + + if hard_mask is not None: + final_tex_feat = mint.zeros(planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1]) + expanded_hard_mask = ( + hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + ) + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][: n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def extract_mesh( + self, + planes: ms.Tensor, + mesh_resolution: int = 256, + mesh_threshold: int = 10.0, + texture_resolution: int = 1024, + ): + """ + Extract a 3D mesh from triplane nerf. Only support batch_size 1. + :param planes: triplane features + :param mesh_resolution: marching cubes resolution + :param mesh_threshold: iso-surface threshold + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texture map + """ + assert planes.shape[0] == 1 + + grid_out = self.synthesizer.forward_grid( + planes=planes, + grid_size=mesh_resolution, + ) + + vertices, faces = mcubes.marching_cubes( + grid_out["sigma"].squeeze(0).squeeze(-1).cpu().numpy(), + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + + # query vertex colors + vertices_tensor = ms.tensor(vertices, dtype=ms.float32).unsqueeze(0) + vertices_colors = self.synthesizer.forward_points(planes, vertices_tensor)["rgb"].squeeze(0) + vertices_colors = (vertices_colors * 255).to(ms.uint8) + + return vertices, faces, vertices_colors diff --git a/examples/instantmesh/models/lrm_mesh.py b/examples/instantmesh/models/lrm_mesh.py new file mode 100644 index 0000000000..0a1822e9a9 --- /dev/null +++ b/examples/instantmesh/models/lrm_mesh.py @@ -0,0 +1,371 @@ +from loguru import logger + +import mindspore as ms +import mindspore.nn as nn +from mindspore import ops + +from .decoder.transformer import TriplaneTransformer +from .encoder.dino_wrapper import DinoWrapper +from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry +from .renderer.synthesizer_mesh import TriplaneSynthesizer + + +class InstantMesh(nn.Cell): + """ + Full model of the large reconstruction model. + """ + + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = "facebook/dino-vitb16", + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + grid_res: int = 128, + grid_scale: float = 2.0, + ): + super().__init__() + + # attributes + self.grid_res = grid_res + self.grid_scale = grid_scale + self.deformation_multiplier = 4.0 + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + ) + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + self.geometry = FlexiCubesGeometry(grid_res=self.grid_res, scale=self.grid_scale) + + def forward_planes(self, images, cameras): + # cameras: b n 16 + # images: b n c h w + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) # FIXME this vit has compute difference with torch + + # image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + image_feats = image_feats.reshape(B, image_feats.shape[-3] * image_feats.shape[-2], image_feats.shape[-1]) + logger.info(f"the shape in forward plane after reshape: {image_feats.shape}") + + # decode triplanes + planes = self.transformer(image_feats) + + return planes + + def get_sdf_deformation_prediction(self, planes): + """ + Predict SDF and deformation for tetrahedron vertices + :param planes: triplane feature map for the geometry + """ + init_position = self.geometry.verts.unsqueeze(0).expand((planes.shape[0], -1, -1)) + + # Step 1: predict the SDF and deformation. FIXME make sure that the whole model ckpt is loaded with the main func outside + sdf, deformation, weight = self.synthesizer.get_geometry_prediction( + planes, + init_position, + self.geometry.indices, + ) + + # Step 2: Normalize the deformation to avoid the flipped triangles. + deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * ops.tanh(deformation) + sdf_reg_loss = ops.zeros(sdf.shape[0], dtype=ms.float32) + + #### + # Step 3: Fix some sdf if we observe empty shape (full positive or full negative) + sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) + sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) + pos_shape = ops.sum((sdf_less_boundary > 0).int(), dim=-1) + neg_shape = ops.sum((sdf_less_boundary < 0).int(), dim=-1) + zero_surface = ops.bitwise_or( + ms.Tensor(pos_shape == 0, dtype=ms.uint8), ms.Tensor(neg_shape == 0, dtype=ms.uint8) + ) + if ops.sum(zero_surface).item() > 0: + update_sdf = ops.zeros_like(sdf[0:1]) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[:, self.geometry.center_indices] += 1.0 - min_sdf # greater than zero + update_sdf[:, self.geometry.boundary_indices] += -1 - max_sdf # smaller than zero + new_sdf = ops.zeros_like(sdf) + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + new_sdf[i_batch : i_batch + 1] += update_sdf + update_mask = (new_sdf == 0).float() + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = ops.abs(sdf).mean(axis=-1).mean(axis=-1) + sdf_reg_loss = sdf_reg_loss * zero_surface.float() + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + + # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) + final_sdf = [] + final_def = [] + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + final_sdf.append(sdf[i_batch : i_batch + 1]) + final_def.append(deformation[i_batch : i_batch + 1]) + else: + final_sdf.append(sdf[i_batch : i_batch + 1]) + final_def.append(deformation[i_batch : i_batch + 1]) + sdf = ops.cat(final_sdf, axis=0) + deformation = ops.cat(final_def, axis=0) + return sdf, deformation, sdf_reg_loss, weight + + def extract_mesh_triplane_feat_marching_cubes(self, planes): + """ + Takes triplane features to generate SDF and hence raw mesh extraction with the marching cubes. + """ + init_position = self.geometry.verts.unsqueeze(0).expand((planes.shape[0], -1, -1)) + sdf, _, _ = self.synthesizer.get_geometry_prediction( + planes, + init_position, + self.geometry.indices, + ) + sdf = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) + return sdf + + def get_geometry_prediction(self, planes=None): + """ + Function to generate mesh with give triplanes + :param planes: triplane features + """ + # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. + sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) + v_deformed = self.geometry.verts.unsqueeze(dim=0).expand((sdf.shape[0], -1, -1)) + deformation + tets = self.geometry.indices + n_batch = planes.shape[0] + v_list = [] + f_list = [] + flexicubes_surface_reg_list = [] + + # Step 2: Using marching tet to obtain the mesh (f: which is done by flexicubes) + logger.info(f"sdf shape of {sdf.shape}") + logger.info(f"nbathc shape of {n_batch}") + # ms squeeze requires batch dim exists + sdf = sdf.squeeze(axis=-1) + for i_batch in range(n_batch): + verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( + v_deformed[i_batch], + sdf[i_batch], + with_uv=False, + indices=tets, + weight_n=weight[i_batch], + is_training=self.training, + ) + flexicubes_surface_reg_list.append(flexicubes_surface_reg) + v_list.append(verts) + f_list.append(faces) + + flexicubes_surface_reg = ops.cat(flexicubes_surface_reg_list).mean() + flexicubes_weight_reg = (weight**2).mean() + + return ( + v_list, + f_list, + sdf, + deformation, + v_deformed, + (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg), + ) + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + """ + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + """ + tex_pos = ops.cat(tex_pos, axis=0) + if hard_mask is not None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = ops.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand((-1, -1, 3)) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape((1, -1, 3)) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = ops.cat( + [tex_pos_one_shape, ops.zeros(1, max_point - tex_pos_one_shape.shape[1], 3, dtype=ops.float32)], + axis=1, + ) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = ops.cat(sample_tex_pose_list, axis=0) + + tex_feat = self.synthesizer.get_texture_prediction(planes, tex_pos) + + if hard_mask is not None: + final_tex_feat = ops.zeros(planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1]) + expanded_hard_mask = ( + hard_mask.reshape(hard_mask.shape[0], -1, 1).expand((-1, -1, final_tex_feat.shape[-1])) > 0.5 + ) + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][: n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): + """ + Function to render a generated mesh with nvdiffrast + :param mesh_v: List of vertices for the mesh + :param mesh_f: List of faces for the mesh + :param cam_mv: 4x4 rotation matrix + :return: + """ + return_value_list = [] + for i_mesh in range(len(mesh_v)): + return_value = self.geometry.render_mesh( + mesh_v[i_mesh], mesh_f[i_mesh].int(), cam_mv[i_mesh], resolution=render_size, hierarchical_mask=False + ) + return_value_list.append(return_value) + + return_keys = return_value_list[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in return_value_list] + return_value[k] = value + + mask = ops.cat(return_value["mask"], axis=0) + hard_mask = ops.cat(return_value["hard_mask"], axis=0) + tex_pos = return_value["tex_pos"] + depth = ops.cat(return_value["depth"], axis=0) + normal = ops.cat(return_value["normal"], axis=0) + return mask, hard_mask, tex_pos, depth, normal + + def forward_geometry(self, planes, render_cameras, render_size=256): + """ + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape + """ + B, NV = render_cameras.shape[:2] + + # Generate 3D mesh first + mesh_v, mesh_f, sdf, _, _, sdf_reg_loss = self.get_geometry_prediction(planes) + + # Render the mesh into 2D image (get 3d position of each image plane) + cam_mv = render_cameras + run_n_view = cam_mv.shape[1] + antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh( + mesh_v, mesh_f, cam_mv, render_size=render_size + ) + + tex_hard_mask = hard_mask + tex_pos = [ops.cat([pos[i_view : i_view + 1] for i_view in range(run_n_view)], axis=2) for pos in tex_pos] + tex_hard_mask = ops.cat( + [ + ops.cat( + [ + tex_hard_mask[i * run_n_view + i_view : i * run_n_view + i_view + 1] + for i_view in range(run_n_view) + ], + axis=2, + ) + for i in range(planes.shape[0]) + ], + axis=0, + ) + + # Querying the texture field to predict the texture feature for each pixel on the image + tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) + background_feature = ops.ones_like(tex_feat) # white background + + # Merge them together + img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) + + # We should split it back to the original image shape + img_feat = ops.cat( + [ + ops.cat( + [ + img_feat[i : i + 1, :, render_size * i_view : render_size * (i_view + 1)] + for i_view in range(run_n_view) + ], + axis=0, + ) + for i in range(len(tex_pos)) + ], + axis=0, + ) + + img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + out = { + "img": img, + "mask": antilias_mask, + "depth": depth, + "normal": normal, + "sdf": sdf, + "mesh_v": mesh_v, + "mesh_f": mesh_f, + "sdf_reg_loss": sdf_reg_loss, + } + return out + + def construct(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + out = self.forward_geometry(planes, render_cameras, render_size=render_size) + + return {"planes": planes, **out} + + def extract_mesh_with_texture( + self, + planes: ms.Tensor, + **kwargs, + ): + """ + Extract a 3D mesh from FlexiCubes. Only support batch_size 1. + :param planes: triplane features + :param use_texture_map: use texture map or vertex color + """ + assert planes.shape[0] == 1 + + # predict geometry first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + vertices, faces = mesh_v[0], mesh_f[0] + + # query vertex colors + vertices_tensor = vertices.unsqueeze(0) + vertices_colors = self.synthesizer.get_texture_prediction(planes, vertices_tensor).clamp(0, 1).squeeze(0) + vertices_colors = (vertices_colors * 255).astype(ms.uint8) + + return vertices, faces, vertices_colors diff --git a/examples/instantmesh/models/renderer/__init__.py b/examples/instantmesh/models/renderer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/renderer/synthesizer.py b/examples/instantmesh/models/renderer/synthesizer.py new file mode 100644 index 0000000000..0e8af648bc --- /dev/null +++ b/examples/instantmesh/models/renderer/synthesizer.py @@ -0,0 +1,225 @@ +import itertools + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint + +from .utils.ray_sampler import RaySampler +from .utils.renderer import ImportanceRenderer + + +class OSGDecoder(nn.Cell): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + + def __init__( + self, + n_features: int, + hidden_dim: int = 64, + num_layers: int = 4, + activation: nn.Cell = nn.ReLU, + use_recompute: bool = True, + ): + super().__init__() + self.net = nn.SequentialCell( + nn.Dense(3 * n_features, hidden_dim), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 1 + 3), + ) + # bias init as zero by default, can refer to ~/examples/stable_diffusion_v2/tests/test_lora.py & lora_torch.py for evidence + + # @ms.jit # now has the error: Exceed function call depth limit 1000, (function call depth: 1001, simulate call depth: 508). + def construct(self, sampled_features): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes * _C) + x = sampled_features + + N, M, C = x.shape + x = x.view(N * M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = mint.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return rgb, sigma + + +class TriplaneSynthesizer(nn.Cell): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + "ray_start": "auto", + "ray_end": "auto", + "box_warp": 2.0, + "white_back": True, + "disparity_space_sampling": False, + "clamp_mode": "softplus", + "sampler_bbox_min": -1.0, + "sampler_bbox_max": 1.0, + } + + def __init__( + self, triplane_dim: int, samples_per_ray: int, dtype: ms.dtype = ms.float32, use_recompute: bool = False + ): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + dep_res = int(np.divmod(samples_per_ray, 2)[0]) + dep_res_imp = int(np.divmod(samples_per_ray, 2)[0]) + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + "depth_resolution": dep_res, + "depth_resolution_importance": dep_res_imp, + } + + # renderings + self.renderer = ImportanceRenderer(self.rendering_kwargs, dtype=dtype) + self.ray_sampler = RaySampler() + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + if use_recompute: + self.renderer.recompute() + self.decoder.recompute() + + # @ms.jit # now has the error in the renderer: Exceed function call depth limit 1000, (function call depth: 1001, simulate call depth: 508). + def construct(self, planes, cameras, render_size, crop_params): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # render_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + render_size=render_size, + ) + assert N * M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Crop rays if crop_params is available + if crop_params is not None: + ray_origins = ray_origins.reshape(N * M, render_size, render_size, 3) + ray_directions = ray_directions.reshape(N * M, render_size, render_size, 3) + i, j, h, w = crop_params.tolist()[0] + ray_origins = ray_origins[:, i : i + h, j : j + w, :].reshape(N * M, -1, 3) + ray_directions = ray_directions[:, i : i + h, j : j + w, :].reshape(N * M, -1, 3) + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), ray_origins, ray_directions + ) + + # Reshape into 'raw' neural-rendered image + if crop_params is not None: + Himg, Wimg = crop_params.tolist()[0][2:] + else: + Himg = Wimg = render_size + rgb_images = ( + rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + ) # b n c h w + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + # out = { + # 'images_rgb': rgb_images, + # 'images_depth': depth_images, + # 'images_weight': weight_images, + # } + return rgb_images, depth_images, weight_images + + # [inference only] below two func shortcuts, not used in graph training: get_texture_prediction() & extract_mesh() + # for run meshing code, not trained + def forward_points(self, planes, points: ms.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i : i + chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + sample_coordinates=chunk_points, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = {k: mint.cat([out[k] for out in outs], dim=1) for k in outs[0].keys()} + return point_features + + def forward_grid(self, planes, grid_size: int, aabb: ms.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = ( + ms.tensor( + [ + [self.rendering_kwargs["sampler_bbox_min"]] * 3, + [self.rendering_kwargs["sampler_bbox_max"]] * 3, + ], + dtype=planes.dtype, + ) + .unsqueeze(0) + .repeat(planes.shape[0], 1, 1) + ) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append( + mint.stack( + mint.meshgrid( + mint.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size), + mint.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size), + mint.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 3) + ) + cube_grid = mint.stack(grid_points, dim=0) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = {k: v.reshape(N, grid_size, grid_size, grid_size, -1) for k, v in features.items()} + return features diff --git a/examples/instantmesh/models/renderer/synthesizer_mesh.py b/examples/instantmesh/models/renderer/synthesizer_mesh.py new file mode 100644 index 0000000000..b2f6a95739 --- /dev/null +++ b/examples/instantmesh/models/renderer/synthesizer_mesh.py @@ -0,0 +1,155 @@ +import itertools + +import mindspore.nn as nn +from mindspore import ops + +from .utils.renderer import generate_planes, sample_from_planes + + +class OSGDecoder(nn.Cell): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + + def __init__(self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Cell = nn.ReLU): + super().__init__() + + self.net_sdf = nn.SequentialCell( + nn.Dense(3 * n_features, hidden_dim, bias_init="zeros"), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim, bias_init="zeros"), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 1, bias_init="zeros"), + ) + self.net_rgb = nn.SequentialCell( + nn.Dense(3 * n_features, hidden_dim, bias_init="zeros"), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim, bias_init="zeros"), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 3, bias_init="zeros"), + ) + self.net_deformation = nn.SequentialCell( + nn.Dense(3 * n_features, hidden_dim, bias_init="zeros"), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim, bias_init="zeros"), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 3, bias_init="zeros"), + ) + self.net_weight = nn.SequentialCell( + nn.Dense(8 * 3 * n_features, hidden_dim, bias_init="zeros"), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim, bias_init="zeros"), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 21, bias_init="zeros"), + ) + + def get_geometry_prediction(self, sampled_features, flexicubes_indices): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes * _C) + + sdf = self.net_sdf(sampled_features) + deformation = self.net_deformation(sampled_features) + + grid_features = ops.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), axis=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], + flexicubes_indices.shape[0], + flexicubes_indices.shape[1] * sampled_features.shape[-1], + ) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + def get_texture_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes * _C) + + rgb = self.net_rgb(sampled_features) + rgb = ops.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + return rgb + + +class TriplaneSynthesizer(nn.Cell): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + "ray_start": "auto", + "ray_end": "auto", + "box_warp": 2.0, + "white_back": True, + "disparity_space_sampling": False, + "clamp_mode": "softplus", + "sampler_bbox_min": -1.0, + "sampler_bbox_max": 1.0, + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + "depth_resolution": samples_per_ray // 2, + "depth_resolution_importance": samples_per_ray // 2, + } + + # modules + self.plane_axes = generate_planes() + self.decoder = OSGDecoder(n_features=triplane_dim) + + def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): + plane_axes = self.plane_axes + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode="zeros", box_warp=self.rendering_kwargs["box_warp"] + ) + + sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) + return sdf, deformation, weight + + def get_texture_prediction(self, planes, sample_coordinates): + plane_axes = self.plane_axes + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode="zeros", box_warp=self.rendering_kwargs["box_warp"] + ) + + rgb = self.decoder.get_texture_prediction(sampled_features) + return rgb diff --git a/examples/instantmesh/models/renderer/utils/__init__.py b/examples/instantmesh/models/renderer/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/models/renderer/utils/renderer.py b/examples/instantmesh/models/renderer/utils/renderer.py new file mode 100644 index 0000000000..be1d98049f --- /dev/null +++ b/examples/instantmesh/models/renderer/utils/renderer.py @@ -0,0 +1,403 @@ +""" The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" +import itertools +from typing import Dict + +from sgm.util import Inverse + +import mindspore as ms +import mindspore.nn as nn +from mindspore import _no_grad, mint, ops + +from . import math_utils +from .ray_marcher import MipRayMarcher2 + + +class OSGDecoder(nn.Cell): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + + def __init__(self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Cell = nn.ReLU): + super().__init__() + self.net = nn.SequentialCell( + nn.Dense(3 * n_features, hidden_dim), + activation(), + *itertools.chain( + *[ + [ + nn.Dense(hidden_dim, hidden_dim), + activation(), + ] + for _ in range(num_layers - 2) + ] + ), + nn.Dense(hidden_dim, 1 + 3), + ) + # bias init as zero by default, can refer to ~/examples/stable_diffusion_v2/tests/test_lora.py & lora_torch.py for evidence + + def construct(self, sampled_features): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + x = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes * _C) + + N, M, color = x.shape + x = x.view((N * M, color)) + + x = self.net(x) + x = x.view((N, M, -1)) + rgb = mint.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return rgb, sigma + + +class ImportanceRenderer(nn.Cell): + """Modified version of the filtering the out-of-box sampels as TensorRF does.""" + + def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False): + super().__init__() + self.rendering_options = opts + + # ms graph mode paradigm: NOT passing hyperparam (that are not Tensors) as the construct args, make it class attr instead + self.disparity_space_sampling = opts["disparity_space_sampling"] + self.depth_resolution = opts["depth_resolution"] + self.N_importance = opts["depth_resolution_importance"] + self.ray_marcher = MipRayMarcher2(opts) + + self.plane_axes = generate_planes().astype(dtype) + self.max_pool1d_layer = nn.MaxPool1d(2, 1, pad_mode="pad", padding=1) + self.avg_pool1d_layer = nn.AvgPool1d(2, 1) + + self.inverse_operator = Inverse() # workaournd for the case that current amp not working on mint/ops.inverse + self.decoder = OSGDecoder(n_features=80) # triplane_dim + self.debug_logging = debug + + self.path_to_save_grad = "./renderer_save_grad/" + + def project_onto_planes( + self, + planes: ms.Tensor, # when calling this here from outside, particually it's sampling on the unit plane axes + coordinates: ms.Tensor, + ) -> ms.Tensor: + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1) + coordinates = coordinates.broadcast_to((-1, n_planes, -1, -1)).reshape(N * n_planes, M, 3) + + # cast to fp32 as the inverse operator requests, ops.inverse does not support amp + # May report an issue to the ms/mindone about the amp blacklist: + # TypeError: For primitive[MatrixInverse], the input argument[x] must be a type of {Tensor[Complex128], + # Tensor[Complex64], Tensor[Float32], Tensor[Float64]}, but got Tensor[Float16]. + # inv_planes = ops.inverse(planes.to(ms.float32)).unsqueeze(0) + inv_planes = self.inverse_operator(planes.to(ms.float32)).unsqueeze(0) + inv_planes = inv_planes.broadcast_to((N, -1, -1, -1)).reshape(N * n_planes, 3, 3) + + projections = mint.bmm(coordinates, inv_planes.to(planes.dtype)) + return projections[..., :2] + + def sample_from_planes(self, plane_features, coordinates): + mode = "bilinear" + padding_mode = "zeros" + assert padding_mode == "zeros" + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N * n_planes, C, H, W) + # dtype = plane_features.dtype + + coordinates = (2 / self.rendering_options["box_warp"]) * coordinates # add specific box bounds + + # debug + # plane_features += self.plane_feat_grad + + projected_coordinates = self.project_onto_planes(self.plane_axes, coordinates).unsqueeze(1) + + # ts.save(self.path_to_save_grad + 'pf_input', plane_features) + # plane_features = ts.save_grad(self.path_to_save_grad + 'pf_input_grad', plane_features) + + # output_features = ops.grid_sample( + # plane_features, + # projected_coordinates, + # mode=mode, + # padding_mode=padding_mode, + # align_corners=False, + # ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + + output_features = ( + mint.nn.functional.grid_sample( + plane_features, + projected_coordinates, + mode=mode, + padding_mode=padding_mode, + align_corners=False, + ) + .permute(0, 3, 2, 1) + .reshape(N, n_planes, M, C) + ) + + # ts.save(self.path_to_save_grad + 'pf_output', plane_features) + # output_features = ts.save_grad(self.path_to_save_grad + 'output_features_output_grad', output_features) + + return output_features + + def run_model( + self, + planes, + sample_coordinates, + ): + """Run triplane sampler & nerf decoder model""" + sampled_features = self.sample_from_planes(planes, sample_coordinates) + out = self.decoder(sampled_features) + return out + + def _forward_pass( + self, + depths: ms.Tensor, + ray_directions: ms.Tensor, + ray_origins: ms.Tensor, + planes: ms.Tensor, + ): + """ + Additional filtering is applied to filter out-of-box samples. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + + # define sample points + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape( + batch_size, -1, 3 + ) + + # filter out-of-box samples + mask_inbox = ops.logical_and( + self.rendering_options["sampler_bbox_min"] <= sample_coordinates, + sample_coordinates <= self.rendering_options["sampler_bbox_max"], + ) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _rgb, _sigma = self.run_model(planes, sample_coordinates) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _sigma.dtype + colors_pass = ops.zeros((batch_size, num_rays * samples_per_ray, 3), dtype=DATA_TYPE) + densities_pass = ( + ops.nan_to_num(mint.full((batch_size, num_rays * samples_per_ray, 1), -float("inf"), dtype=DATA_TYPE)) + / SAFE_GUARD + ) + + # colors_pass[mask_inbox] = _rgb[mask_inbox] + mask_inbox = mask_inbox[..., None] + colors_pass = mint.where(mask_inbox, _rgb, colors_pass) # Tensor indexing assignment in G mode + # densities_pass[mask_inbox] = _sigma[mask_inbox] # GRAPH MODE: index val assignment using tensor cannot be mul dims + densities_pass = mint.where(mask_inbox, _sigma, densities_pass) + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = ops.sort(all_depths, axis=-2) + all_depths = mint.gather(all_depths, -2, indices) + all_colors = mint.gather(all_colors, -2, indices.broadcast_to((-1, -1, -1, all_colors.shape[-1]))) + all_densities = mint.gather(all_densities, -2, indices.broadcast_to((-1, -1, -1, 1))) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): + all_depths = ops.cat([depths1, depths2], axis=-2) + all_colors = ops.cat([colors1, colors2], axis=-2) + all_densities = ops.cat([densities1, densities2], axis=-2) + + if normals1 is not None and normals2 is not None: + all_normals = ops.cat([normals1, normals2], axis=-2) + else: + all_normals = None + + _, indices = ops.sort(all_depths, axis=-2) + all_depths = mint.gather(all_depths, -2, indices) + all_colors = mint.gather(all_colors, -2, indices.broadcast_to((-1, -1, -1, all_colors.shape[-1]))) + all_densities = mint.gather(all_densities, -2, indices.broadcast_to((-1, -1, -1, 1))) + + if all_normals is not None: + all_normals = mint.gather(all_normals, -2, indices.broadcast_to((-1, -1, -1, all_normals.shape[-1]))) + return all_depths, all_colors, all_normals, all_densities + + return all_depths, all_colors, all_densities + + def sample_stratified( + self, + ray_origins: ms.Tensor, # b n 3 + ray_start: ms.Tensor, # b n 1 + ray_end: ms.Tensor, # b n 1 + ): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if self.disparity_space_sampling: + depths_coarse = ( + mint.linspace(0, 1, self.depth_resolution).reshape(1, 1, self.depth_resolution, 1).tile((N, M, 1, 1)) + ) + depth_delta = 1 / (self.depth_resolution - 1) + depths_coarse += mint.rand_like(depths_coarse) * depth_delta + depths_coarse = 1.0 / (1.0 / ray_start * (1.0 - depths_coarse) + 1.0 / ray_end * depths_coarse) + else: + # print(f'shape: ray start: {ray_start.shape}, ray end: {ray_end.shape}') + depths_coarse = math_utils.linspace(ray_start, ray_end, self.depth_resolution).permute(1, 2, 0, 3) + depth_delta = (ray_end - ray_start) / (self.depth_resolution - 1) + depths_coarse += mint.rand_like(depths_coarse) * depth_delta[..., None] + + return depths_coarse + + def sample_pdf(self, bins, weights): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + Not-tensor Inputs: + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + det = (False,) + eps = 1e-5 + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / mint.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = ops.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = ops.cat([mint.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = mint.linspace(0, 1, self.N_importance) + u = u.broadcast_to((N_rays, self.N_importance)) + else: + u = ops.rand(N_rays, self.N_importance) + + inds = mint.searchsorted(cdf, u, right=True) + below = mint.clamp(inds - 1, min=0) + above = mint.clamp(inds, max=N_samples_) + + inds_sampled = mint.stack([below, above], -1).view(N_rays, 2 * self.N_importance) + cdf_g = mint.gather(cdf, 1, inds_sampled).view(N_rays, self.N_importance, 2) + bins_g = mint.gather(bins, 1, inds_sampled).view(N_rays, self.N_importance, 2) + + denom = cdf_g[..., 1] - cdf_g[..., 0] + denom[denom < eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled + # anyway, therefore any value for it is fine (set to 1 here) + + samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] - bins_g[..., 0]) + return samples + + def sample_importance( + self, + z_vals, + weights, + ): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with _no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = self.max_pool1d_layer(weights.unsqueeze(1)) + weights = self.avg_pool1d_layer(weights).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1]).reshape( + batch_size, num_rays, self.N_importance, 1 + ) + return importance_z_vals + + def construct( + self, + planes: ms.Tensor, + ray_origins: ms.Tensor, + ray_directions: ms.Tensor, + ): + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions) + is_ray_valid = ray_end > ray_start + + # FIXME below take item may degrade the shape, potentially into unknown errors... + if ops.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end) + + # Coarse Pass + colors_coarse, densities_coarse = self._forward_pass( + depths=depths_coarse, + ray_directions=ray_directions, + ray_origins=ray_origins, + planes=planes, + ) + + # print(f'input below cc: {colors_coarse}\n weights: {densities_coarse}\n depth color: {densities_coarse}') + # print(f'input below cc: {colors_coarse}\n weights: {densities_coarse}\n depth color: {densities_coarse}') + # ops.print_('n importance is', self.N_importance) + + # Fine Pass + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse) + + depths_fine = self.sample_importance(depths_coarse, weights) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, + ray_directions=ray_directions, + ray_origins=ray_origins, + planes=planes, + ) + + all_depths, all_colors, all_densities = self.unify_samples( + depths_coarse, colors_coarse, densities_coarse, depths_fine, colors_fine, densities_fine + ) + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths) + + return rgb_final, depth_final, weights.sum(2) + + # [inference] for run meshing code, not trained + def run_model_activated(self, planes, sample_coordinates, options=None): + _rgb, _sigma = self.run_model(planes, sample_coordinates) + _sigma = self.activation_factory(options)(_sigma) + return _rgb, _sigma + + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return ms.Tensor( + [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]]], + dtype=ms.float32, + ) diff --git a/examples/instantmesh/requirements.txt b/examples/instantmesh/requirements.txt new file mode 100644 index 0000000000..ed08290f0f --- /dev/null +++ b/examples/instantmesh/requirements.txt @@ -0,0 +1,7 @@ +pycubes +trimesh +imageio +loguru +einops +huggingface_hub +omegaconf diff --git a/examples/instantmesh/train.py b/examples/instantmesh/train.py new file mode 100644 index 0000000000..4cb4ebc968 --- /dev/null +++ b/examples/instantmesh/train.py @@ -0,0 +1,464 @@ +""" InstantMesh Stage-1 Training Script """ +import argparse +import datetime + +# from loguru import logger +import logging +import math +import os + +import yaml +from utils.train_util import str2bool + +import mindspore as ms +from mindspore import Model, nn +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import TimeMonitor + +logger = logging.getLogger(__name__) + +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../.."))) # for mindone +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) # for sgm + +from model_stage1 import InstantMeshStage1WithLoss + +# from eval import ValidationCallback, LossMonitor +from omegaconf import OmegaConf +from utils.ms_callback_util import SaveCkptCallback + +from mindone.data import create_dataloader +from mindone.env import init_train_env +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch +from mindone.trainers.checkpoint import resume_train_network +from mindone.trainers.ema import EMA +from mindone.trainers.lr_schedule import create_scheduler +from mindone.trainers.optim import create_optimizer +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.config import instantiate_from_config +from mindone.utils.logger import set_logger +from mindone.utils.params import count_params +from mindone.utils.seed import set_random_seed + + +def parse_args(**parser_kwargs): + parser = argparse.ArgumentParser(**parser_kwargs) + parser = parse_train_args(parser) + parser.add_argument( + "--resume", + type=str, + default=None, + help="resume from checkpoint with path", + ) + parser.add_argument( + "--base", + type=str, + default="configs/instant-nerf-large-train.yaml", + help="path to base configs", + ) + parser.add_argument( + "--log_interval", + default=1, + type=int, + help="log interval in the unit of data sink size.. E.g. if data sink size = 10, log_inteval=2, log every 20 steps", + ) + parser.add_argument( + "--debug", + default=False, # also setting debug as true will set pynative sync as true as well + help="When debugging, set it true. Dumping files will overlap to avoid trashing your storage.", + ) + args = parser.parse_args() + return args + + +def parse_train_args(parser): + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--use_parallel", default=False, help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" + ) + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument( + "--dtype", + default="fp32", # if amp level O0/1, must pass fp32 + type=str, + choices=["bf16", "fp16", "fp32"], + help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--global_bf16", + default=False, + help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN", + ) + parser.add_argument( + "--amp_level", + default="O0", # cannot amp for InstantMesh training, easily grad nan + type=str, + help="mindspore amp level, O1: most fp32, only layers in whitelist compute in fp16 (dense, conv, etc); \ + O2: most fp16, only layers in blacklist compute in fp32 (batch norm etc)", + ) + parser.add_argument("--ckpt_save_interval", default=1, type=int, help="save checkpoint every this epochs") + parser.add_argument( + "--profile", + default=False, # deactivate as profiler says NOT supporting PyNative + type=str2bool, + help="Profile or not", + ) + parser.add_argument( + "--loss_scaler_type", default=None, type=str, help="dynamic or static" # loss scale only used in amp O1/O2 + ) + parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale") + parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") + parser.add_argument("--scale_window", default=1000, type=float, help="scale window") + parser.add_argument("--ckpt_max_keep", default=5, type=int, help="Maximum number of checkpoints to keep") + parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") + parser.add_argument( + "--log_level", + type=str, + default="logging.INFO", + help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR", + ) + parser.add_argument("--num_parallel_workers", default=12, type=int, help="num workers for data loading") + parser.add_argument( + "--data_multiprocessing", + default=False, + help="If True, use multiprocessing for data processing. Default: multithreading.", + ) + parser.add_argument("--max_rowsize", default=64, type=int, help="max rowsize for data loading") + parser.add_argument( + "--train_steps", default=-1, type=int, help="If not -1, limit the number of training steps to the set value" + ) + parser.add_argument( + "--epochs", + # default=3, + default=7000, + type=int, + help="epochs. If dataset_sink_mode is on, epochs is with respect to dataset sink size. Otherwise, it's w.r.t the dataset size.", + ) + parser.add_argument( + "--ckpt_save_steps", + default=-1, + type=int, + help="save checkpoint every this steps. If -1, use ckpt_save_interval will be used.", + ) + parser.add_argument("--step_mode", default=False, help="whether save ckpt by steps. If False, save ckpt by epochs.") + # optimizer param + parser.add_argument("--use_ema", default=True, help="whether use EMA") + parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps") + parser.add_argument("--clip_grad", default=True, type=str2bool, help="whether apply gradient clipping") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="max gradient norm for clipping, effective when `clip_grad` enabled.", + ) + parser.add_argument("--start_learning_rate", default=4e-4, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") + parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") + parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.") + parser.add_argument("--optim", default="adamw", type=str, help="optimizer") + parser.add_argument( + "--betas", + type=float, + nargs="+", + default=[0.9, 0.95], + help="Specify the [beta1, beta2] parameter for the AdamW optimizer.", + ) + parser.add_argument( + "--optim_eps", type=float, default=1e-6, help="Specify the eps parameter for the AdamW optimizer." + ) + parser.add_argument( + "--group_strategy", + type=str, + default="norm_and_bias", + help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \ + If None, filter list is [layernorm, bias]. Default: norm_and_bias", + ) + parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay.") + parser.add_argument("--seed", default=42, type=int, help="data path") + parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps") + # dataloader param + parser.add_argument("--dataset_sink_mode", default=False, help="sink mode") + parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.") + + return parser + + +def main(args): + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + if not args.resume: + args.output_path = os.path.join(args.output_path, time_str) if not args.debug else args.output_path + else: + # FIXME now when resume the output path overwrites the debug path, fix it... + pass + + # 1. init + did, rank_id, device_num = init_train_env( + args.mode, + seed=args.seed, + distributed=args.use_parallel, + device_target=args.device_target, + max_device_memory=args.max_device_memory, + debug=args.debug, + ) + set_random_seed(42) + set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) + + # load config yaml file + config = OmegaConf.load(args.base) + ckpt_dir = os.path.join(args.output_path, "ckpt") + + # 2. model initiate + # 2.1 instantmesh model stage 1 + img_size = config.model.params.input_size + config.model.params.lrm_generator_config.params.dtype = args.dtype + + lrm_model_with_loss = InstantMeshStage1WithLoss(config.model.params.lrm_generator_config) + lrm_model_with_loss.set_train(True) + + if not args.global_bf16: + lrm_model_with_loss.lrm_generator = auto_mixed_precision( + lrm_model_with_loss.lrm_generator, + amp_level=args.amp_level, + ) + + # 3. create dataset + dataset = instantiate_from_config(config.data.train) + nw = config.data.num_workers + dataloader = create_dataloader( + dataset, + batch_size=config.data.batch_size, + shuffle=False, + device_num=device_num, + rank_id=rank_id, + num_workers=nw, + python_multiprocessing=args.data_multiprocessing, + max_rowsize=args.max_rowsize, + debug=False, # ms240_sept4: THIS CANNOT BE TRUE, OTHERWISE loader error + # Sort output columns to match DiffusionWithLoss input + # project_columns=project_columns, # not sure input/target frames data should use this option ornot + ) + + dataset_size = dataloader.get_dataset_size() + + # compute total steps and data epochs (in unit of data sink size) + if args.train_steps == -1: + assert args.epochs != -1 + total_train_steps = args.epochs * dataset_size + else: + total_train_steps = args.train_steps + + if args.dataset_sink_mode and args.sink_size != -1: + steps_per_sink = args.sink_size + else: + steps_per_sink = dataset_size + sink_epochs = math.ceil(total_train_steps / steps_per_sink) + + if args.ckpt_save_steps == -1: + ckpt_save_interval = args.ckpt_save_interval + step_mode = False + else: + step_mode = not args.dataset_sink_mode + if not args.dataset_sink_mode: + ckpt_save_interval = args.ckpt_save_steps + else: + # still need to count interval in sink epochs + ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_sink) + if args.ckpt_save_steps % steps_per_sink != 0: + logger.warning( + f"'ckpt_save_steps' must be times of sink size or dataset_size under dataset sink mode." + f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps." + ) + step_mode = step_mode if args.step_mode is None else args.step_mode + + logger.info(f"train_steps: {total_train_steps}, train_epochs: {args.epochs}, sink_size: {args.sink_size}") + logger.info(f"total train steps: {total_train_steps}, sink epochs: {sink_epochs}") + logger.info( + "ckpt_save_interval: {} {}".format( + ckpt_save_interval, "steps" if (not args.dataset_sink_mode and step_mode) else "sink epochs" + ) + ) + + # 4. build training utils: lr, optim, callbacks, trainer + # build learning rate scheduler + if not args.decay_steps: + args.decay_steps = total_train_steps - args.warmup_steps # fix lr scheduling + if args.decay_steps <= 0: + logger.warning( + f"decay_steps is {args.decay_steps}, please check epochs, dataset_size and warmup_steps. " + f"Will force decay_steps to be set to 1." + ) + args.decay_steps = 1 + + # TODO is the warmup + cosinedecay scheduler already the same as mindspore.experimental.optim.lr_scheduler? + # do we need to make a warm restart for the lr schedulaer to be 100% aligned with the vanilla version ()? + lr = create_scheduler( + steps_per_epoch=dataset_size, + name=args.scheduler, + lr=args.start_learning_rate, + end_lr=args.end_learning_rate, + warmup_steps=args.warmup_steps, + decay_steps=args.decay_steps, + total_steps=total_train_steps, + ) + + # 4.1 build optimizer + optimizer = create_optimizer( + lrm_model_with_loss.trainable_params(), + name=args.optim, + betas=args.betas, + eps=args.optim_eps, + group_strategy=args.group_strategy, + weight_decay=args.weight_decay, + lr=lr, + ) + + if args.loss_scaler_type == "dynamic": # for the case when there is an overflow during training + loss_scaler = DynamicLossScaleUpdateCell( + loss_scale_value=args.init_loss_scale, scale_factor=args.loss_scale_factor, scale_window=args.scale_window + ) + elif args.loss_scaler_type == "static": + loss_scaler = nn.FixedLossScaleUpdateCell(args.init_loss_scale) + else: + loss_scaler = ms.Tensor([1.0], dtype=ms.float32) + + # 4.2 weight loading: load checkpoint when resume + lrm_model = lrm_model_with_loss.lrm_generator + if args.resume: + logger.info(f"Loading Fred's own ckpt {args.resume}'s 'train_resume.ckpt'") + resume_ckpt = os.path.join(args.resume, "train_resume.ckpt") + start_epoch, loss_scale, cur_iter, last_overflow_iter = resume_train_network( + lrm_model, optimizer, resume_ckpt + ) # refer to hpcai train script about the input usage of this func + loss_scaler.loss_scale_value = loss_scale + loss_scaler.cur_iter = cur_iter + loss_scaler.last_overflow_iter = last_overflow_iter + else: + logger.info( + f"Resuming is turned off, with args {args.resume}.\n\t" + "Following original itmh implementation by initializing the model using the pretrained weights from openlrm," + "see Sec. 3.2 of the paper for details.\n" + ) + start_epoch = 0 + resume_param = ms.load_checkpoint(config.model.params.lrm_generator_config.openlrm_ckpt) + ms.load_param_into_net(lrm_model, resume_param) + # logger.info("Use random initialization for lrm, NO ckpt loading") # NOT converge + + ema = ( + EMA( + lrm_model_with_loss, + ema_decay=0.9999, + ) + if args.use_ema + else None + ) + + net_with_grads = TrainOneStepWrapper( + lrm_model_with_loss, + optimizer=optimizer, + scale_sense=loss_scaler, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=ema, + ) + + if args.global_bf16: + model = Model(net_with_grads, amp_level="O0") + else: + model = Model(net_with_grads) + + # 4.3 callbacks + callback = [ + TimeMonitor(), + OverflowMonitor(), + # LossMonitor(log_interval=args.log_interval), + SaveCkptCallback( + rank_id=rank_id, + output_dir=os.path.join(args.output_path, "ckpt"), + ckpt_max_keep=args.ckpt_max_keep, + ckpt_save_interval=args.ckpt_save_interval, + save_ema=args.use_ema, + ckpt_save_policy="top_k", + ), + # ValidationCallback(output_dir=args.output_path) + ] + + if rank_id == 0: + save_cb = EvalSaveCallback( + network=lrm_model_with_loss, + rank_id=rank_id, + ckpt_save_dir=ckpt_dir, + ema=ema, + ckpt_save_policy="latest_k", + ckpt_max_keep=args.ckpt_max_keep, + step_mode=step_mode, + use_step_unit=(args.ckpt_save_steps != -1), + ckpt_save_interval=ckpt_save_interval, + log_interval=args.log_interval, + start_epoch=start_epoch, + model_name="instantmesh_stage1", + record_lr=False, + ) + callback.append(save_cb) + + if args.profile: + callback.append(ProfilerCallbackEpoch(2, 3, "./profile_data")) + + # 5. log and save config + if rank_id == 0: + num_params_lrm, num_params_lrm_trainable = count_params(lrm_model_with_loss) + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"\tMindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"\tDistributed mode: {args.use_parallel}", + f"\tNum params: {num_params_lrm} (lrm: {num_params_lrm})", + f"\tNum trainable params: {num_params_lrm_trainable}", + f"\tLearning rate: {args.start_learning_rate}", + f"\tBatch size: {config.data.batch_size}", + f"\tImage size: {img_size}", + f"\tWeight decay: {args.weight_decay}", + f"\tGrad accumulation steps: {args.gradient_accumulation_steps}", + f"\tNum epochs: {args.epochs}", + f"\tUse model dtype: {args.dtype}", + f"\tMixed precision level: {args.amp_level}", + f"\tLoss scaler: {args.loss_scaler_type}", + f"\tInit loss scale: {args.init_loss_scale}", + f"\tGrad clipping: {args.clip_grad}", + f"\tMax grad norm: {args.max_grad_norm}", + f"\tEMA: {args.use_ema}", + f"\tUse recompute: {config.model.params.lrm_generator_config.params.use_recompute}", + f"\tDataset sink: {args.dataset_sink_mode}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + logger.info("Start training...") + with open(os.path.join(args.output_path, "args.yaml"), "w") as f: + yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + OmegaConf.save(config, os.path.join(args.output_path, "cfg.yaml")) + + # 6. train + logger.info("using the standard fitting api") + model.fit( + sink_epochs, + dataloader, + # valid_dataset=val_dataloader, + callbacks=callback, + dataset_sink_mode=args.dataset_sink_mode, + sink_size=args.sink_size, + initial_epoch=start_epoch, + ) + + +if __name__ == "__main__": + logger.debug("process id:", os.getpid()) + args = parse_args() + main(args) diff --git a/examples/instantmesh/train.sh b/examples/instantmesh/train.sh new file mode 100644 index 0000000000..5bc1e0d079 --- /dev/null +++ b/examples/instantmesh/train.sh @@ -0,0 +1,3 @@ +python train.py --base configs/instant-nerf-large-train.yaml + # --resume YOUR_PATH/last.ckpt + # --resume_lrm_weights \ diff --git a/examples/instantmesh/utils/__init__.py b/examples/instantmesh/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/instantmesh/utils/camera_util.py b/examples/instantmesh/utils/camera_util.py new file mode 100644 index 0000000000..a8a161bcc6 --- /dev/null +++ b/examples/instantmesh/utils/camera_util.py @@ -0,0 +1,119 @@ +import numpy as np + +import mindspore as ms +from mindspore import Tensor, ops + + +def pad_camera_extrinsics_4x4(extrinsics): + if extrinsics.shape[-2] == 4: + return extrinsics + padding = Tensor([[0, 0, 0, 1]]).to(extrinsics.dtype) + if extrinsics.ndim == 3: + padding = padding.unsqueeze(0).tile((extrinsics.shape[0], 1, 1)) + extrinsics = ops.cat([extrinsics, padding], axis=-2) + return extrinsics + + +def center_looking_at_camera_pose(camera_position: Tensor, look_at=None, up_world=None): + """ + Create OpenGL camera extrinsics from camera locations and look-at position. + + camera_position: (M, 3) or (3,) + look_at: (3) + up_world: (3) + return: (M, 3, 4) or (3, 4) + """ + # by default, looking at the origin and world up is z-axis + if look_at is None: + look_at = Tensor([0, 0, 0], dtype=ms.float32) + if up_world is None: + up_world = Tensor([0, 0, 1], dtype=ms.float32) + if camera_position.ndim == 2: + look_at = look_at.unsqueeze(0).tile((camera_position.shape[0], 1)) + up_world = up_world.unsqueeze(0).tile((camera_position.shape[0], 1)) + + # OpenGL camera: z-backward, x-right, y-up + z_axis = camera_position - look_at + norm = ops.L2Normalize(axis=-1) + z_axis = norm(z_axis) + x_axis = ops.cross(up_world, z_axis, dim=-1) + x_axis = norm(x_axis) + y_axis = ops.cross(z_axis, x_axis, dim=-1) + y_axis = norm(y_axis) + print(f"zshape: {z_axis.shape}, xshape: {x_axis.shape}, yshape: {y_axis.shape}") + + extrinsics = ops.stack([x_axis, y_axis, z_axis, camera_position], axis=-1) + print(f"fred: the extrinsics shape of {extrinsics.shape}") + extrinsics = pad_camera_extrinsics_4x4(extrinsics) + return extrinsics + + +def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + xs = radius * np.cos(elevations) * np.cos(azimuths) + ys = radius * np.cos(elevations) * np.sin(azimuths) + zs = radius * np.sin(elevations) + + cam_locations = np.stack([xs, ys, zs], axis=-1) + cam_locations = Tensor.from_numpy(cam_locations).float() + print(f"fred: camloc shape {cam_locations.shape}") + + c2ws = center_looking_at_camera_pose(cam_locations) + return c2ws + + +def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): + # M: number of circular views + # radius: camera dist to center + # elevation: elevation degrees of the camera + # return: (M, 4, 4) + assert M > 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = Tensor.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = Tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]]) + return intrinsics + + +def get_sv3d_input_cameras(bs=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([360 / 21 * i for i in range(21)]) + elevations = np.array([0] * 21).astype(float) + + # tensor + pose_cam2world = spherical_camera_pose(azimuths, elevations, radius) + pose_cam2world = pose_cam2world.float().flatten(start_dim=-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).tile((21, 1, 1)).float().flatten(start_dim=-2) + + extrinsics = pose_cam2world[:, :12] + intrinsics = ops.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], axis=-1) + cameras = ops.cat([extrinsics, intrinsics], axis=-1) + + print(f"cameras dtype is {cameras.dtype}") + + return cameras.unsqueeze(0).tile((int(bs), 1, 1)) diff --git a/examples/instantmesh/utils/eval_util.py b/examples/instantmesh/utils/eval_util.py new file mode 100644 index 0000000000..1c5ebb39b8 --- /dev/null +++ b/examples/instantmesh/utils/eval_util.py @@ -0,0 +1,208 @@ +import math +import os +import sys +from typing import List, Optional, Tuple, Union + +from loguru import logger +from PIL import Image + +import mindspore as ms +from mindspore import mint, ops + +from mindone.utils.seed import set_random_seed + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../../.."))) # for loading mindone + + +def str2bool(b): + if b.lower() not in ["false", "true"]: + raise Exception("Invalid Bool Value") + if b.lower() in ["false"]: + return False + return True + + +def init_inference_env( + mode: int = ms.GRAPH_MODE, + seed: int = 42, + max_device_memory: str = None, + device_target: str = "Ascend", + device_id: int = 7, + jit_level: str = "O0", + debug: bool = False, +): + """ + Init the MS env for inference. + """ + set_random_seed(seed) + if max_device_memory is not None: + ms.set_context(max_device_memory=max_device_memory) + + if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging + logger.warning("Debug mode is on, switching execution mode to PyNative.") + mode = ms.PYNATIVE_MODE + + device_num = 1 + rank_id = 0 + ms.set_context( + mode=mode, + device_target=device_target, + pynative_synchronize=debug, + jit_config={"jit_level": jit_level}, + device_id=device_id, + ) + return rank_id, device_num + + +def make_grid_ms( + tensor: ms.Tensor, + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, +) -> ms.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = mint.stack(tensor, axis=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.shape[0] == 1: # if single-channel, convert to 3-channel + tensor = mint.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images + tensor = mint.cat((tensor, tensor, tensor), 1) + + if normalize is True: + # tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") + + def norm_ip(img, low, high): + img = mint.clamp(img, min=low, max=high) + img = mint.sub(img, low) + img = mint.div(img, max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if not isinstance(tensor, ms.Tensor): + raise TypeError("tensor should be of type ms Tensor") + if tensor.shape[0] == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.shape[0] + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding) + num_channels = tensor.shape[1] + # grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + grid = mint.full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # # Tensor.copy_() is a valid method but seems to be missing from the stubs + # # https://pyms.org/docs/stable/tensors.html#ms.Tensor.copy_ + # # grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + # 2, x * width + padding, width - padding + # # ).copy_(tensor[k]) + + # TODO mint.narrow operation different from torch: the below will only have the last sample passed to grid, comment for now + # grid_portion = mint.narrow(grid, 1, y * height + padding, height - padding) + # grid_portion = mint.narrow(grid_portion, 2, x * width + padding, width - padding) + # grid_portion.copy_(tensor[k]) # assign the kth sample to the grid + + grid[:, y * height + padding : y * height + height, x * width + padding : x * width + width].copy_( + tensor[k] + ) + + k = k + 1 + return grid + + +def save_image_ms( + tensor: Union[ms.Tensor, List[ms.Tensor]], + fp: str, + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + grid = make_grid_ms(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer + ndarr = grid.mul(255).add(0.5).clamp(0, 255).permute(1, 2, 0).to(ms.uint8).numpy() + + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +def get_params_ms(img: ms.Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + _, h, w = img.shape + th, tw = output_size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = ops.randint(0, h - th + 1, size=(1,)).item() + j = ops.randint(0, w - tw + 1, size=(1,)).item() + return i, j, th, tw diff --git a/examples/instantmesh/utils/loss_util.py b/examples/instantmesh/utils/loss_util.py new file mode 100644 index 0000000000..4daa83c453 --- /dev/null +++ b/examples/instantmesh/utils/loss_util.py @@ -0,0 +1,164 @@ +import logging +import os + +import mindcv + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + +_logger = logging.getLogger(__name__) + + +class LPIPS(nn.Cell): + # Learned perceptual metric + def __init__( + self, + use_dropout=True, + normalize=False, + ): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vgg16 features + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.normalize = normalize + # load NetLin metric layers + + # loading HERE before init vgg16 avoids warning during loading this little ckpt + self.load_from_pretrained("YOUR_PATH/lpips_vgg-426bf45c.ckpt") + + # create vision backbone and load pretrained weights + self.net = vgg16(pretrained=True, requires_grad=False) + + # ensure that lpips's param not tuned, but lpips loss still supervises + self.set_train(False) + for param in self.trainable_params(): + param.requires_grad = False + + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + self.lins = nn.CellList(self.lins) + + def load_from_pretrained(self, ckpt_path): + # TODO: just load ms ckpt + if not os.path.exists(ckpt_path): + raise ValueError( + f"{ckpt_path} not exists. Please download it from https: //download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt" + ) + + state_dict = ms.load_checkpoint(ckpt_path) + m, u = ms.load_param_into_net(self, state_dict) + if len(m) > 0: + print("missing keys:") + print(m) + if len(u) > 0: + print("unexpected keys:") + print(u) + + _logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path)) + + def construct(self, input, target): + if self.normalize: + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + else: + in0_input, in1_input = input, target + outs0, outs1 = self.net(in0_input), self.net(in1_input) + val = 0 # ms.Tensor(0, dtype=input.dtype) + for kk in range(len(self.chns)): + diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 + # res += spatial_average(lins[kk](diff), keepdim=True) + # lin_layer = lins[kk] + val += ops.mean(self.lins[kk](diff), axis=[2, 3], keep_dims=True) + return val + + +class ScalingLayer(nn.Cell): + def __init__(self): + super(ScalingLayer, self).__init__() + self.shift = ms.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + self.scale = ms.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + + def construct(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Cell): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False, dtype=ms.float32): + super(NetLinLayer, self).__init__() + # TODO: can parse dtype=dtype in ms2.3 + layers = ( + [ + nn.Dropout(p=0.5).to_float(dtype), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, has_bias=False).to_float(dtype), + ] + self.model = nn.SequentialCell(layers) + + def construct(self, x): + return self.model(x) + + +class vgg16(nn.Cell): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + # TODO: add bias in vgg. use the same model weights in PT. + model = mindcv.create_model( + "vgg16", + pretrained=pretrained # FIXME the mindcv api here seems to be not working, workaround as below + # checkpoint_path='YOUR_PATH/vgg16-95697531.ckpt' # 1GB ckpt + ) + model.set_train(False) + vgg_pretrained_features = model.features + self.slice1 = nn.SequentialCell() + self.slice2 = nn.SequentialCell() + self.slice3 = nn.SequentialCell() + self.slice4 = nn.SequentialCell() + self.slice5 = nn.SequentialCell() + self.N_slices = 5 + for x in range(4): + self.slice1.append(vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.append(vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.append(vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.append(vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.append(vgg_pretrained_features[x]) + if not requires_grad: + for param in self.trainable_params(): + param.requires_grad = False + for param in model.trainable_params(): + param.requires_grad = False + + def construct(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + out = (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keep_dims=keepdim) diff --git a/examples/instantmesh/utils/mesh_util.py b/examples/instantmesh/utils/mesh_util.py new file mode 100644 index 0000000000..05144d5f2d --- /dev/null +++ b/examples/instantmesh/utils/mesh_util.py @@ -0,0 +1,141 @@ +import cv2 +import numpy as np +import trimesh +from PIL import Image + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, "obj") + + +def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): + pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, "glb") + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = "%s/%s.mtl" % (fol, na) + fid = open(matname, "w") + fid.write("newmtl material_0\n") + fid.write("Kd 1 1 1\n") + fid.write("Ka 0 0 0\n") + fid.write("Ks 0.4 0.4 0.4\n") + fid.write("Ns 10\n") + fid.write("illum 2\n") + fid.write("map_Kd %s.png\n" % na) + fid.close() + #### + + fid = open(fname, "w") + fid.write("mtllib %s.mtl\n" % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write("vt %f %f\n" % (pp[0], pp[1])) + + fid.write("usemtl material_0\n") + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), "uint8") + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), "RGB").save(f"{fol}/{na}.png") + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, "r") + for line in meshfp.readlines(): + data = line.strip().split(" ") + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == "v": + v.append([float(d) for d in data[1:]]) + if data[0] == "f": + data = [da.split("/")[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, "r") + for line in meshfp.readlines(): + data = line.strip().split(" ") + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == "v": + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == "vt": + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == "f": + data = [da.split("/") for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 diff --git a/examples/instantmesh/utils/ms_callback_util.py b/examples/instantmesh/utils/ms_callback_util.py new file mode 100644 index 0000000000..2ecf9f8e51 --- /dev/null +++ b/examples/instantmesh/utils/ms_callback_util.py @@ -0,0 +1,112 @@ +import logging +import os +from typing import List, Optional + +from mindspore import Parameter, ParameterTuple, RunContext, Tensor, mint +from mindspore.train import Callback + +from mindone.trainers.checkpoint import CheckpointManager + +__all__ = ["LossMonitor", "SaveCkptCallback"] + +logger = logging.getLogger(__name__) + + +class LossMonitor(Callback): + def __init__(self, log_interval: int = 1) -> None: + self.log_interval = log_interval + + def on_train_step_end(self, run_context: RunContext) -> None: + cb_params = run_context.original_args() + cur_step = cb_params.cur_step_num + step_num = cb_params.batch_num * cb_params.epoch_num + + if (cur_step % self.log_interval == 0) or (cur_step == step_num): + cur_lr = self._fetch_optimizer_lr(cb_params) + cur_loss = self._fetch_loss(cb_params) + cur_loss_scale = self._fetch_loss_scale(cb_params) + + logger.info( + "epoch: %d step: %d, lr: %.7f, loss: %.6f, loss scale: %d.", + cb_params.cur_epoch_num, + (cb_params.cur_step_num - 1) % cb_params.batch_num + 1, + cur_lr.item(), + cur_loss.item(), + cur_loss_scale.item(), + ) + + def _get_optimizer_from_cbp(self, cb_params): + if cb_params.optimizer is not None: + optimizer = cb_params.optimizer + elif cb_params.dataset_sink_mode: + optimizer = cb_params.train_network.network.optimizer + else: + optimizer = cb_params.train_network.optimizer + return optimizer + + def _fetch_loss_scale(self, cb_params) -> Tensor: + if cb_params.dataset_sink_mode: + return cb_params.train_network.network.scale_sense + else: + return cb_params.train_network.scale_sense + + def _fetch_optimizer_lr(self, cb_params) -> Tensor: + opt = self._get_optimizer_from_cbp(cb_params) + lr = opt.learning_rate + if opt.dynamic_lr: + lr = opt.learning_rate(mint.clamp(opt.global_step - 1, min=0))[0] + return lr + + def _fetch_loss(self, cb_params) -> Tensor: + loss = cb_params.net_outputs[0] + return loss + + +class SaveCkptCallback(Callback): + def __init__( + self, + rank_id: Optional[int] = None, + output_dir: str = "./output", + ckpt_max_keep: int = 5, + ckpt_save_interval: int = 1, + save_ema: bool = False, + ckpt_save_policy: str = "latest_k", + ) -> None: + self.rank_id = 0 if rank_id is None else rank_id + if self.rank_id != 0: + return + + self.ckpt_save_interval = ckpt_save_interval + self.save_ema = save_ema + + ckpt_save_dir = os.path.join(output_dir, f"rank_{rank_id}") + if not os.path.isdir(ckpt_save_dir): + os.makedirs(ckpt_save_dir) + self.ckpt_manager = CheckpointManager(ckpt_save_dir, ckpt_save_policy=ckpt_save_policy, k=ckpt_max_keep) + + if self.save_ema: + self.ema_ckpt_manager = CheckpointManager(ckpt_save_dir, ckpt_save_policy=ckpt_save_policy, k=ckpt_max_keep) + + def on_train_epoch_end(self, run_context: RunContext) -> None: + if self.rank_id != 0: + return + + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + epoch_num = cb_params.epoch_num + + if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == epoch_num): + ckpt_name = f"epoch_{cur_epoch}.ckpt" + network_weight = cb_params.train_network.network + self.ckpt_manager.save(network=network_weight, ckpt_name=ckpt_name, perf=cb_params.net_outputs) + if self.save_ema: + ckpt_name = f"epoch_{cur_epoch}_ema.ckpt" + ema_weight = self._drop_ema_prefix(cb_params.train_network.ema.ema_weight) + self.ema_ckpt_manager.save(network=ema_weight, ckpt_name=ckpt_name, perf=cb_params.net_outputs) + + def _drop_ema_prefix(self, weight: ParameterTuple) -> List[Parameter]: + new_weight = list() + for x in weight: + x.name = x.name.replace("ema.", "") + new_weight.append(x) + return new_weight diff --git a/examples/instantmesh/utils/train_util.py b/examples/instantmesh/utils/train_util.py new file mode 100644 index 0000000000..2802304bc4 --- /dev/null +++ b/examples/instantmesh/utils/train_util.py @@ -0,0 +1,34 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6: .2f} M params.") + return total_params + + +def instantiate_from_config(config): + if "target" not in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def str2bool(b): + if b.lower() not in ["false", "true"]: + raise Exception("Invalid Bool Value") + if b.lower() in ["false"]: + return False + return True From 57b66d7fa021b0b415cf3246f316cfc8b5d1dd7d Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 1 Oct 2024 16:38:43 +0800 Subject: [PATCH 02/18] update readme --- examples/instantmesh/README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index af21872ca2..72b6e0e53d 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -12,7 +12,7 @@ A walk-through of the file structure is provided here as below. ```bash -├── instantmesh +├── models │ ├── decoder # triplane feature transformer decoder │ │ └── transformer.py │ ├── encoder # dino vit decoder to extract img feat @@ -21,9 +21,8 @@ A walk-through of the file structure is provided here as below. │ ├── renderer # a wrapper that synthesizes sdf/texture from triplane feat │ │ ├── synthesizer_mesh.py # triplane synthesizer, the triplane feat is decoded thru nerf to predict texture rgb & 3D sdf │ │ ├── synthesizer.py # triplane synthesizer, the triplane feat is decoded thru nerf to predict novel view rgba -│ │ ├── utils -│ │ │ └── renderer.py -│ │ └── synthesizer.py +│ │ └── utils +│ │ └── renderer.py │ ├── geometry # use Flexicubes to extract isosurface │ │ ├── rep_3d │ │ │ ├── flexicubes_geometry.py @@ -75,7 +74,7 @@ pip install -r requirements.txt 2. Inference is tested on the machine with the following specs using 1x NPU: ```text -Mindspore Version: 2.3.0.B528 +Mindspore Version: 2.3.1 release CANN Version: 7.3 Ascend Driver: 23.0.rc3.6 ``` From b4054ed98c9a37b347eb39971b27f7c7741c3fd8 Mon Sep 17 00:00:00 2001 From: HaFred Date: Wed, 2 Oct 2024 17:14:00 +0800 Subject: [PATCH 03/18] putting on renderer utils --- .../models/renderer/utils/math_utils.py | 97 +++++++++++++++++++ .../models/renderer/utils/ray_marcher.py | 53 ++++++++++ .../models/renderer/utils/ray_sampler.py | 90 +++++++++++++++++ examples/instantmesh/train.py | 5 +- 4 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 examples/instantmesh/models/renderer/utils/math_utils.py create mode 100644 examples/instantmesh/models/renderer/utils/ray_marcher.py create mode 100644 examples/instantmesh/models/renderer/utils/ray_sampler.py diff --git a/examples/instantmesh/models/renderer/utils/math_utils.py b/examples/instantmesh/models/renderer/utils/math_utils.py new file mode 100644 index 0000000000..c3e189627b --- /dev/null +++ b/examples/instantmesh/models/renderer/utils/math_utils.py @@ -0,0 +1,97 @@ +import mindspore as ms +from mindspore import mint, ops + + +def transform_vectors(matrix: ms.Tensor, vectors4: ms.Tensor) -> ms.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = mint.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: ms.Tensor) -> ms.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (ops.norm(vectors, dim=-1, keepdim=True)) + + +def get_ray_limits_box( + rays_o: ms.Tensor, + rays_d: ms.Tensor, +): + """ + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + box_side_length = 2 + o_shape = rays_o.shape + rays_o = ops.stop_gradient(rays_o.reshape(-1, 3)) + rays_d = ops.stop_gradient(rays_d.reshape(-1, 3)) + + bb_min = [-1 * (box_side_length / 2), -1 * (box_side_length / 2), -1 * (box_side_length / 2)] + bb_max = [1 * (box_side_length / 2), 1 * (box_side_length / 2), 1 * (box_side_length / 2)] + bounds = ms.Tensor((bb_min, bb_max), dtype=rays_o.dtype) + is_valid = ops.ones(rays_o.shape[:-1], dtype=ms.bool_) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).to(ms.int64) + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[mint.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = mint.maximum(tmin, tymin) + tmax = mint.minimum(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[mint.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = mint.maximum(tmin, tzmin) + tmax = mint.minimum(tmax, tzmax) + + # Mark invalid. + tmin[mint.logical_not(is_valid)] = -1 + tmax[mint.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace( + start: ms.Tensor, + stop: ms.Tensor, + num: int, +): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = mint.arange(num) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", + # hence the code below + # print(f'in math utils, start ndim is {start.ndim}') + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/examples/instantmesh/models/renderer/utils/ray_marcher.py b/examples/instantmesh/models/renderer/utils/ray_marcher.py new file mode 100644 index 0000000000..1cdef88044 --- /dev/null +++ b/examples/instantmesh/models/renderer/utils/ray_marcher.py @@ -0,0 +1,53 @@ +""" The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" +from typing import Dict, Optional + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint, ops + + +class MipRayMarcher2(nn.Cell): + def __init__(self, rendering_options: Optional[Dict] = None): + super().__init__() + self.white_back = rendering_options.get("white_back", False) + + def construct( + self, + colors: ms.Tensor, + densities: ms.Tensor, + depths: ms.Tensor, + ): + dtype = colors.dtype + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = mint.nn.functional.softplus(densities_mid - 1).to(dtype) + + density_delta = densities_mid * deltas + + alpha = 1 - mint.exp(-density_delta).to(dtype) + + alpha_shifted = ops.cat([mint.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) + weights = alpha * ops.cumprod(alpha_shifted, -2)[:, :, :-1] + weights = weights.to(dtype) + + composite_rgb = mint.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = mint.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = ops.nan_to_num(composite_depth, float("inf")).to(dtype) + + if self.white_back: + composite_rgb = composite_rgb + 1 - weight_total + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights diff --git a/examples/instantmesh/models/renderer/utils/ray_sampler.py b/examples/instantmesh/models/renderer/utils/ray_sampler.py new file mode 100644 index 0000000000..9c586b9b79 --- /dev/null +++ b/examples/instantmesh/models/renderer/utils/ray_sampler.py @@ -0,0 +1,90 @@ +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint, ops + + +class RaySampler(nn.Cell): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = ( + None, + None, + None, + None, + None, + ) + + def construct(self, cam2world_matrix: ms.Tensor, intrinsics: ms.Tensor, render_size: int): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + dtype = cam2world_matrix.dtype + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = ops.stack( + ops.meshgrid( + mint.arange(render_size, dtype=dtype), + mint.arange(render_size, dtype=dtype), + indexing="ij", + ) + ) # FIXME mint.stack() builds graph mode fail, bypass with ops.stack first + uv = mint.flip(uv, dims=(0,)).reshape((2, -1)).swapaxes(1, 0) + uv = uv.unsqueeze(0).tile((cam2world_matrix.shape[0], 1, 1)) + + x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) + z_cam = ops.ones((N, M), dtype=dtype) + + x_lift = ( + ( + x_cam + - cx.unsqueeze(-1) + + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) + - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) + ) + / fx.unsqueeze(-1) + * z_cam + ) + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = mint.stack((x_lift, y_lift, z_cam, mint.ones_like(z_cam)), dim=-1).to(dtype) + + _opencv2blender = ( + ms.tensor( + [ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + dtype=dtype, + ) + .unsqueeze(0) + .tile((N, 1, 1)) + ) + + cam2world_matrix = mint.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = mint.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + # ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) + # l2 norm + ray_dirs_denom = ray_dirs.norm(2.0, 2, keepdim=True).clip(min=1e-12).broadcast_to(ray_dirs.shape) + ray_dirs = ray_dirs / ray_dirs_denom + + ray_origins = cam_locs_world.unsqueeze(1).tile((1, ray_dirs.shape[1], 1)) + + return ray_origins, ray_dirs diff --git a/examples/instantmesh/train.py b/examples/instantmesh/train.py index 4cb4ebc968..7e8fac5bd9 100644 --- a/examples/instantmesh/train.py +++ b/examples/instantmesh/train.py @@ -1,4 +1,7 @@ """ InstantMesh Stage-1 Training Script """ + +# TODO training: oct 2: 1. turn off ema; 2. align cosdecay wrestart + import argparse import datetime @@ -30,7 +33,6 @@ from utils.ms_callback_util import SaveCkptCallback from mindone.data import create_dataloader -from mindone.env import init_train_env from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch from mindone.trainers.checkpoint import resume_train_network from mindone.trainers.ema import EMA @@ -39,6 +41,7 @@ from mindone.trainers.train_step import TrainOneStepWrapper from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import instantiate_from_config +from mindone.utils.env import init_train_env from mindone.utils.logger import set_logger from mindone.utils.params import count_params from mindone.utils.seed import set_random_seed From 03ece27bcafbfbf06056187aa400deead286eaf7 Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 8 Oct 2024 15:50:45 +0800 Subject: [PATCH 04/18] fixes about fmt and f-string issue in precommit check and mindone import --- examples/instantmesh/eval.py | 8 ++++---- examples/instantmesh/utils/train_util.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py index 60619d4b58..6d747b9b28 100644 --- a/examples/instantmesh/eval.py +++ b/examples/instantmesh/eval.py @@ -19,7 +19,7 @@ from mindspore import mint __dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../../"))) # for loading mindone +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) # for loading mindone # from loguru import logger import logging @@ -77,7 +77,7 @@ def evaluate(args, epoch_num: Optional[str]): epoch_num = ms.load_checkpoint(args.itmh_ckpt).get( "epoch_num", 0 ) # 0 means that there is no this key in the resume ckpt file - image_path = os.path.join(image_path, f"val_e{epoch_num}.png") + image_path = os.path.join(image_path, f"val_e{epoch_num}.png") validation_step_outputs = [] batches_time = [] @@ -117,14 +117,14 @@ def evaluate(args, epoch_num: Optional[str]): batch_time = time.time() - start_time batches_time.append(batch_time) - logger.info(f"Batch time cost: {batch_time: .3f}s.") + logger.info("Batch time cost: %.3fs.", batch_time) # save result both img and alpha, in validation_step() # render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') render_images = mint.permute(render_images, dims=(0, 2, 3, 1, 4)).flatten(start_dim=-2) validation_step_outputs.append(render_images) mean_time = sum(batches_time) / len(batches_time) - logger.info(f"Mean Batch time: {mean_time: .3f}s.") + logger.info("Mean Batch time: %.3fs.", mean_time) # save mviews outputs images = mint.cat(validation_step_outputs, dim=0) # enable for multiple batches diff --git a/examples/instantmesh/utils/train_util.py b/examples/instantmesh/utils/train_util.py index 2802304bc4..ae8c2e31f1 100644 --- a/examples/instantmesh/utils/train_util.py +++ b/examples/instantmesh/utils/train_util.py @@ -4,7 +4,7 @@ def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6: .2f} M params.") + print(f"{model.__class__.__name__} has {total_params*1.e-6} M params.") return total_params From 35d27c3009287d8b7fe6d77b70a23ced68cf7f23 Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 8 Oct 2024 16:13:18 +0800 Subject: [PATCH 05/18] supporting cosine_annealing_warm_restarts_lr and top_k saving ckptcallback, and some refactoring --- .../instantmesh/configs/instant-nerf-large-train.yaml | 5 ++++- .../instantmesh/models/renderer/utils/renderer.py | 11 +---------- examples/instantmesh/utils/loss_util.py | 6 +----- mindone/trainers/lr_schedule.py | 6 ++++++ mindone/trainers/train_step.py | 4 ++++ 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/instantmesh/configs/instant-nerf-large-train.yaml b/examples/instantmesh/configs/instant-nerf-large-train.yaml index cb848b3556..a8829335ab 100644 --- a/examples/instantmesh/configs/instant-nerf-large-train.yaml +++ b/examples/instantmesh/configs/instant-nerf-large-train.yaml @@ -1,5 +1,8 @@ model: base_learning_rate: 4.0e-04 + scheduler: cosine_annealing_warm_restarts_lr + optimizer: adamw + weight_decay: 0.01 target: model_stage1.InstantMeshStage1WithLoss params: input_size: 320 @@ -7,7 +10,7 @@ model: # render_size: 96 lrm_generator_config: openlrm_ckpt: 'YOUR_PATH/openlrm.ckpt' - target: models3d.lrm.InstantNeRF + target: models.lrm.InstantNeRF params: encoder_feat_dim: 768 encoder_freeze: false diff --git a/examples/instantmesh/models/renderer/utils/renderer.py b/examples/instantmesh/models/renderer/utils/renderer.py index be1d98049f..fd4de5fb5e 100644 --- a/examples/instantmesh/models/renderer/utils/renderer.py +++ b/examples/instantmesh/models/renderer/utils/renderer.py @@ -4,8 +4,6 @@ import itertools from typing import Dict -from sgm.util import Inverse - import mindspore as ms import mindspore.nn as nn from mindspore import _no_grad, mint, ops @@ -75,8 +73,6 @@ def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False): self.plane_axes = generate_planes().astype(dtype) self.max_pool1d_layer = nn.MaxPool1d(2, 1, pad_mode="pad", padding=1) self.avg_pool1d_layer = nn.AvgPool1d(2, 1) - - self.inverse_operator = Inverse() # workaournd for the case that current amp not working on mint/ops.inverse self.decoder = OSGDecoder(n_features=80) # triplane_dim self.debug_logging = debug @@ -100,12 +96,7 @@ def project_onto_planes( coordinates = coordinates.unsqueeze(1) coordinates = coordinates.broadcast_to((-1, n_planes, -1, -1)).reshape(N * n_planes, M, 3) - # cast to fp32 as the inverse operator requests, ops.inverse does not support amp - # May report an issue to the ms/mindone about the amp blacklist: - # TypeError: For primitive[MatrixInverse], the input argument[x] must be a type of {Tensor[Complex128], - # Tensor[Complex64], Tensor[Float32], Tensor[Float64]}, but got Tensor[Float16]. - # inv_planes = ops.inverse(planes.to(ms.float32)).unsqueeze(0) - inv_planes = self.inverse_operator(planes.to(ms.float32)).unsqueeze(0) + inv_planes = mint.inverse(planes.to(ms.float32)).unsqueeze(0) inv_planes = inv_planes.broadcast_to((N, -1, -1, -1)).reshape(N * n_planes, 3, 3) projections = mint.bmm(coordinates, inv_planes.to(planes.dtype)) diff --git a/examples/instantmesh/utils/loss_util.py b/examples/instantmesh/utils/loss_util.py index 4daa83c453..035bf72f41 100644 --- a/examples/instantmesh/utils/loss_util.py +++ b/examples/instantmesh/utils/loss_util.py @@ -111,11 +111,7 @@ class vgg16(nn.Cell): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() # TODO: add bias in vgg. use the same model weights in PT. - model = mindcv.create_model( - "vgg16", - pretrained=pretrained # FIXME the mindcv api here seems to be not working, workaround as below - # checkpoint_path='YOUR_PATH/vgg16-95697531.ckpt' # 1GB ckpt - ) + model = mindcv.create_model("vgg16", pretrained=pretrained) model.set_train(False) vgg_pretrained_features = model.features self.slice1 = nn.SequentialCell() diff --git a/mindone/trainers/lr_schedule.py b/mindone/trainers/lr_schedule.py index f3a2a1649f..a895a06725 100644 --- a/mindone/trainers/lr_schedule.py +++ b/mindone/trainers/lr_schedule.py @@ -15,6 +15,8 @@ """Learning Rate Scheduler Factory""" import logging +from mindcv.scheduler.dynamic_lr import cosine_annealing_warm_restarts_lr + from .dynamic_lr import cosine_decay_refined_lr, linear_refined_lr, multi_step_lr, polynomial_refined_lr _logger = logging.getLogger(__name__) @@ -97,6 +99,10 @@ def create_scheduler( main_lr_scheduler = multi_step_lr(milestones=milestones, gamma=decay_rate, lr=lr, total_steps=main_steps) elif name == "constant": main_lr_scheduler = [lr for _ in range(main_steps)] + elif name == "cosine_annealing_warm_restarts_lr": + main_lr_scheduler = cosine_annealing_warm_restarts_lr( + te=3000, tm=1, eta_min=lr / 10, eta_max=lr, steps_per_epoch=steps_per_epoch, epochs=main_steps + ) else: raise ValueError(f"Invalid scheduler: {name}") diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index 384d70eed5..ecbb2bc75a 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -99,6 +99,10 @@ def __init__( def construct(self, *inputs): # compute loss weights = self.weights + + # for solving grad err PyNative mode in ms2.4.0: RuntimeErr: value is not supported, valBaseRef + self.network.set_inputs(*inputs) + loss = self.network(*inputs) # mini-batch loss scaling_sens = self.scale_sense From 1fc421494951647f6f03f1c9f4a81e3f590ceba0 Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 15 Oct 2024 09:45:54 +0800 Subject: [PATCH 06/18] housekeeping --- examples/instantmesh/eval.py | 5 +- examples/instantmesh/train.py | 4 -- .../instantmesh/utils/ms_callback_util.py | 54 +------------------ mindone/trainers/callback.py | 10 ++-- 4 files changed, 12 insertions(+), 61 deletions(-) diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py index 6d747b9b28..fc3ffbdf66 100644 --- a/examples/instantmesh/eval.py +++ b/examples/instantmesh/eval.py @@ -85,7 +85,10 @@ def evaluate(args, epoch_num: Optional[str]): val_batch_np = data.__getitem__(index) # [torch] prepare_validation_batch_data(): - # prepare for validation batch/mv2mesh inference(): see raw repo TODO del this comment once this script finishes + # prepare for validation batch/mv2mesh inference() + # TODO mov this to data + # although in torch this method is under model.py, but for torch lightning, + # it's actually implemented as a part of dataset val_input = model.prepare_validation_batch_data( val_batch_np, render_size=config.eval_render_size, diff --git a/examples/instantmesh/train.py b/examples/instantmesh/train.py index 7e8fac5bd9..beb86b3f13 100644 --- a/examples/instantmesh/train.py +++ b/examples/instantmesh/train.py @@ -27,8 +27,6 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) # for sgm from model_stage1 import InstantMeshStage1WithLoss - -# from eval import ValidationCallback, LossMonitor from omegaconf import OmegaConf from utils.ms_callback_util import SaveCkptCallback @@ -381,7 +379,6 @@ def main(args): callback = [ TimeMonitor(), OverflowMonitor(), - # LossMonitor(log_interval=args.log_interval), SaveCkptCallback( rank_id=rank_id, output_dir=os.path.join(args.output_path, "ckpt"), @@ -390,7 +387,6 @@ def main(args): save_ema=args.use_ema, ckpt_save_policy="top_k", ), - # ValidationCallback(output_dir=args.output_path) ] if rank_id == 0: diff --git a/examples/instantmesh/utils/ms_callback_util.py b/examples/instantmesh/utils/ms_callback_util.py index 2ecf9f8e51..556eb9f83b 100644 --- a/examples/instantmesh/utils/ms_callback_util.py +++ b/examples/instantmesh/utils/ms_callback_util.py @@ -2,66 +2,16 @@ import os from typing import List, Optional -from mindspore import Parameter, ParameterTuple, RunContext, Tensor, mint +from mindspore import Parameter, ParameterTuple, RunContext from mindspore.train import Callback from mindone.trainers.checkpoint import CheckpointManager -__all__ = ["LossMonitor", "SaveCkptCallback"] +__all__ = ["SaveCkptCallback"] logger = logging.getLogger(__name__) -class LossMonitor(Callback): - def __init__(self, log_interval: int = 1) -> None: - self.log_interval = log_interval - - def on_train_step_end(self, run_context: RunContext) -> None: - cb_params = run_context.original_args() - cur_step = cb_params.cur_step_num - step_num = cb_params.batch_num * cb_params.epoch_num - - if (cur_step % self.log_interval == 0) or (cur_step == step_num): - cur_lr = self._fetch_optimizer_lr(cb_params) - cur_loss = self._fetch_loss(cb_params) - cur_loss_scale = self._fetch_loss_scale(cb_params) - - logger.info( - "epoch: %d step: %d, lr: %.7f, loss: %.6f, loss scale: %d.", - cb_params.cur_epoch_num, - (cb_params.cur_step_num - 1) % cb_params.batch_num + 1, - cur_lr.item(), - cur_loss.item(), - cur_loss_scale.item(), - ) - - def _get_optimizer_from_cbp(self, cb_params): - if cb_params.optimizer is not None: - optimizer = cb_params.optimizer - elif cb_params.dataset_sink_mode: - optimizer = cb_params.train_network.network.optimizer - else: - optimizer = cb_params.train_network.optimizer - return optimizer - - def _fetch_loss_scale(self, cb_params) -> Tensor: - if cb_params.dataset_sink_mode: - return cb_params.train_network.network.scale_sense - else: - return cb_params.train_network.scale_sense - - def _fetch_optimizer_lr(self, cb_params) -> Tensor: - opt = self._get_optimizer_from_cbp(cb_params) - lr = opt.learning_rate - if opt.dynamic_lr: - lr = opt.learning_rate(mint.clamp(opt.global_step - 1, min=0))[0] - return lr - - def _fetch_loss(self, cb_params) -> Tensor: - loss = cb_params.net_outputs[0] - return loss - - class SaveCkptCallback(Callback): def __init__( self, diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index 08b39c0751..b376e59a9b 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -267,7 +267,9 @@ def on_train_epoch_end(self, run_context): self.ema.swap_before_eval() # save history checkpoints - self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) + self.ckpt_manager.save( + self.net_to_save, perf=cb_params["net_outputs"], ckpt_name=ckpt_name, append_dict=append_dict + ) if self.save_training_resume: ms.save_checkpoint( @@ -289,16 +291,16 @@ def on_train_epoch_end(self, run_context): def on_train_end(self, run_context): if self.is_main_device: if self.ckpt_save_policy == "top_k": - log_str = f"Top K checkpoints:\n{self.main_indicator}\tcheckpoint\n" + log_str = "Top K checkpoints:\n %s\tcheckpoint\n" % self.main_indicator for p, ckpt_name in self.ckpt_manager.get_ckpt_queue(): - log_str += f"{p:.4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" + log_str += "%.4f\t\n%s" % p, {os.path.join(self.ckpt_save_dir, ckpt_name)} def on_eval_end(self, run_context): if self.is_main_device: cb_params = run_context.original_args() metrics = cb_params.get("metrics") if metrics is not None: - metrics = {k: f"{v:.4f}" for k, v in metrics.items()} + metrics = {k: "%.4f" % v for k, v in metrics.items()} _logger.info(f"Eval result epoch {cb_params.cur_epoch_num}: {metrics}") def _get_optimizer_from_cbp(self, cb_params): From 70663daf0b4adb3f330a2d162a64cc2cef02159b Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 15 Oct 2024 10:06:43 +0800 Subject: [PATCH 07/18] revert to f-string while meeting flake8 constraints --- mindone/trainers/callback.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index b376e59a9b..6e8fe9f50a 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -291,16 +291,16 @@ def on_train_epoch_end(self, run_context): def on_train_end(self, run_context): if self.is_main_device: if self.ckpt_save_policy == "top_k": - log_str = "Top K checkpoints:\n %s\tcheckpoint\n" % self.main_indicator + log_str = f"Top K checkpoints: \n{self.main_indicator}\tcheckpoint\n" for p, ckpt_name in self.ckpt_manager.get_ckpt_queue(): - log_str += "%.4f\t\n%s" % p, {os.path.join(self.ckpt_save_dir, ckpt_name)} + log_str += f"{p: .4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n" def on_eval_end(self, run_context): if self.is_main_device: cb_params = run_context.original_args() metrics = cb_params.get("metrics") if metrics is not None: - metrics = {k: "%.4f" % v for k, v in metrics.items()} + metrics = {k: f"{v: .4f}" for k, v in metrics.items()} _logger.info(f"Eval result epoch {cb_params.cur_epoch_num}: {metrics}") def _get_optimizer_from_cbp(self, cb_params): From 7073b97b4e79a0f4742cd7c98a71e77c812a0e33 Mon Sep 17 00:00:00 2001 From: Frederick Hong Date: Fri, 18 Oct 2024 16:00:20 +0800 Subject: [PATCH 08/18] Update README.md --- examples/instantmesh/README.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index 72b6e0e53d..a390c199cd 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -1,9 +1,14 @@ # InstantMesh: 3D Mesh Generation from Multiview Images -- [ ] Elaborate on the design methdology in the intro on Oct 2 -- [ ] Training part +- [x] Elaborate on the design methodology in the intro on Oct 2 +- [x] Training part -This `instantmesh` module under `.../sv3d/models` is implemented for the 3D mesh generation using the multiview images extracted from [the sv3d pipeline](https://github.com/mindspore-lab/mindone/pull/574). +We support [instantmesh](https://github.com/TencentARC/InstantMesh) for the 3D mesh generation using the multiview images extracted from [the sv3d pipeline](https://github.com/mindspore-lab/mindone/pull/574). +

+ Capture +

+ +The model consists of a Dino-ViT feature extractor, a triplane feature extraction transformer, and a triplane-to-NeRF synthesizer which also conducts rendering. A walk-through of the file structure is provided here as below. @@ -93,7 +98,12 @@ python inference.py --ckpt PATH_TO_CKPT \ ``` ## Training +```shell +python train.py --base configs/YOUR_CFG +``` + ### Data Curation +We used Blender to render multiview frames for a 3D object in `.obj` for training. ## Acknowledgements From f939579c01e74af021e01e8c60bef0d8141f06a3 Mon Sep 17 00:00:00 2001 From: HaFred Date: Fri, 25 Oct 2024 10:29:09 +0800 Subject: [PATCH 09/18] lpips loss alignment --- examples/instantmesh/README.md | 5 +++++ examples/instantmesh/utils/loss_util.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index a390c199cd..2ba981621e 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -101,6 +101,11 @@ python inference.py --ckpt PATH_TO_CKPT \ ```shell python train.py --base configs/YOUR_CFG ``` +One needs to patch `mindcv.models.vgg` in L62 to enable conv kernel bias to align with the torchmetric implementation of lpips loss. +```diff +- conv2d = nn.Conv2d(in_channels, v, kernel_size=3, pad_mode="pad", padding=1) ++ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, pad_mode="pad", padding=1, has_bias=True) +``` ### Data Curation We used Blender to render multiview frames for a 3D object in `.obj` for training. diff --git a/examples/instantmesh/utils/loss_util.py b/examples/instantmesh/utils/loss_util.py index 035bf72f41..630bcc9684 100644 --- a/examples/instantmesh/utils/loss_util.py +++ b/examples/instantmesh/utils/loss_util.py @@ -12,11 +12,7 @@ class LPIPS(nn.Cell): # Learned perceptual metric - def __init__( - self, - use_dropout=True, - normalize=False, - ): + def __init__(self, use_dropout=True, normalize=False, pretrained_vgg_mindcv=False): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vgg16 features @@ -32,7 +28,11 @@ def __init__( self.load_from_pretrained("YOUR_PATH/lpips_vgg-426bf45c.ckpt") # create vision backbone and load pretrained weights - self.net = vgg16(pretrained=True, requires_grad=False) + self.net = vgg16( + pretrained=pretrained_vgg_mindcv, + ckpt_path="./th-vgg16-397923af.ckpt", # torch ckpt different from mindcv ckpt + requires_grad=False, + ) # ensure that lpips's param not tuned, but lpips loss still supervises self.set_train(False) @@ -108,10 +108,10 @@ def construct(self, x): class vgg16(nn.Cell): - def __init__(self, requires_grad=False, pretrained=True): + def __init__(self, requires_grad=False, pretrained=True, ckpt_path=None): super(vgg16, self).__init__() # TODO: add bias in vgg. use the same model weights in PT. - model = mindcv.create_model("vgg16", pretrained=pretrained) + model = mindcv.create_model("vgg16", pretrained=pretrained, checkpoint_path=ckpt_path if not pretrained else "") model.set_train(False) vgg_pretrained_features = model.features self.slice1 = nn.SequentialCell() From 0eb76416370f779af7a909783dd76756c2d2d787 Mon Sep 17 00:00:00 2001 From: HaFred Date: Fri, 25 Oct 2024 11:13:14 +0800 Subject: [PATCH 10/18] put on mindcv version --- examples/instantmesh/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/instantmesh/requirements.txt b/examples/instantmesh/requirements.txt index ed08290f0f..0f5ec83a77 100644 --- a/examples/instantmesh/requirements.txt +++ b/examples/instantmesh/requirements.txt @@ -5,3 +5,4 @@ loguru einops huggingface_hub omegaconf +mindcv==0.3.0 From efdd4237b2b0199324471cb87097fb2dc6df8b73 Mon Sep 17 00:00:00 2001 From: HaFred Date: Fri, 25 Oct 2024 16:23:04 +0800 Subject: [PATCH 11/18] update the --- examples/instantmesh/README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index 2ba981621e..c2ab0e96cb 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -68,7 +68,7 @@ Using the multiview images input from 3D mesh extracted from [the sv3d pipeline] The illustrations here are better viewed in viewers than with HTML support (e.g., the vscode built-in viewer). -## Environments +## Environment Requirements 1. To kickstart: @@ -78,11 +78,9 @@ pip install -r requirements.txt 2. Inference is tested on the machine with the following specs using 1x NPU: -```text -Mindspore Version: 2.3.1 release -CANN Version: 7.3 -Ascend Driver: 23.0.rc3.6 -``` +| mindspore | ascend driver | firmware | cann toolkit/kernel | +| :--- | :--- | :--- | :--- | +| 2.3.1 | 24.1.RC2 |7.3.0.1.231 | 8.0.RC2.beta1 | ## Pretrained Models From f48983c1bf7a828377a7fc7e34328ca09c8605f6 Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 29 Oct 2024 12:36:01 +0800 Subject: [PATCH 12/18] eval output to the same path as the loaded ckpt, also some housekeeping for loggers --- examples/instantmesh/eval.py | 34 ++++--------------- examples/instantmesh/model_stage1.py | 2 +- examples/instantmesh/train.py | 2 +- examples/instantmesh/utils/eval_util.py | 17 ++-------- examples/instantmesh/utils/loss_util.py | 2 +- .../instantmesh/utils/ms_callback_util.py | 2 +- 6 files changed, 12 insertions(+), 47 deletions(-) diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py index fc3ffbdf66..2cc1dddebb 100644 --- a/examples/instantmesh/eval.py +++ b/examples/instantmesh/eval.py @@ -1,14 +1,3 @@ -"""Eval script using the model stage 1 trained ckpt to conduct arbitral novel view synthesis. - -Design methdology: Unlike the ms translation that has been done for training, -we make the eval here more similar to the inference script below with np utilization. -Because for training, the np data proc parts should be translated into ms as much as possible -(see ~/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py), -but not for the inference. - -Thus we can do np for all data proc here to avoid tedious ms tranlsation of data. -Refer to inference.py for the full stage inference. -""" import argparse import datetime import os @@ -20,8 +9,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) # for loading mindone -# from loguru import logger -import logging +from typing import Optional from omegaconf import OmegaConf from transformers import ViTImageProcessor @@ -31,10 +19,6 @@ from mindone.utils.logger import set_logger from mindone.utils.seed import set_random_seed -logger = logging.getLogger(__name__) - -from typing import Optional - def evaluate(args, epoch_num: Optional[str]): if args.append_timestr: @@ -55,7 +39,7 @@ def evaluate(args, epoch_num: Optional[str]): debug=args.debug, ) set_random_seed(42) - set_logger(name="", output_dir=args.output_path, rank=rid, log_level=eval(args.log_level)) + logger = set_logger(name="", output_dir=args.output_path, rank=rid, log_level=eval(args.log_level)) # valdata preparation config = OmegaConf.load(args.datacfg) @@ -74,9 +58,7 @@ def evaluate(args, epoch_num: Optional[str]): image_path = os.path.join(image_path, "val_official_instantmesh_ckpt.png") else: if not epoch_num: # train_resume.ckpt - epoch_num = ms.load_checkpoint(args.itmh_ckpt).get( - "epoch_num", 0 - ) # 0 means that there is no this key in the resume ckpt file + epoch_num = ms.load_checkpoint(args.itmh_ckpt)["epoch_num"].item() image_path = os.path.join(image_path, f"val_e{epoch_num}.png") validation_step_outputs = [] @@ -84,7 +66,6 @@ def evaluate(args, epoch_num: Optional[str]): for index in range(0, valset_len, args.batch_size): val_batch_np = data.__getitem__(index) - # [torch] prepare_validation_batch_data(): # prepare for validation batch/mv2mesh inference() # TODO mov this to data # although in torch this method is under model.py, but for torch lightning, @@ -96,7 +77,6 @@ def evaluate(args, epoch_num: Optional[str]): ) images = val_input["images"] - # [torch] forward(): # RGB image with [0,1] scale and properly sized requested by the ViTImgProc if images.ndim == 5: (B, N, C, H, W) = images.shape @@ -130,11 +110,7 @@ def evaluate(args, epoch_num: Optional[str]): logger.info("Mean Batch time: %.3fs.", mean_time) # save mviews outputs images = mint.cat(validation_step_outputs, dim=0) # enable for multiple batches - - # images = rearrange(images, 'r b c h w -> (r b) c h w') - # images = images.flatten(start_dim=0, end_dim=1) assert len(images.shape) == 4, "images' shape not matched" - grid = make_grid_ms(images, nrow=1, normalize=True, value_range=(0, 1)) if not args.debug: save_image_ms(grid, image_path) @@ -198,8 +174,10 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - if args.itmh_ckpt.split("/")[-1] == "train_resume.ckpt": + ckpt_str_l = args.itmh_ckpt.split("/") + if ckpt_str_l[-1] == "train_resume.ckpt": epoch_num = None + args.output_path = os.path.join("output", ckpt_str_l[-3]) else: epoch_num = args.itmh_ckpt.split("-e")[-1].split(".")[0] evaluate(args, epoch_num) diff --git a/examples/instantmesh/model_stage1.py b/examples/instantmesh/model_stage1.py index ba130aaf5e..2d04907086 100644 --- a/examples/instantmesh/model_stage1.py +++ b/examples/instantmesh/model_stage1.py @@ -2,7 +2,7 @@ import logging import os -logger = logging.getLogger(__name__) +logger = logging.getLogger("") import numpy as np from einops import rearrange diff --git a/examples/instantmesh/train.py b/examples/instantmesh/train.py index beb86b3f13..c97d4ae8fc 100644 --- a/examples/instantmesh/train.py +++ b/examples/instantmesh/train.py @@ -18,7 +18,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import TimeMonitor -logger = logging.getLogger(__name__) +logger = logging.getLogger("") import sys diff --git a/examples/instantmesh/utils/eval_util.py b/examples/instantmesh/utils/eval_util.py index 1c5ebb39b8..cf5fad7fcb 100644 --- a/examples/instantmesh/utils/eval_util.py +++ b/examples/instantmesh/utils/eval_util.py @@ -123,8 +123,6 @@ def norm_range(t, value_range): if not isinstance(tensor, ms.Tensor): raise TypeError("tensor should be of type ms Tensor") - if tensor.shape[0] == 1: - return tensor.squeeze(0) # make the mini-batch of images into a grid nmaps = tensor.shape[0] @@ -137,19 +135,8 @@ def norm_range(t, value_range): k = 0 for y in range(ymaps): for x in range(xmaps): - if k >= nmaps: - break - # # Tensor.copy_() is a valid method but seems to be missing from the stubs - # # https://pyms.org/docs/stable/tensors.html#ms.Tensor.copy_ - # # grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] - # 2, x * width + padding, width - padding - # # ).copy_(tensor[k]) - - # TODO mint.narrow operation different from torch: the below will only have the last sample passed to grid, comment for now - # grid_portion = mint.narrow(grid, 1, y * height + padding, height - padding) - # grid_portion = mint.narrow(grid_portion, 2, x * width + padding, width - padding) - # grid_portion.copy_(tensor[k]) # assign the kth sample to the grid - + # if k >= nmaps: + # break grid[:, y * height + padding : y * height + height, x * width + padding : x * width + width].copy_( tensor[k] ) diff --git a/examples/instantmesh/utils/loss_util.py b/examples/instantmesh/utils/loss_util.py index 630bcc9684..51c93d5eb0 100644 --- a/examples/instantmesh/utils/loss_util.py +++ b/examples/instantmesh/utils/loss_util.py @@ -7,7 +7,7 @@ import mindspore.nn as nn import mindspore.ops as ops -_logger = logging.getLogger(__name__) +_logger = logging.getLogger("") class LPIPS(nn.Cell): diff --git a/examples/instantmesh/utils/ms_callback_util.py b/examples/instantmesh/utils/ms_callback_util.py index 556eb9f83b..08ba80f8ae 100644 --- a/examples/instantmesh/utils/ms_callback_util.py +++ b/examples/instantmesh/utils/ms_callback_util.py @@ -9,7 +9,7 @@ __all__ = ["SaveCkptCallback"] -logger = logging.getLogger(__name__) +logger = logging.getLogger("") class SaveCkptCallback(Callback): From 0dfbf99ddf40a2d44cda9665b0d29a2edf6d2df2 Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 29 Oct 2024 14:09:28 +0800 Subject: [PATCH 13/18] fix the ckpt saving path cfg --- examples/instantmesh/eval.py | 1 - examples/instantmesh/inference.py | 1 - examples/instantmesh/train.py | 13 +++++-------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py index 2cc1dddebb..99b27bf8fa 100644 --- a/examples/instantmesh/eval.py +++ b/examples/instantmesh/eval.py @@ -122,7 +122,6 @@ def parse_args(): parser.add_argument("--itmh_ckpt", default="CKPT_PATH") parser.add_argument( "--debug", - # default=True, # also setting debug as true will set pynative sync as true as well default=False, # also setting debug as true will set pynative sync as true as well help="When debugging, set it true, to avoid saving too many ckpts and burn out the storage.", ) diff --git a/examples/instantmesh/inference.py b/examples/instantmesh/inference.py index 585bacf854..9c66093ea1 100644 --- a/examples/instantmesh/inference.py +++ b/examples/instantmesh/inference.py @@ -67,7 +67,6 @@ def construct(self, inputs: ms.Tensor, radius: float) -> Tensor: ms.set_context( mode=1, device_target="Ascend", - # device_target='CPU', device_id=6, pynative_synchronize=True, ) diff --git a/examples/instantmesh/train.py b/examples/instantmesh/train.py index c97d4ae8fc..f7444b1370 100644 --- a/examples/instantmesh/train.py +++ b/examples/instantmesh/train.py @@ -135,7 +135,6 @@ def parse_train_args(parser): ) parser.add_argument( "--epochs", - # default=3, default=7000, type=int, help="epochs. If dataset_sink_mode is on, epochs is with respect to dataset sink size. Otherwise, it's w.r.t the dataset size.", @@ -192,11 +191,12 @@ def parse_train_args(parser): def main(args): time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - if not args.resume: - args.output_path = os.path.join(args.output_path, time_str) if not args.debug else args.output_path + if args.resume: + args.output_path = args.resume + elif not args.debug: + args.output_path = os.path.join(args.output_path, time_str) else: - # FIXME now when resume the output path overwrites the debug path, fix it... - pass + print("make sure you are debugging now, as no ckpt will be saved.") # 1. init did, rank_id, device_num = init_train_env( @@ -241,8 +241,6 @@ def main(args): python_multiprocessing=args.data_multiprocessing, max_rowsize=args.max_rowsize, debug=False, # ms240_sept4: THIS CANNOT BE TRUE, OTHERWISE loader error - # Sort output columns to match DiffusionWithLoss input - # project_columns=project_columns, # not sure input/target frames data should use this option ornot ) dataset_size = dataloader.get_dataset_size() @@ -449,7 +447,6 @@ def main(args): model.fit( sink_epochs, dataloader, - # valid_dataset=val_dataloader, callbacks=callback, dataset_sink_mode=args.dataset_sink_mode, sink_size=args.sink_size, From 53e2216319a4d24434bc511e75ec1f51551b4aba Mon Sep 17 00:00:00 2001 From: HaFred Date: Tue, 29 Oct 2024 17:19:16 +0800 Subject: [PATCH 14/18] update cfg --- examples/instantmesh/configs/instant-nerf-large-train.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/instantmesh/configs/instant-nerf-large-train.yaml b/examples/instantmesh/configs/instant-nerf-large-train.yaml index a8829335ab..1465930540 100644 --- a/examples/instantmesh/configs/instant-nerf-large-train.yaml +++ b/examples/instantmesh/configs/instant-nerf-large-train.yaml @@ -7,14 +7,12 @@ model: params: input_size: 320 render_size: 192 - # render_size: 96 lrm_generator_config: openlrm_ckpt: 'YOUR_PATH/openlrm.ckpt' target: models.lrm.InstantNeRF params: encoder_feat_dim: 768 encoder_freeze: false - # encoder_model_name: facebook/dino-vitb16 encoder_model_name: 'YOUR_PATH_HF/models--facebook--dino-vitb16/snapshots/f205d5d8e640a89a2b8ef0369670dfc37cc07fc2' # coz needs to enforce the is_local flag (with pretrained_model_name_or_path as dir), thus here put in the abs path as a workaround transformer_dim: 1024 transformer_layers: 16 @@ -22,10 +20,10 @@ model: triplane_low_res: 32 triplane_high_res: 64 triplane_dim: 80 - rendering_samples_per_ray: 64 # official ckpt is 128, if loaded pretrained make sure it's 128 + rendering_samples_per_ray: 64 # for the vanilla ckpt is 128, if loaded pretrained make sure it's 128 use_recompute: true -eval_render_size: 96 # large may have oom +eval_render_size: 96 # larger may lead to oom in eval.py data: batch_size: 1 From e2c29e146334e23ebb24c251bd142eb69510a687 Mon Sep 17 00:00:00 2001 From: HaFred Date: Wed, 30 Oct 2024 18:07:55 +0800 Subject: [PATCH 15/18] update arch to support loading vanilla stage 1 ckpt and ms-trained stage 1 ckpt with eval.py seamlessly --- examples/instantmesh/eval.py | 2 - .../models/renderer/synthesizer.py | 10 ++- .../models/renderer/utils/renderer.py | 71 +------------------ 3 files changed, 7 insertions(+), 76 deletions(-) diff --git a/examples/instantmesh/eval.py b/examples/instantmesh/eval.py index 99b27bf8fa..16f7e36b7b 100644 --- a/examples/instantmesh/eval.py +++ b/examples/instantmesh/eval.py @@ -66,8 +66,6 @@ def evaluate(args, epoch_num: Optional[str]): for index in range(0, valset_len, args.batch_size): val_batch_np = data.__getitem__(index) - # prepare for validation batch/mv2mesh inference() - # TODO mov this to data # although in torch this method is under model.py, but for torch lightning, # it's actually implemented as a part of dataset val_input = model.prepare_validation_batch_data( diff --git a/examples/instantmesh/models/renderer/synthesizer.py b/examples/instantmesh/models/renderer/synthesizer.py index 0e8af648bc..859a23a5e3 100644 --- a/examples/instantmesh/models/renderer/synthesizer.py +++ b/examples/instantmesh/models/renderer/synthesizer.py @@ -25,7 +25,6 @@ def __init__( hidden_dim: int = 64, num_layers: int = 4, activation: nn.Cell = nn.ReLU, - use_recompute: bool = True, ): super().__init__() self.net = nn.SequentialCell( @@ -98,16 +97,15 @@ def __init__( "depth_resolution_importance": dep_res_imp, } + # nerf decoder + self.decoder = OSGDecoder(n_features=triplane_dim) + # renderings - self.renderer = ImportanceRenderer(self.rendering_kwargs, dtype=dtype) + self.renderer = ImportanceRenderer(self.rendering_kwargs, dtype=dtype, decoder=self.decoder) self.ray_sampler = RaySampler() - # modules - self.decoder = OSGDecoder(n_features=triplane_dim) - if use_recompute: self.renderer.recompute() - self.decoder.recompute() # @ms.jit # now has the error in the renderer: Exceed function call depth limit 1000, (function call depth: 1001, simulate call depth: 508). def construct(self, planes, cameras, render_size, crop_params): diff --git a/examples/instantmesh/models/renderer/utils/renderer.py b/examples/instantmesh/models/renderer/utils/renderer.py index fd4de5fb5e..98ced4bd97 100644 --- a/examples/instantmesh/models/renderer/utils/renderer.py +++ b/examples/instantmesh/models/renderer/utils/renderer.py @@ -1,8 +1,7 @@ """ The renderer is a module that takes in rays, decides where to sample along each ray, and computes pixel colors using the volume rendering equation. """ -import itertools -from typing import Dict +from typing import Dict, Optional import mindspore as ms import mindspore.nn as nn @@ -12,55 +11,10 @@ from .ray_marcher import MipRayMarcher2 -class OSGDecoder(nn.Cell): - """ - Triplane decoder that gives RGB and sigma values from sampled features. - Using ReLU here instead of Softplus in the original implementation. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 - """ - - def __init__(self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Cell = nn.ReLU): - super().__init__() - self.net = nn.SequentialCell( - nn.Dense(3 * n_features, hidden_dim), - activation(), - *itertools.chain( - *[ - [ - nn.Dense(hidden_dim, hidden_dim), - activation(), - ] - for _ in range(num_layers - 2) - ] - ), - nn.Dense(hidden_dim, 1 + 3), - ) - # bias init as zero by default, can refer to ~/examples/stable_diffusion_v2/tests/test_lora.py & lora_torch.py for evidence - - def construct(self, sampled_features): - # Aggregate features by mean - # sampled_features = sampled_features.mean(1) - # Aggregate features by concatenation - _N, n_planes, _M, _C = sampled_features.shape - x = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes * _C) - - N, M, color = x.shape - x = x.view((N * M, color)) - - x = self.net(x) - x = x.view((N, M, -1)) - rgb = mint.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF - sigma = x[..., 0:1] - - return rgb, sigma - - class ImportanceRenderer(nn.Cell): """Modified version of the filtering the out-of-box sampels as TensorRF does.""" - def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False): + def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False, decoder: Optional[nn.Cell] = None): super().__init__() self.rendering_options = opts @@ -73,11 +27,9 @@ def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False): self.plane_axes = generate_planes().astype(dtype) self.max_pool1d_layer = nn.MaxPool1d(2, 1, pad_mode="pad", padding=1) self.avg_pool1d_layer = nn.AvgPool1d(2, 1) - self.decoder = OSGDecoder(n_features=80) # triplane_dim + self.decoder = decoder self.debug_logging = debug - self.path_to_save_grad = "./renderer_save_grad/" - def project_onto_planes( self, planes: ms.Tensor, # when calling this here from outside, particually it's sampling on the unit plane axes @@ -113,22 +65,8 @@ def sample_from_planes(self, plane_features, coordinates): coordinates = (2 / self.rendering_options["box_warp"]) * coordinates # add specific box bounds - # debug - # plane_features += self.plane_feat_grad - projected_coordinates = self.project_onto_planes(self.plane_axes, coordinates).unsqueeze(1) - # ts.save(self.path_to_save_grad + 'pf_input', plane_features) - # plane_features = ts.save_grad(self.path_to_save_grad + 'pf_input_grad', plane_features) - - # output_features = ops.grid_sample( - # plane_features, - # projected_coordinates, - # mode=mode, - # padding_mode=padding_mode, - # align_corners=False, - # ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) - output_features = ( mint.nn.functional.grid_sample( plane_features, @@ -141,9 +79,6 @@ def sample_from_planes(self, plane_features, coordinates): .reshape(N, n_planes, M, C) ) - # ts.save(self.path_to_save_grad + 'pf_output', plane_features) - # output_features = ts.save_grad(self.path_to_save_grad + 'output_features_output_grad', output_features) - return output_features def run_model( From 6ac74a97a2cddbb6129f98fc3319fa47ae8621d2 Mon Sep 17 00:00:00 2001 From: HaFred Date: Fri, 1 Nov 2024 18:26:11 +0800 Subject: [PATCH 16/18] swtich ops to mint AMAP --- examples/instantmesh/model_stage1.py | 26 +++++------- .../instantmesh/models/decoder/transformer.py | 14 +++---- examples/instantmesh/models/encoder/dino.py | 24 +++++------ .../models/encoder/dino_wrapper.py | 10 +++-- examples/instantmesh/models/lrm.py | 2 +- .../models/renderer/utils/math_utils.py | 4 +- .../models/renderer/utils/ray_marcher.py | 6 ++- .../models/renderer/utils/ray_sampler.py | 4 +- .../models/renderer/utils/renderer.py | 40 ++++++++++++------- examples/instantmesh/utils/camera_util.py | 18 ++++----- examples/instantmesh/utils/loss_util.py | 6 +-- 11 files changed, 83 insertions(+), 71 deletions(-) diff --git a/examples/instantmesh/model_stage1.py b/examples/instantmesh/model_stage1.py index 2d04907086..bf9f4d848c 100644 --- a/examples/instantmesh/model_stage1.py +++ b/examples/instantmesh/model_stage1.py @@ -10,7 +10,7 @@ from utils.loss_util import LPIPS import mindspore as ms -from mindspore import Tensor, mint, nn, ops +from mindspore import Tensor, mint, nn from mindspore.dataset.vision import ToPIL from mindone.utils.config import instantiate_from_config @@ -78,22 +78,22 @@ def prepare_validation_batch_data(self, batch, render_size, _use_dataloader=Fals input_Ks = batch["input_Ks"].flatten(start_dim=-2) input_extrinsics = input_c2ws[:, :12] - input_intrinsics = ops.stack( + input_intrinsics = mint.stack( [ input_Ks[:, 0], input_Ks[:, 4], input_Ks[:, 2], input_Ks[:, 5], ], - axis=-1, + dim=-1, ) - cameras = ops.cat([input_extrinsics, input_intrinsics], axis=-1) + cameras = mint.cat([input_extrinsics, input_intrinsics], dim=-1) lrm_generator_input["cameras"] = cameras render_c2ws = batch["render_c2ws"].flatten(start_dim=-2) render_Ks = batch["render_Ks"].flatten(start_dim=-2) - render_cameras = ops.cat([render_c2ws, render_Ks], axis=-1) + render_cameras = mint.cat([render_c2ws, render_Ks], dim=-1) lrm_generator_input["render_cameras"] = render_cameras # create batch dim when not using dataloader, presuming bsize==1 @@ -121,8 +121,8 @@ def construct( images_rgb, images_depth, images_weight = self.lrm_generator( images, cameras, render_cameras, render_size.item(), crop_params # to int ) - render_images = ops.clamp(images_rgb, 0.0, 1.0) - render_alphas = ops.clamp(images_weight, 0.0, 1.0) + render_images = mint.clamp(images_rgb, 0.0, 1.0) + render_alphas = mint.clamp(images_weight, 0.0, 1.0) loss = self.compute_loss(render_images, render_alphas, target_images, target_alphas) @@ -134,8 +134,8 @@ def forward_nocalloss(self, images: Tensor, cameras: Tensor, render_cameras: Ten images_rgb, images_depth, images_weight = self.lrm_generator( images, cameras, render_cameras, render_size, crop_params=None ) - render_images = ops.clamp(images_rgb, 0.0, 1.0) - render_alphas = ops.clamp(images_weight, 0.0, 1.0) + render_images = mint.clamp(images_rgb, 0.0, 1.0) + render_alphas = mint.clamp(images_weight, 0.0, 1.0) return render_images, render_alphas def compute_loss(self, render_images, render_alphas, target_images, target_alphas): @@ -152,14 +152,10 @@ def compute_loss(self, render_images, render_alphas, target_images, target_alpha b, n, c, h, w = target_images.shape target_images = target_images.reshape(b * n, c, h, w) * 2.0 - 1.0 - loss_mse = ops.mse_loss(render_images, target_images) - - # FIXME current lpips loss wrong, not positive. And 2e3 scale here is for the difference of the lpips implemented here vs. torchmetric's - # loss_lpips = 2e2 * mint.mean(self.lpips(render_images, target_images)) + loss_mse = mint.nn.functional.mse_loss(render_images, target_images) loss_lpips = 2.0 * mint.mean(self.lpips(render_images, target_images)) - target_alphas = target_alphas.permute((0, 1, 4, 2, 3)) # b n h w c -> b n c h w - loss_mask = ops.mse_loss(render_alphas, target_alphas) + loss_mask = mint.nn.functional.mse_loss(render_alphas, target_alphas) logger.info(f"loss mse: {loss_mse}, loss mask: {loss_mask}, loss lpips: {loss_lpips}") diff --git a/examples/instantmesh/models/decoder/transformer.py b/examples/instantmesh/models/decoder/transformer.py index 4058104712..d1bf3baf84 100644 --- a/examples/instantmesh/models/decoder/transformer.py +++ b/examples/instantmesh/models/decoder/transformer.py @@ -1,6 +1,6 @@ import mindspore as ms import mindspore.nn as nn -import mindspore.ops as ops +from mindspore import mint class BasicTransformerBlock(nn.Cell): @@ -94,7 +94,7 @@ def __init__( # modules # initialize pos_embed with 1/sqrt(dim) * N(0, 1) self.pos_embed = ms.Parameter( - ops.randn(1, 3 * triplane_low_res**2, inner_dim) * (1.0 / inner_dim) ** 0.5 + mint.normal(size=(1, 3 * triplane_low_res**2, inner_dim)) * (1.0 / inner_dim) ** 0.5 ) # [L, D] self.layers = nn.CellList( [ @@ -130,13 +130,13 @@ def construct(self, image_feats): # separate each plane and apply deconv x = x.view(N, 3, H, W, -1) - # x = ops.einsum('nihwd->indhw', x) # [3, N, D, H, W] - # x = ops.reshape(x, (3, N, -1, H, W)) - x = ops.permute(x, (1, 0, 4, 2, 3)) + # x = mint.einsum('nihwd->indhw', x) # [3, N, D, H, W] + # x = mint.reshape(x, (3, N, -1, H, W)) + x = mint.permute(x, (1, 0, 4, 2, 3)) x = x.view(3 * N, -1, H, W) # [3*N, D, H, W] x = self.deconv(x) # [3*N, D', H', W'] x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] - # x = ops.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] - x = ops.permute(x, (1, 0, 2, 3, 4)) + # x = mint.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = mint.permute(x, (1, 0, 2, 3, 4)) return x diff --git a/examples/instantmesh/models/encoder/dino.py b/examples/instantmesh/models/encoder/dino.py index d885f2eefd..0b64aa7f21 100644 --- a/examples/instantmesh/models/encoder/dino.py +++ b/examples/instantmesh/models/encoder/dino.py @@ -9,7 +9,7 @@ import mindspore as ms import mindspore.nn as nn -from mindspore import Parameter, Tensor, ops +from mindspore import Parameter, Tensor, mint, ops from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.nn import LayerNorm @@ -27,11 +27,11 @@ class ViTEmbeddings(nn.Cell): def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: super().__init__() - self.cls_token = Parameter(ops.randn(1, 1, config.hidden_size)) - self.mask_token = Parameter(ops.zeros((1, 1, config.hidden_size))) if use_mask_token else None + self.cls_token = Parameter(mint.normal(size=(1, 1, config.hidden_size))) + self.mask_token = Parameter(mint.zeros((1, 1, config.hidden_size))) if use_mask_token else None self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = Parameter(ops.randn(1, num_patches + 1, config.hidden_size)) + self.position_embeddings = Parameter(mint.normal(size=(1, num_patches + 1, config.hidden_size))) self.dropout = nn.Dropout(p=config.hidden_dropout_prob) self.config = config @@ -70,7 +70,7 @@ def interpolate_pos_encoding(self, embeddings: Tensor, height: int, width: int) assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return ops.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), axis=1) + return mint.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def construct( self, @@ -90,7 +90,7 @@ def construct( # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.to(embeddings.dtype).broadcast_to((batch_size, -1, -1)) - embeddings = ops.cat((cls_tokens, embeddings), axis=1) + embeddings = mint.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token if interpolate_pos_encoding: @@ -147,7 +147,7 @@ def __init__(self, config: ViTConfig) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( - f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"The hidden size {config.hidden_size, } is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) @@ -175,12 +175,12 @@ def construct( query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = ops.matmul(query_layer, key_layer.swapaxes(-1, -2)) + attention_scores = mint.matmul(query_layer, key_layer.swapaxes(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. - attention_probs = ops.Softmax(axis=-1)(attention_scores) + attention_probs = mint.nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -190,7 +190,7 @@ def construct( if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = ops.matmul(attention_probs, value_layer) + context_layer = mint.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3) new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) @@ -311,8 +311,8 @@ def __init__(self, config: ViTConfig) -> None: ) # nn.init.constant_(self.adaLN_modulation[-1].weight, 0) # nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - self.adaLN_modulation[-1].weight = ops.zeros_like(self.adaLN_modulation[-1].weight) - self.adaLN_modulation[-1].bias = ops.zeros_like(self.adaLN_modulation[-1].bias) + self.adaLN_modulation[-1].weight = mint.zeros_like(self.adaLN_modulation[-1].weight) + self.adaLN_modulation[-1].bias = mint.zeros_like(self.adaLN_modulation[-1].bias) def construct( self, diff --git a/examples/instantmesh/models/encoder/dino_wrapper.py b/examples/instantmesh/models/encoder/dino_wrapper.py index cb52f370cf..408f54c2d0 100644 --- a/examples/instantmesh/models/encoder/dino_wrapper.py +++ b/examples/instantmesh/models/encoder/dino_wrapper.py @@ -1,11 +1,8 @@ import mindspore as ms import mindspore.nn as nn -# comment when debug from .dino import ViTModel -# from dino import ViTModel - class DinoWrapper(nn.Cell): """ @@ -23,12 +20,15 @@ def __init__(self, model_name: str, use_recompute: bool = False): if use_recompute: self.camera_embedder.recompute() self.model.encoder.recompute() - self.model.layernorm.recompute() + self.model.layernorm.recompute() # recompute layernorm causes gram leackage? + # @ms.jit, for now don't make it graph mode, as the vit encoder output dict will be none weirdly def construct(self, images: ms.Tensor, camera: ms.Tensor): # because img processor only takes np img # image: [B, N, C, H, W] # camera: [B, N, D] + # logger.info(f'input np image shape is {images.shape}') if images.ndim == 5: + # image = rearrange(image, 'b n c h w -> (b n) c h w') # NOW ITS ALREADY NCHW (B, N, C, H, W) = images.shape images = images.reshape(B * N, C, H, W) @@ -39,9 +39,11 @@ def construct(self, images: ms.Tensor, camera: ms.Tensor): # because img proces cam_emb_shape = camera_embeddings.shape camera_embeddings = camera_embeddings.reshape(cam_emb_shape[0] * cam_emb_shape[1], cam_emb_shape[2]) embeddings = camera_embeddings + # logger.info(f'emd shape {embeddings.shape}') # This resampling of positional embedding uses bicubic interpolation outputs = self.model(pixel_values=images, adaln_input=embeddings, interpolate_pos_encoding=True)[0] + # last_hidden_states = outputs.last_hidden_state return outputs @staticmethod diff --git a/examples/instantmesh/models/lrm.py b/examples/instantmesh/models/lrm.py index 71c2b1cc63..f79becee5c 100644 --- a/examples/instantmesh/models/lrm.py +++ b/examples/instantmesh/models/lrm.py @@ -89,7 +89,7 @@ def construct( cameras: ms.Tensor, render_cameras: ms.Tensor, render_size: int, - crop_params: Tuple[int], + crop_params: Optional[Tuple[int]], ): # images: [B, V, C_img, H_img, W_img] # cameras: [B, V, 16] diff --git a/examples/instantmesh/models/renderer/utils/math_utils.py b/examples/instantmesh/models/renderer/utils/math_utils.py index c3e189627b..1eb65fa000 100644 --- a/examples/instantmesh/models/renderer/utils/math_utils.py +++ b/examples/instantmesh/models/renderer/utils/math_utils.py @@ -14,7 +14,7 @@ def normalize_vecs(vectors: ms.Tensor) -> ms.Tensor: """ Normalize vector lengths. """ - return vectors / (ops.norm(vectors, dim=-1, keepdim=True)) + return vectors / (mint.norm(vectors, dim=-1, keepdim=True)) def get_ray_limits_box( @@ -32,7 +32,7 @@ def get_ray_limits_box( bb_min = [-1 * (box_side_length / 2), -1 * (box_side_length / 2), -1 * (box_side_length / 2)] bb_max = [1 * (box_side_length / 2), 1 * (box_side_length / 2), 1 * (box_side_length / 2)] bounds = ms.Tensor((bb_min, bb_max), dtype=rays_o.dtype) - is_valid = ops.ones(rays_o.shape[:-1], dtype=ms.bool_) + is_valid = mint.ones(rays_o.shape[:-1], dtype=ms.bool_) # Precompute inverse for stability. invdir = 1 / rays_d diff --git a/examples/instantmesh/models/renderer/utils/ray_marcher.py b/examples/instantmesh/models/renderer/utils/ray_marcher.py index 1cdef88044..30342635f3 100644 --- a/examples/instantmesh/models/renderer/utils/ray_marcher.py +++ b/examples/instantmesh/models/renderer/utils/ray_marcher.py @@ -32,7 +32,7 @@ def construct( alpha = 1 - mint.exp(-density_delta).to(dtype) - alpha_shifted = ops.cat([mint.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) + alpha_shifted = mint.cat([mint.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) weights = alpha * ops.cumprod(alpha_shifted, -2)[:, :, :-1] weights = weights.to(dtype) @@ -44,6 +44,10 @@ def construct( # clip the composite to min/max range of depths composite_depth = ops.nan_to_num(composite_depth, float("inf")).to(dtype) + min_val = mint.min(depths) + max_val = mint.max(depths) + composite_depth = mint.clamp(composite_depth, min_val, max_val) + if self.white_back: composite_rgb = composite_rgb + 1 - weight_total diff --git a/examples/instantmesh/models/renderer/utils/ray_sampler.py b/examples/instantmesh/models/renderer/utils/ray_sampler.py index 9c586b9b79..986bdf5827 100644 --- a/examples/instantmesh/models/renderer/utils/ray_sampler.py +++ b/examples/instantmesh/models/renderer/utils/ray_sampler.py @@ -33,7 +33,7 @@ def construct(self, cam2world_matrix: ms.Tensor, intrinsics: ms.Tensor, render_s cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] - uv = ops.stack( + uv = mint.stack( ops.meshgrid( mint.arange(render_size, dtype=dtype), mint.arange(render_size, dtype=dtype), @@ -45,7 +45,7 @@ def construct(self, cam2world_matrix: ms.Tensor, intrinsics: ms.Tensor, render_s x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) - z_cam = ops.ones((N, M), dtype=dtype) + z_cam = mint.ones((N, M), dtype=dtype) x_lift = ( ( diff --git a/examples/instantmesh/models/renderer/utils/renderer.py b/examples/instantmesh/models/renderer/utils/renderer.py index 98ced4bd97..7840140fa5 100644 --- a/examples/instantmesh/models/renderer/utils/renderer.py +++ b/examples/instantmesh/models/renderer/utils/renderer.py @@ -1,15 +1,20 @@ """ The renderer is a module that takes in rays, decides where to sample along each ray, and computes pixel colors using the volume rendering equation. """ +import logging from typing import Dict, Optional +import math_utils + import mindspore as ms import mindspore.nn as nn from mindspore import _no_grad, mint, ops -from . import math_utils +# comment below for debugging insitu from .ray_marcher import MipRayMarcher2 +logger = logging.getLogger(__name__) + class ImportanceRenderer(nn.Cell): """Modified version of the filtering the out-of-box sampels as TensorRF does.""" @@ -27,6 +32,7 @@ def __init__(self, opts: Dict, dtype: ms.dtype, debug: bool = False, decoder: Op self.plane_axes = generate_planes().astype(dtype) self.max_pool1d_layer = nn.MaxPool1d(2, 1, pad_mode="pad", padding=1) self.avg_pool1d_layer = nn.AvgPool1d(2, 1) + self.decoder = decoder self.debug_logging = debug @@ -61,10 +67,8 @@ def sample_from_planes(self, plane_features, coordinates): N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N * n_planes, C, H, W) - # dtype = plane_features.dtype coordinates = (2 / self.rendering_options["box_warp"]) * coordinates # add specific box bounds - projected_coordinates = self.project_onto_planes(self.plane_axes, coordinates).unsqueeze(1) output_features = ( @@ -111,7 +115,7 @@ def _forward_pass( ) # filter out-of-box samples - mask_inbox = ops.logical_and( + mask_inbox = mint.logical_and( self.rendering_options["sampler_bbox_min"] <= sample_coordinates, sample_coordinates <= self.rendering_options["sampler_bbox_max"], ) @@ -123,12 +127,18 @@ def _forward_pass( # set out-of-box samples to zeros(rgb) & -inf(sigma) SAFE_GUARD = 3 DATA_TYPE = _sigma.dtype - colors_pass = ops.zeros((batch_size, num_rays * samples_per_ray, 3), dtype=DATA_TYPE) + colors_pass = mint.zeros((batch_size, num_rays * samples_per_ray, 3), dtype=DATA_TYPE) densities_pass = ( ops.nan_to_num(mint.full((batch_size, num_rays * samples_per_ray, 1), -float("inf"), dtype=DATA_TYPE)) / SAFE_GUARD ) + if self.debug_logging: + logger.info( + f"shape] depths: {mask_inbox.shape}, rd: {_rgb.shape}, ro: {colors_pass.shape}, planes: {planes.shape}" + ) + logger.info(f"shape] mi: {mask_inbox.shape}, rgb: {_rgb.shape}, colorpass: {colors_pass.shape}") + # colors_pass[mask_inbox] = _rgb[mask_inbox] mask_inbox = mask_inbox[..., None] colors_pass = mint.where(mask_inbox, _rgb, colors_pass) # Tensor indexing assignment in G mode @@ -142,23 +152,23 @@ def _forward_pass( return colors_pass, densities_pass def sort_samples(self, all_depths, all_colors, all_densities): - _, indices = ops.sort(all_depths, axis=-2) + _, indices = mint.sort(all_depths, dim=-2) all_depths = mint.gather(all_depths, -2, indices) all_colors = mint.gather(all_colors, -2, indices.broadcast_to((-1, -1, -1, all_colors.shape[-1]))) all_densities = mint.gather(all_densities, -2, indices.broadcast_to((-1, -1, -1, 1))) return all_depths, all_colors, all_densities def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): - all_depths = ops.cat([depths1, depths2], axis=-2) - all_colors = ops.cat([colors1, colors2], axis=-2) - all_densities = ops.cat([densities1, densities2], axis=-2) + all_depths = mint.cat([depths1, depths2], dim=-2) + all_colors = mint.cat([colors1, colors2], dim=-2) + all_densities = mint.cat([densities1, densities2], dim=-2) if normals1 is not None and normals2 is not None: - all_normals = ops.cat([normals1, normals2], axis=-2) + all_normals = mint.cat([normals1, normals2], dim=-2) else: all_normals = None - _, indices = ops.sort(all_depths, axis=-2) + _, indices = mint.sort(all_depths, dim=-2) all_depths = mint.gather(all_depths, -2, indices) all_colors = mint.gather(all_colors, -2, indices.broadcast_to((-1, -1, -1, all_colors.shape[-1]))) all_densities = mint.gather(all_densities, -2, indices.broadcast_to((-1, -1, -1, 1))) @@ -212,15 +222,15 @@ def sample_pdf(self, bins, weights): N_rays, N_samples_ = weights.shape weights = weights + eps # prevent division by zero (don't do inplace op!) pdf = weights / mint.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) - cdf = ops.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function - cdf = ops.cat([mint.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) + cdf = mint.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = mint.cat([mint.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) # padded to 0~1 inclusive if det: u = mint.linspace(0, 1, self.N_importance) u = u.broadcast_to((N_rays, self.N_importance)) else: - u = ops.rand(N_rays, self.N_importance) + u = mint.rand(N_rays, self.N_importance) inds = mint.searchsorted(cdf, u, right=True) below = mint.clamp(inds - 1, min=0) @@ -272,7 +282,7 @@ def construct( is_ray_valid = ray_end > ray_start # FIXME below take item may degrade the shape, potentially into unknown errors... - if ops.any(is_ray_valid).item(): + if mint.any(is_ray_valid).item(): ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end) diff --git a/examples/instantmesh/utils/camera_util.py b/examples/instantmesh/utils/camera_util.py index a8a161bcc6..f78af5548a 100644 --- a/examples/instantmesh/utils/camera_util.py +++ b/examples/instantmesh/utils/camera_util.py @@ -1,7 +1,7 @@ import numpy as np import mindspore as ms -from mindspore import Tensor, ops +from mindspore import Tensor, mint, ops def pad_camera_extrinsics_4x4(extrinsics): @@ -10,7 +10,7 @@ def pad_camera_extrinsics_4x4(extrinsics): padding = Tensor([[0, 0, 0, 1]]).to(extrinsics.dtype) if extrinsics.ndim == 3: padding = padding.unsqueeze(0).tile((extrinsics.shape[0], 1, 1)) - extrinsics = ops.cat([extrinsics, padding], axis=-2) + extrinsics = mint.cat([extrinsics, padding], dim=-2) return extrinsics @@ -21,7 +21,7 @@ def center_looking_at_camera_pose(camera_position: Tensor, look_at=None, up_worl camera_position: (M, 3) or (3,) look_at: (3) up_world: (3) - return: (M, 3, 4) or (3, 4) + return: Tensor, (M, 3, 4) or (3, 4) """ # by default, looking at the origin and world up is z-axis if look_at is None: @@ -36,13 +36,13 @@ def center_looking_at_camera_pose(camera_position: Tensor, look_at=None, up_worl z_axis = camera_position - look_at norm = ops.L2Normalize(axis=-1) z_axis = norm(z_axis) - x_axis = ops.cross(up_world, z_axis, dim=-1) + x_axis = mint.cross(up_world, z_axis, dim=-1) x_axis = norm(x_axis) - y_axis = ops.cross(z_axis, x_axis, dim=-1) + y_axis = mint.cross(z_axis, x_axis, dim=-1) y_axis = norm(y_axis) print(f"zshape: {z_axis.shape}, xshape: {x_axis.shape}, yshape: {y_axis.shape}") - extrinsics = ops.stack([x_axis, y_axis, z_axis, camera_position], axis=-1) + extrinsics = mint.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) print(f"fred: the extrinsics shape of {extrinsics.shape}") extrinsics = pad_camera_extrinsics_4x4(extrinsics) return extrinsics @@ -108,11 +108,11 @@ def get_sv3d_input_cameras(bs=1, radius=4.0, fov=30.0): pose_cam2world = spherical_camera_pose(azimuths, elevations, radius) pose_cam2world = pose_cam2world.float().flatten(start_dim=-2) - Ks = FOV_to_intrinsics(fov).unsqueeze(0).tile((21, 1, 1)).float().flatten(start_dim=-2) + Ks = Tensor(FOV_to_intrinsics(fov)).unsqueeze(0).tile((21, 1, 1)).float().flatten(start_dim=-2) extrinsics = pose_cam2world[:, :12] - intrinsics = ops.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], axis=-1) - cameras = ops.cat([extrinsics, intrinsics], axis=-1) + intrinsics = mint.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = mint.cat([extrinsics, intrinsics], dim=-1) print(f"cameras dtype is {cameras.dtype}") diff --git a/examples/instantmesh/utils/loss_util.py b/examples/instantmesh/utils/loss_util.py index 51c93d5eb0..7746ec5219 100644 --- a/examples/instantmesh/utils/loss_util.py +++ b/examples/instantmesh/utils/loss_util.py @@ -5,7 +5,7 @@ import mindspore as ms import mindspore.nn as nn -import mindspore.ops as ops +from mindspore import mint _logger = logging.getLogger("") @@ -71,7 +71,7 @@ def construct(self, input, target): diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 # res += spatial_average(lins[kk](diff), keepdim=True) # lin_layer = lins[kk] - val += ops.mean(self.lins[kk](diff), axis=[2, 3], keep_dims=True) + val += mint.mean(self.lins[kk](diff), dim=[2, 3], keep_dims=True) return val @@ -152,7 +152,7 @@ def construct(self, X): def normalize_tensor(x, eps=1e-10): - norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + norm_factor = mint.sqrt((x**2).sum(1, keepdims=True)) return x / (norm_factor + eps) From 59697a8107fdae10ef5854fdae48620fe199733b Mon Sep 17 00:00:00 2001 From: HaFred Date: Wed, 13 Nov 2024 16:38:43 +0800 Subject: [PATCH 17/18] upload the safetensor conversion snippet mentioned in the readme --- examples/instantmesh/README.md | 13 +- .../tools/convert_dinovit_bin2st.py | 256 ++++++++++++++++++ examples/instantmesh/tools/convert_pt2ms.py | 49 ++++ 3 files changed, 313 insertions(+), 5 deletions(-) create mode 100644 examples/instantmesh/tools/convert_dinovit_bin2st.py create mode 100644 examples/instantmesh/tools/convert_pt2ms.py diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index c2ab0e96cb..920c31a9e7 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -1,8 +1,5 @@ # InstantMesh: 3D Mesh Generation from Multiview Images -- [x] Elaborate on the design methodology in the intro on Oct 2 -- [x] Training part - We support [instantmesh](https://github.com/TencentARC/InstantMesh) for the 3D mesh generation using the multiview images extracted from [the sv3d pipeline](https://github.com/mindspore-lab/mindone/pull/574).

Capture @@ -83,11 +80,17 @@ pip install -r requirements.txt | 2.3.1 | 24.1.RC2 |7.3.0.1.231 | 8.0.RC2.beta1 | ## Pretrained Models - -To better accommodate the mindone transformer codebase, we provide an out-of-the-box [checkpoints conversion script](../../tools/convert_instantmesh_ckpt.py) that works seamlessly with the mindspore version of transformers. +### ViT Pretrained Checkpoint +To better accommodate the mindone transformer codebase, we provide an out-of-the-box [checkpoints conversion script](./tools/convert_dinovit_bin2st.py) that works seamlessly with the mindspore version of transformers. The image features are extracted with dino-vit, which depends on HuggingFace's transformer package. We reuse [the MindSpore's implementation](https://github.com/mindspore-lab/mindone/blob/master/mindone/transformers/modeling_utils.py#L499) and the only challenge remains to be that `.bin` checkpoint of [dino-vit](https://huggingface.co/facebook/dino-vitb16/tree/main) is not supported by MindSpore off-the-shelf. The checkpoint script above serves easy conversion purposes and ensures that dino-vit is still based on `MSPreTrainedModel` safe and sound. +### InstantMesh Checkpoint +To convert checkpoints, we prepare the following snippet. +```bash +python tools/convert_pt2ms.py --trgt PATH_TO_CKPT +``` + ## Inference ```shell diff --git a/examples/instantmesh/tools/convert_dinovit_bin2st.py b/examples/instantmesh/tools/convert_dinovit_bin2st.py new file mode 100644 index 0000000000..c5cafa4d21 --- /dev/null +++ b/examples/instantmesh/tools/convert_dinovit_bin2st.py @@ -0,0 +1,256 @@ +"""convert pytorch_model.bin to model.safetensors""" + +import argparse +import os +from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple + +import torch +from huggingface_hub import CommitInfo, CommitOperationAdd, HfApi, hf_hub_download +from huggingface_hub.file_download import repo_folder_name +from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file + +ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]: + try: + import json + + import transformers + + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json", token=token, cache_dir=folder + ) + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + + except Exception: + discard_names = [] + return discard_names + + +class AlreadyExists(Exception): + pass + + +def check_file_size(sf_filename: str, pt_filename: str): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: \n- {sf_filename}: {sf_size} \n- {pt_filename}: {pt_size}""" + ) + + +def rename(pt_filename: str) -> str: + filename, ext = os.path.splitext(pt_filename) + local = f"{filename}.safetensors" + local = local.replace("pytorch_model", "model") + return local + + +def convert_single( + model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str] +) -> ConversionResult: + pt_filename = hf_hub_download( + repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder + ) + + sf_name = "model.safetensors" + sf_filename = os.path.join(folder, sf_name) + convert_file(pt_filename, sf_filename, discard_names) + operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)] + errors: List[Tuple[str, "Exception"]] = [] + return operations, errors + + +def convert_file( + pt_filename: str, + sf_filename: str, + discard_names: List[str], +): + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata=metadata) + check_file_size(sf_filename, pt_filename) + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_generic( + model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str] +) -> ConversionResult: + operations = [] + errors = [] + + extensions = set([".bin", ".ckpt"]) + for filename in filenames: + prefix, ext = os.path.splitext(filename) + if ext in extensions: + pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder) + dirname, raw_filename = os.path.split(filename) + if raw_filename == "pytorch_model.bin": + sf_in_repo = os.path.join(dirname, "model.safetensors") + else: + sf_in_repo = f"{prefix}.safetensors" + sf_filename = os.path.join(folder, sf_in_repo) + try: + convert_file(pt_filename, sf_filename, discard_names=[]) + operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)) + except Exception as e: + errors.append((pt_filename, e)) + return operations, errors + + +def convert( + api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False +) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: + info = api.model_info(model_id, revision=revision) + filenames = set(s.rfilename for s in info.siblings) + + # with TemporaryDirectory() as d: + d = os.environ["HF_HOME"] + folder = os.path.join(d, "hub", repo_folder_name(repo_id=model_id, repo_type="model")) + print(f"current folder is {folder}") + os.makedirs(folder, exist_ok=True) + library_name = getattr(info, "library_name", None) + if any(filename.endswith(".safetensors") for filename in filenames) and not force: + raise AlreadyExists(f"Model {model_id} is already converted, skipping..") + elif library_name == "transformers": + discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token) + if "pytorch_model.bin" in filenames: + operations, errors = convert_single( + model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names + ) + else: + raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") + else: + operations, errors = convert_generic( + model_id, revision=revision, folder=folder, filenames=filenames, token=api.token + ) + return errors + + +if __name__ == "__main__": + DESCRIPTION = """ + Simple utility tool to convert automatically some weights on the hub to `safetensors` format. + It is PyTorch exclusive for now. + It works by downloading the weights (PT), converting them locally, and uploading them back + as a PR on the hub. + """ + parser = argparse.ArgumentParser(description=DESCRIPTION) + parser.add_argument( + "model_id", + type=str, + help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", + ) + parser.add_argument( + "--revision", + type=str, + help="The revision to convert", + ) + parser.add_argument( + "--force", + action="store_true", + help="Create the PR even if it already exists of if the model was already converted.", + ) + parser.add_argument( + "-y", + action="store_true", + help="Ignore safety prompt", + ) + args = parser.parse_args() + + model_id = args.model_id + api = HfApi() + if args.y: + txt = "y" + else: + txt = input( + "This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use" + " https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file." + " Continue [Y/n] ?" + ) + if txt.lower() in {"", "y"}: + commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force) + string = """### Success 🔥 Yay! This model was successfully converted""" + if errors: + string += "\nErrors during conversion:\n" + string += "\n".join(f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors) + print(string) + else: + print(f"Answer was '{txt}' aborting.") diff --git a/examples/instantmesh/tools/convert_pt2ms.py b/examples/instantmesh/tools/convert_pt2ms.py new file mode 100644 index 0000000000..b7375f5223 --- /dev/null +++ b/examples/instantmesh/tools/convert_pt2ms.py @@ -0,0 +1,49 @@ +"""convert pt ckpt to model.safetensors""" + +import argparse + +import numpy as np +import torch +from safetensors import safe_open + +import mindspore as ms + + +def convert(pt_ckpt, target_fp, pick_ema=True): + if pt_ckpt.endswith(".ckpt") or pt_ckpt.endswith(".pt") or pt_ckpt.endswith(".pth"): + state_dict = torch.load(pt_ckpt, torch.device("cpu")) + if "ema" in state_dict and pick_ema: + print("WARNING: find EMA weights in source checkpoint. Will pick it!") + state_dict = state_dict["ema"] + elif "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + else: + state_dict = {} + with safe_open(pt_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + target_data = [] + for k in state_dict: + print(k) + if "." not in k: + # only for GroupNorm + ms_name = k.replace("weight", "gamma").replace("bias", "beta") + else: + if "norm" in k: + ms_name = k.replace(".weight", ".gamma").replace(".bias", ".beta") + else: + ms_name = k + + val = state_dict[k].detach().numpy().astype(np.float32) + target_data.append({"name": ms_name, "data": ms.Tensor(val, dtype=ms.float32)}) + + ms.save_checkpoint(target_data, target_fp) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, default="YOUR_PT_CKPT_PATH", help="path to torch checkpoint path") + parser.add_argument("--trgt", type=str, help="target file path to save the converted checkpoint") + args = parser.parse_args() + + convert(args.src, args.trgt) From 968a03b44612534e2d1669aef87a89c52ec42743 Mon Sep 17 00:00:00 2001 From: HaFred Date: Wed, 13 Nov 2024 17:04:26 +0800 Subject: [PATCH 18/18] update link --- examples/instantmesh/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instantmesh/README.md b/examples/instantmesh/README.md index 920c31a9e7..5a30595bad 100644 --- a/examples/instantmesh/README.md +++ b/examples/instantmesh/README.md @@ -57,7 +57,7 @@ A walk-through of the file structure is provided here as below. InstantMesh [[1]](#acknowledgements) synergizes the strengths of a multiview diffusion model and a sparse-view reconstruction model based on the LRM [[2]](#acknowledgements) architecture. It also adopts FlexiCubes [[3]](#acknowledgements) isosurface extraction for a smoother and more elegant mesh extraction. -Using the multiview images input from 3D mesh extracted from [the sv3d pipeline](../../simple_video_sample.py), we extracted 3D meshes as below. Please kindly find the input illustrated by following the link to the sv3d pipeline above. +Using the multiview images input from 3D mesh extracted from [the sv3d pipeline](../sv3d/simple_video_sample.py), we extracted 3D meshes as below. Please kindly find the input illustrated by following the link to the sv3d pipeline above. |

akun

|

anya

| | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |