diff --git a/.buildinfo b/.buildinfo index bfe1e0e..40a0a27 100644 --- a/.buildinfo +++ b/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 76e008e4845e9ce31818045e3e2f72f8 +config: dac2eef83a8a45408328865fe5e6d562 tags: d77d1c0d9ca2f4c8421862c7c5a0d620 diff --git a/.doctrees/camera.doctree b/.doctrees/camera.doctree index ad9c4f9..ff07560 100644 Binary files a/.doctrees/camera.doctree and b/.doctrees/camera.doctree differ diff --git a/.doctrees/environment.pickle b/.doctrees/environment.pickle index 06ef481..6fad986 100644 Binary files a/.doctrees/environment.pickle and b/.doctrees/environment.pickle differ diff --git a/.doctrees/mesh.doctree b/.doctrees/mesh.doctree index 5ea9502..b99a881 100644 Binary files a/.doctrees/mesh.doctree and b/.doctrees/mesh.doctree differ diff --git a/.doctrees/misc_api.doctree b/.doctrees/misc_api.doctree index ca99d5d..2fe1f6f 100644 Binary files a/.doctrees/misc_api.doctree and b/.doctrees/misc_api.doctree differ diff --git a/.doctrees/ops.doctree b/.doctrees/ops.doctree index 27db2ba..13a10ba 100644 Binary files a/.doctrees/ops.doctree and b/.doctrees/ops.doctree differ diff --git a/.doctrees/utils.doctree b/.doctrees/utils.doctree index 825dabf..4dcf12c 100644 Binary files a/.doctrees/utils.doctree and b/.doctrees/utils.doctree differ diff --git a/.doctrees/vis.doctree b/.doctrees/vis.doctree index b0b74d9..a4593cd 100644 Binary files a/.doctrees/vis.doctree and b/.doctrees/vis.doctree differ diff --git a/_modules/index.html b/_modules/index.html deleted file mode 100644 index 8d8e0e2..0000000 --- a/_modules/index.html +++ /dev/null @@ -1,311 +0,0 @@ - - -
- - - - -
-import numpy as np
-from scipy.spatial.transform import Rotation
-
-from kiui.op import safe_normalize
-from kiui.typing import *
-
-# convert between different world coordinate systems
-
-[docs]
-def convert(
- pose,
- target: Literal['unity', 'blender', 'opencv', 'colmap', 'opengl'] = 'unity',
- original: Literal['unity', 'blender', 'opencv', 'colmap', 'opengl'] = 'opengl',
-):
- """A method to convert between different world coordinate systems.
-
- Args:
- pose (np.ndarray): camera pose, float [4, 4].
- target (Literal['unity', 'blender', 'opencv', 'colmap', 'opengl'], optional): from convention. Defaults to 'unity'.
- original (Literal['unity', 'blender', 'opencv', 'colmap', 'opengl'], optional): to convention. Defaults to 'opengl'.
-
- Returns:
- np.ndarray: converted camera pose, float [4, 4].
- """
-
- if original == 'opengl':
- if target == 'unity':
- pose[2] *= -1
- elif target == 'blender':
- pose[2] *= -1
- pose[[1, 2]] = pose[[2, 1]]
- elif target in ['opencv', 'colmap']:
- pose[1:3] *= -1
- elif original == 'unity':
- if target == 'opengl':
- pose[2] *= -1
- elif target == 'blender':
- pose[[1, 2]] = pose[[2, 1]]
- elif target in ['opencv', 'colmap']:
- pose[1] *= -1
- elif original == 'blender':
- if target == 'opengl':
- pose[1] *= -1
- pose[[1, 2]] = pose[[2, 1]]
- elif target == 'unity':
- pose[[1, 2]] = pose[[2, 1]]
- elif target in ['opencv', 'colmap']:
- pose[2] *= -1
- pose[[1, 2]] = pose[[2, 1]]
- elif original in ['opencv', 'colmap']:
- if target == 'opengl':
- pose[1:3] *= -1
- elif target == 'unity':
- pose[1] *= -1
- elif target == 'blender':
- pose[1] *= -1
- pose[[1, 2]] = pose[[2, 1]]
- return pose
-
-
-
-
-[docs]
-def look_at(campos, target, opengl=True):
- """construct pose rotation matrix by look-at.
-
- Args:
- campos (np.ndarray): camera position, float [3]
- target (np.ndarray): look at target, float [3]
- opengl (bool, optional): whether use opengl camera convention (forward direction is target --> camera). Defaults to True.
-
- Returns:
- np.ndarray: the camera pose rotation matrix, float [3, 3], normalized.
- """
-
- if not opengl:
- # forward is camera --> target
- forward_vector = safe_normalize(target - campos)
- up_vector = np.array([0, 1, 0], dtype=np.float32)
- right_vector = safe_normalize(np.cross(forward_vector, up_vector))
- up_vector = safe_normalize(np.cross(right_vector, forward_vector))
- else:
- # forward is target --> camera
- forward_vector = safe_normalize(campos - target)
- up_vector = np.array([0, 1, 0], dtype=np.float32)
- right_vector = safe_normalize(np.cross(up_vector, forward_vector))
- up_vector = safe_normalize(np.cross(forward_vector, right_vector))
- R = np.stack([right_vector, up_vector, forward_vector], axis=1)
- return R
-
-
-
-
-[docs]
-def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
- """construct a camera pose matrix orbiting a target with elevation & azimuth angle.
-
- Args:
- elevation (float): elevation in (-90, 90), from +y to -y is (-90, 90)
- azimuth (float): azimuth in (-180, 180), from +z to +x is (0, 90)
- radius (int, optional): camera radius. Defaults to 1.
- is_degree (bool, optional): if the angles are in degree. Defaults to True.
- target (np.ndarray, optional): look at target position. Defaults to None.
- opengl (bool, optional): whether to use OpenGL camera convention. Defaults to True.
-
- Returns:
- np.ndarray: the camera pose matrix, float [4, 4]
- """
-
- if is_degree:
- elevation = np.deg2rad(elevation)
- azimuth = np.deg2rad(azimuth)
- x = radius * np.cos(elevation) * np.sin(azimuth)
- y = - radius * np.sin(elevation)
- z = radius * np.cos(elevation) * np.cos(azimuth)
- if target is None:
- target = np.zeros([3], dtype=np.float32)
- campos = np.array([x, y, z]) + target # [3]
- T = np.eye(4, dtype=np.float32)
- T[:3, :3] = look_at(campos, target, opengl)
- T[:3, 3] = campos
- return T
-
-
-
-
-[docs]
-def undo_orbit_camera(T, is_degree=True):
- """ undo an orbital camera pose matrix to elevation & azimuth
-
- Args:
- T (np.ndarray): camera pose matrix, float [4, 4], must be an orbital camera targeting at (0, 0, 0)!
- is_degree (bool, optional): whether to return angles in degree. Defaults to True.
-
- Returns:
- Tuple[float]: elevation, azimuth, and radius.
- """
-
- campos = T[:3, 3]
- radius = np.linalg.norm(campos)
- elevation = np.arcsin(-campos[1] / radius)
- azimuth = np.arctan2(campos[0], campos[2])
- if is_degree:
- elevation = np.rad2deg(elevation)
- azimuth = np.rad2deg(azimuth)
- return elevation, azimuth, radius
-
-
-# perspective matrix
-
-[docs]
-def get_perspective(fovy, aspect=1, near=0.01, far=1000):
- """construct a perspective matrix from fovy.
-
- Args:
- fovy (float): field of view in degree along y-axis.
- aspect (int, optional): aspect ratio. Defaults to 1.
- near (float, optional): near clip plane. Defaults to 0.01.
- far (int, optional): far clip plane. Defaults to 1000.
-
- Returns:
- np.ndarray: perspective matrix, float [4, 4]
- """
- # fovy: field of view in degree.
-
- y = np.tan(np.deg2rad(fovy) / 2)
- return np.array(
- [
- [1 / (y * aspect), 0, 0, 0],
- [0, -1 / y, 0, 0],
- [
- 0,
- 0,
- -(far + near) / (far - near),
- -(2 * far * near) / (far - near),
- ],
- [0, 0, -1, 0],
- ],
- dtype=np.float32,
- )
-
-
-
-
-[docs]
-def get_rays(pose, h, w, fovy, opengl=True, normalize_dir=True):
- """ construct rays origin and direction from a camera pose.
-
- Args:
- pose (np.ndarray): camera pose, float [4, 4]
- h (int): image height
- w (int): image width
- fovy (float): field of view in degree along y-axis.
- opengl (bool, optional): whether to use the OpenGL camera convention. Defaults to True.
- normalize_dir (bool, optional): whether to normalize the ray directions. Defaults to True.
-
- Returns:
- Tuple[np.ndarray]: rays_o and rays_d, both are float [h, w, 3]
- """
- # pose: [4, 4]
- # fov: in degree
- # opengl: camera front view convention
-
- x, y = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
- x = x.reshape(-1)
- y = y.reshape(-1)
-
- cx = w * 0.5
- cy = h * 0.5
-
- # objaverse rendering has fixed focal of 560 at resolution 512 --> fov = 49.1 degree
- focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
-
- camera_dirs = np.stack([
- (x - cx + 0.5) / focal,
- (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
- np.ones_like(x) * (-1.0 if opengl else 1.0),
- ], axis=-1) # [hw, 3]
-
- rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
- rays_o = np.expand_dims(pose[:3, 3], 0).repeat(rays_d.shape[0], 0) # [hw, 3]
-
- if normalize_dir:
- rays_d = safe_normalize(rays_d)
-
- rays_o = rays_o.reshape(h, w, 3)
- rays_d = rays_d.reshape(h, w, 3)
-
- return rays_o, rays_d
-
-
-
-[docs]
-class OrbitCamera:
- """ An orbital camera class.
- """
-
-[docs]
- def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
- """init function
-
- Args:
- W (int): image width
- H (int): image height
- r (int, optional): camera radius. Defaults to 2.
- fovy (int, optional): camera field of view in degree along y-axis. Defaults to 60.
- near (float, optional): near clip plane. Defaults to 0.01.
- far (int, optional): far clip plane. Defaults to 100.
- """
- self.W = W
- self.H = H
- self.radius = r # camera distance from center
- self.fovy = np.deg2rad(fovy) # deg 2 rad
- self.near = near
- self.far = far
- self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
- self.rot = Rotation.from_matrix(np.eye(3))
- self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
-
-
- @property
- def fovx(self):
- """get the field of view in radians along x-axis
-
- Returns:
- float: field of view in radians along x-axis
- """
- return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
-
- @property
- def campos(self):
- """get the camera position
-
- Returns:
- np.ndarray: camera position, float [3]
- """
- return self.pose[:3, 3]
-
-
- @property
- def pose(self):
- """get the camera pose matrix (cam2world)
-
- Returns:
- np.ndarray: camera pose, float [4, 4]
- """
- # first move camera to radius
- res = np.eye(4, dtype=np.float32)
- res[2, 3] = self.radius # opengl convention...
- # rotate
- rot = np.eye(4, dtype=np.float32)
- rot[:3, :3] = self.rot.as_matrix()
- res = rot @ res
- # translate
- res[:3, 3] -= self.center
- return res
-
-
- @property
- def view(self):
- """get the camera view matrix (world2cam, inverse of cam2world)
-
- Returns:
- np.ndarray: camera view, float [4, 4]
- """
- return np.linalg.inv(self.pose)
-
-
- @property
- def perspective(self):
- """get the perspective matrix
-
- Returns:
- np.ndarray: camera perspective, float [4, 4]
- """
- y = np.tan(self.fovy / 2)
- aspect = self.W / self.H
- return np.array(
- [
- [1 / (y * aspect), 0, 0, 0],
- [0, -1 / y, 0, 0],
- [
- 0,
- 0,
- -(self.far + self.near) / (self.far - self.near),
- -(2 * self.far * self.near) / (self.far - self.near),
- ],
- [0, 0, -1, 0],
- ],
- dtype=np.float32,
- )
-
- # intrinsics
- @property
- def intrinsics(self):
- """get the camera intrinsics
-
- Returns:
- np.ndarray: intrinsics (fx, fy, cx, cy), float [4]
- """
- focal = self.H / (2 * np.tan(self.fovy / 2))
- return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
-
-
- @property
- def mvp(self):
- """get the MVP (model-view-perspective) matrix.
-
- Returns:
- np.ndarray: camera MVP, float [4, 4]
- """
- return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
-
-
-[docs]
- def orbit(self, dx, dy):
- """ rotate along camera up/side axis!
-
- Args:
- dx (float): delta step along x (up).
- dy (float): delta step along y (side).
- """
- side = self.rot.as_matrix()[:3, 0]
- rotvec_x = self.up * np.radians(-0.05 * dx)
- rotvec_y = side * np.radians(-0.05 * dy)
- self.rot = Rotation.from_rotvec(rotvec_x) * Rotation.from_rotvec(rotvec_y) * self.rot
-
-
-
-[docs]
- def scale(self, delta):
- """scale the camera.
-
- Args:
- delta (float): delta step.
- """
- self.radius *= 1.1 ** (-delta)
-
-
-
-[docs]
- def pan(self, dx, dy, dz=0):
- """pan the camera.
-
- Args:
- dx (float): delta step along x.
- dy (float): delta step along y.
- dz (float, optional): delta step along x. Defaults to 0.
- """
- # pan in camera coordinate system (careful on the sensitivity!)
- self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
-
-
-
-[docs]
- def from_angle(self, elevation, azimuth, is_degree=True):
- """set the camera pose from elevation & azimuth angle.
-
- Args:
- elevation (float): elevation in (-90, 90), from +y to -y is (-90, 90)
- azimuth (float): azimuth in (-180, 180), from +z to +x is (0, 90)
- is_degree (bool, optional): whether the angles are in degree. Defaults to True.
- """
- if is_degree:
- elevation = np.deg2rad(elevation)
- azimuth = np.deg2rad(azimuth)
- x = self.radius * np.cos(elevation) * np.sin(azimuth)
- y = - self.radius * np.sin(elevation)
- z = self.radius * np.cos(elevation) * np.cos(azimuth)
- campos = np.array([x, y, z]) # [N, 3]
- rot_mat = look_at(campos, np.zeros([3], dtype=np.float32))
- self.rot = Rotation.from_matrix(rot_mat)
-
-
-
-import torch
-import torch.nn.functional as F
-
-from kiui.typing import *
-
-def stride_from_shape(shape):
- stride = [1]
- for x in reversed(shape[1:]):
- stride.append(stride[-1] * x)
- return list(reversed(stride))
-
-
-def scatter_add_nd(input, indices, values):
- # input: [..., C], D dimension + C channel
- # indices: [N, D], long
- # values: [N, C]
-
- D = indices.shape[-1]
- C = input.shape[-1]
- size = input.shape[:-1]
- stride = stride_from_shape(size)
-
- assert len(size) == D
-
- input = input.view(-1, C) # [HW, C]
- flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
-
- input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
-
- return input.view(*size, C)
-
-
-def scatter_add_nd_with_count(input, count, indices, values, weights=None):
- # input: [..., C], D dimension + C channel
- # count: [..., 1], D dimension
- # indices: [N, D], long
- # values: [N, C]
-
- D = indices.shape[-1]
- C = input.shape[-1]
- size = input.shape[:-1]
- stride = stride_from_shape(size)
-
- assert len(size) == D
-
- input = input.view(-1, C) # [HW, C]
- count = count.view(-1, 1)
-
- flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
-
- if weights is None:
- weights = torch.ones_like(values[..., :1])
-
- input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
- count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
-
- return input.view(*size, C), count.view(*size, 1)
-
-def nearest_grid_put_2d(H, W, coords, values, return_count=False):
- # coords: [N, 2], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- indices = (coords * 0.5 + 0.5) * torch.tensor(
- [H - 1, W - 1], dtype=torch.float32, device=coords.device
- )
- indices = indices.round().long() # [N, 2]
-
- result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
- count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
- weights = torch.ones_like(values[..., :1]) # [N, 1]
-
- result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-
-def linear_grid_put_2d(H, W, coords, values, return_count=False):
- # coords: [N, 2], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- indices = (coords * 0.5 + 0.5) * torch.tensor(
- [H - 1, W - 1], dtype=torch.float32, device=coords.device
- )
- indices_00 = indices.floor().long() # [N, 2]
- indices_00[:, 0].clamp_(0, H - 2)
- indices_00[:, 1].clamp_(0, W - 2)
- indices_01 = indices_00 + torch.tensor(
- [0, 1], dtype=torch.long, device=indices.device
- )
- indices_10 = indices_00 + torch.tensor(
- [1, 0], dtype=torch.long, device=indices.device
- )
- indices_11 = indices_00 + torch.tensor(
- [1, 1], dtype=torch.long, device=indices.device
- )
-
- h = indices[..., 0] - indices_00[..., 0].float()
- w = indices[..., 1] - indices_00[..., 1].float()
- w_00 = (1 - h) * (1 - w)
- w_01 = (1 - h) * w
- w_10 = h * (1 - w)
- w_11 = h * w
-
- result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
- count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
- weights = torch.ones_like(values[..., :1]) # [N, 1]
-
- result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
- # coords: [N, 2], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
- count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
-
- cur_H, cur_W = H, W
-
- while min(cur_H, cur_W) > min_resolution:
-
- # try to fill the holes
- mask = (count.squeeze(-1) == 0)
- if not mask.any():
- break
-
- cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
- result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
- count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
- cur_H //= 2
- cur_W //= 2
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
- # coords: [N, 3], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- indices = (coords * 0.5 + 0.5) * torch.tensor(
- [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
- )
- indices = indices.round().long() # [N, 2]
-
- result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
- count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
- weights = torch.ones_like(values[..., :1]) # [N, 1]
-
- result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-
-def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
- # coords: [N, 3], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- indices = (coords * 0.5 + 0.5) * torch.tensor(
- [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
- )
- indices_000 = indices.floor().long() # [N, 3]
- indices_000[:, 0].clamp_(0, H - 2)
- indices_000[:, 1].clamp_(0, W - 2)
- indices_000[:, 2].clamp_(0, D - 2)
-
- indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
- indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
- indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
- indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
- indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
- indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
- indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
-
- h = indices[..., 0] - indices_000[..., 0].float()
- w = indices[..., 1] - indices_000[..., 1].float()
- d = indices[..., 2] - indices_000[..., 2].float()
-
- w_000 = (1 - h) * (1 - w) * (1 - d)
- w_001 = (1 - h) * w * (1 - d)
- w_010 = h * (1 - w) * (1 - d)
- w_011 = h * w * (1 - d)
- w_100 = (1 - h) * (1 - w) * d
- w_101 = (1 - h) * w * d
- w_110 = h * (1 - w) * d
- w_111 = h * w * d
-
- result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
- count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
- weights = torch.ones_like(values[..., :1]) # [N, 1]
-
- result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
- result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
- # coords: [N, 3], float in [-1, 1]
- # values: [N, C]
-
- C = values.shape[-1]
-
- result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
- count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
- cur_H, cur_W, cur_D = H, W, D
-
- while min(min(cur_H, cur_W), cur_D) > min_resolution:
-
- # try to fill the holes
- mask = (count.squeeze(-1) == 0)
- if not mask.any():
- break
-
- cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
- result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
- count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
- cur_H //= 2
- cur_W //= 2
- cur_D //= 2
-
- if return_count:
- return result, count
-
- mask = (count.squeeze(-1) > 0)
- result[mask] = result[mask] / count[mask].repeat(1, C)
-
- return result
-
-
-
-[docs]
-def grid_put(shape: Sequence[int], coords: Tensor, values: Tensor, mode: Literal['nearest', 'linear', 'linear-mipmap']='linear-mipmap', min_resolution: int=32, return_count: bool=False) -> Tensor:
- """ put back values to an image according to the coords. inverse operation of ``F.grid_sample``.
-
- Args:
- shape (Sequence[int]): shape of the image, support 2D image and 3D volume, sequence of [D]
- coords (Tensor): coordinates, float [N, D] in [-1, 1].
- values (Tensor): values, float [N, C].
- mode (str, Literal['nearest', 'linear', 'linear-mipmap']): interpolation mode, see https://github.com/ashawkey/grid_put for examples. Defaults to 'linear-mipmap'.
- min_resolution (int, optional): minimal resolution for mipmap. Defaults to 32.
- return_count (bool, optional): whether to return the summed value and weights, instead of the divided results. Defaults to False.
-
- Returns:
- Tensor: the restored image/volume, float [H, W, C]/[H, W, D, C].
- """
-
- D = len(shape)
- assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
-
- if mode == 'nearest':
- if D == 2:
- return nearest_grid_put_2d(*shape, coords, values, return_count)
- else:
- return nearest_grid_put_3d(*shape, coords, values, return_count)
- elif mode == 'linear':
- if D == 2:
- return linear_grid_put_2d(*shape, coords, values, return_count)
- else:
- return linear_grid_put_3d(*shape, coords, values, return_count)
- elif mode == 'linear-mipmap':
- if D == 2:
- return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_count)
- else:
- return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_count)
- else:
- raise NotImplementedError(f"got mode {mode}")
-
-
-import os
-import cv2
-import torch
-import trimesh
-import numpy as np
-from packaging import version
-
-from kiui.op import safe_normalize, dot
-from kiui.typing import *
-
-
-[docs]
-class Mesh:
- """
- A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
-
- Note:
- This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
- """
-
-[docs]
- def __init__(
- self,
- v: Optional[Tensor] = None,
- f: Optional[Tensor] = None,
- vn: Optional[Tensor] = None,
- fn: Optional[Tensor] = None,
- vt: Optional[Tensor] = None,
- ft: Optional[Tensor] = None,
- vc: Optional[Tensor] = None, # vertex color
- albedo: Optional[Tensor] = None,
- metallicRoughness: Optional[Tensor] = None,
- device: Optional[torch.device] = None,
- ):
- """Init a mesh directly using all attributes.
-
- Args:
- v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
- f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
- vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
- fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
- vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
- ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
- vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
- albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
- metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
- device (Optional[torch.device]): torch device. Defaults to None.
- """
- self.device = device
- self.v = v
- self.vn = vn
- self.vt = vt
- self.f = f
- self.fn = fn
- self.ft = ft
- # will first see if there is vertex color to use
- self.vc = vc
- # only support a single albedo image
- self.albedo = albedo
- # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
- # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
- self.metallicRoughness = metallicRoughness
-
- self.ori_center = 0
- self.ori_scale = 1
-
-
- def __repr__(self):
- out = f'<kiui.mesh.Mesh>'
- if self.v is not None: out += f' v={self.v.shape}'
- if self.f is not None: out += f' f={self.f.shape}'
- if self.vc is not None: out += f' vc={self.vc.shape}'
- if self.albedo is not None: out += f' albedo={self.albedo.shape}'
- if self.metallicRoughness is not None: out += f' metallicRoughness={self.metallicRoughness.shape}'
- return out
-
-
-[docs]
- @classmethod
- def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
- """load mesh from path.
-
- Args:
- path (str): path to mesh file, supports ply, obj, glb.
- clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
- resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
- renormal (bool, optional): re-calc the vertex normals. Defaults to True.
- retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
- wotex (bool, optional): do not try to load any texture. Defaults to False.
- bound (float, optional): bound to resize. Defaults to 0.9.
- front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
- device (torch.device, optional): torch device. Defaults to None.
-
- Note:
- a ``device`` keyword argument can be provided to specify the torch device.
- If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- # obj supports face uv
- if path.endswith(".obj"):
- mesh = cls.load_obj(path, **kwargs)
- # trimesh only supports vertex uv, but can load more formats
- else:
- mesh = cls.load_trimesh(path, **kwargs)
-
- # clean
- if clean:
- from kiui.mesh_utils import clean_mesh
- vertices = mesh.v.detach().cpu().numpy()
- triangles = mesh.f.detach().cpu().numpy()
- vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
- mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
- mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
-
- # auto-normalize
- if resize:
- mesh.auto_size(bound=bound)
- print(f"[INFO] load mesh, v: {mesh.v.shape}, f: {mesh.f.shape}")
-
- # auto-fix normal
- if renormal or mesh.vn is None:
- mesh.auto_normal()
- print(f"[INFO] load mesh, vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
-
- # auto-fix texcoords
- if retex:
- mesh.auto_uv(cache_path=path)
- if mesh.vt is not None:
- print(f"[INFO] load mesh, vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
-
- # rotate front dir to +z
- if front_dir != "+z":
- # axis switch
- if "-z" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
- elif "+x" in front_dir:
- T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
- elif "-x" in front_dir:
- T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
- elif "+y" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
- elif "-y" in front_dir:
- T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
- else:
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- # rotation (how many 90 degrees)
- if '1' in front_dir:
- T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- elif '2' in front_dir:
- T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- elif '3' in front_dir:
- T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
- mesh.v @= T
- mesh.vn @= T
-
- return mesh
-
-
- # load from obj file
-
-[docs]
- @classmethod
- def load_obj(cls, path, wotex=False, albedo_path=None, device=None):
- """load an ``obj`` mesh.
-
- Args:
- path (str): path to mesh.
- wotex (bool, optional): do not try to load any texture. Defaults to False.
- albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
- device (torch.device, optional): torch device. Defaults to None.
-
- Note:
- We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
- The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- assert os.path.splitext(path)[-1] == ".obj"
-
- mesh = cls()
-
- # device
- if device is None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- mesh.device = device
-
- # load obj
- with open(path, "r") as f:
- lines = f.readlines()
-
- def parse_f_v(fv):
- # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
- # supported forms:
- # f v1 v2 v3
- # f v1/vt1 v2/vt2 v3/vt3
- # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
- # f v1//vn1 v2//vn2 v3//vn3
- xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
- xs.extend([-1] * (3 - len(xs)))
- return xs[0], xs[1], xs[2]
-
- vertices, texcoords, normals = [], [], []
- faces, tfaces, nfaces = [], [], []
- mtl_path = None
-
- for line in lines:
- split_line = line.split()
- # empty line
- if len(split_line) == 0:
- continue
- prefix = split_line[0].lower()
- # mtllib
- if prefix == "mtllib":
- mtl_path = split_line[1]
- # usemtl
- elif prefix == "usemtl":
- pass # ignored
- # v/vn/vt
- elif prefix == "v":
- vertices.append([float(v) for v in split_line[1:]])
- elif prefix == "vn":
- normals.append([float(v) for v in split_line[1:]])
- elif prefix == "vt":
- val = [float(v) for v in split_line[1:]]
- texcoords.append([val[0], 1.0 - val[1]])
- elif prefix == "f":
- vs = split_line[1:]
- nv = len(vs)
- v0, t0, n0 = parse_f_v(vs[0])
- for i in range(nv - 2): # triangulate (assume vertices are ordered)
- v1, t1, n1 = parse_f_v(vs[i + 1])
- v2, t2, n2 = parse_f_v(vs[i + 2])
- faces.append([v0, v1, v2])
- tfaces.append([t0, t1, t2])
- nfaces.append([n0, n1, n2])
-
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
- mesh.vt = (
- torch.tensor(texcoords, dtype=torch.float32, device=device)
- if len(texcoords) > 0
- else None
- )
- mesh.vn = (
- torch.tensor(normals, dtype=torch.float32, device=device)
- if len(normals) > 0
- else None
- )
-
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
- mesh.ft = (
- torch.tensor(tfaces, dtype=torch.int32, device=device)
- if len(texcoords) > 0
- else None
- )
- mesh.fn = (
- torch.tensor(nfaces, dtype=torch.int32, device=device)
- if len(normals) > 0
- else None
- )
-
- # if not loading texture
- if wotex:
- return mesh
-
- # see if there is vertex color
- use_vertex_color = False
- if mesh.v.shape[1] == 6:
- use_vertex_color = True
- mesh.vc = mesh.v[:, 3:]
- mesh.v = mesh.v[:, :3]
- print(f"[INFO] load obj mesh: use vertex color: {mesh.vc.shape}")
-
- # try to load texture image
- if not use_vertex_color:
- # try to retrieve mtl file
- mtl_path_candidates = []
- if mtl_path is not None:
- mtl_path_candidates.append(mtl_path)
- mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
- mtl_path_candidates.append(path.replace(".obj", ".mtl"))
-
- mtl_path = None
- for candidate in mtl_path_candidates:
- if os.path.exists(candidate):
- mtl_path = candidate
- break
-
- # if albedo_path is not provided, try retrieve it from mtl
- metallic_path = None
- roughness_path = None
- if mtl_path is not None and albedo_path is None:
- with open(mtl_path, "r") as f:
- lines = f.readlines()
-
- for line in lines:
- split_line = line.split()
- # empty line
- if len(split_line) == 0:
- continue
- prefix = split_line[0]
-
- if "map_Kd" in prefix:
- # assume relative path!
- albedo_path = os.path.join(os.path.dirname(path), split_line[1])
- print(f"[INFO] load obj mesh: use texture from: {albedo_path}")
- elif "map_Pm" in prefix:
- metallic_path = os.path.join(os.path.dirname(path), split_line[1])
- elif "map_Pr" in prefix:
- roughness_path = os.path.join(os.path.dirname(path), split_line[1])
-
- # still not found albedo_path, or the path doesn't exist
- if albedo_path is None or not os.path.exists(albedo_path):
- print(f"[INFO] load obj mesh: failed to load texture!")
- mesh.albedo = None
- else:
- albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
- albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
- albedo = albedo.astype(np.float32) / 255
- print(f"[INFO] load obj mesh: load texture: {albedo.shape}")
- mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
-
- # try to load metallic and roughness
- if metallic_path is not None and roughness_path is not None:
- print(f"[INFO] load obj mesh: load metallicRoughness from: {metallic_path}, {roughness_path}")
- metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
- metallic = metallic.astype(np.float32) / 255
- roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
- roughness = roughness.astype(np.float32) / 255
- metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
-
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
-
- return mesh
-
-
-
-[docs]
- @classmethod
- def load_trimesh(cls, path, wotex=False, device=None):
- """load a mesh using ``trimesh.load()``.
-
- Can load various formats like ``glb`` and serves as a fallback.
-
- Note:
- We will try to merge all meshes if the glb contains more than one,
- but **this may cause the texture to lose**, since we only support one texture image!
-
- Args:
- path (str): path to the mesh file.
- wotex (bool, optional): do not try to load any texture. Defaults to False.
- device (torch.device, optional): torch device. Defaults to None.
-
- Returns:
- Mesh: the loaded Mesh object.
- """
- mesh = cls()
-
- # device
- if device is None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- mesh.device = device
-
- # use trimesh to load ply/glb
- _data = trimesh.load(path)
- # always convert scene to mesh, and apply all transforms...
- if isinstance(_data, trimesh.Scene):
- print(f"[INFO] load trimesh: concatenating {len(_data.geometry)} meshes.")
- _concat = []
- # loop the scene graph and apply transform to each mesh
- scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
- for k, v in scene_graph.items():
- name = v['geometry']
- if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
- transform = v['transform']
- _concat.append(_data.geometry[name].apply_transform(transform))
- _mesh = trimesh.util.concatenate(_concat)
- else:
- _mesh = _data
-
- if not wotex:
- if _mesh.visual.kind == 'vertex':
- vertex_colors = _mesh.visual.vertex_colors
- vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
- mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
- print(f"[INFO] load trimesh: use vertex color: {mesh.vc.shape}")
- elif _mesh.visual.kind == 'texture':
- try:
- _material = _mesh.visual.material
- if isinstance(_material, trimesh.visual.material.PBRMaterial):
- texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
- # load metallicRoughness if present
- if _material.metallicRoughnessTexture is not None:
- metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
- # NOTE: fix a bug in trimesh that loads metallicRoughness in wrong channels: https://github.com/mikedh/trimesh/issues/2195
- if version.parse(trimesh.__version__) < version.parse('4.2.2'):
- metallicRoughness = metallicRoughness[..., [2, 1, 0]]
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
- elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
- texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
- else:
- raise NotImplementedError(f"material type {type(_material)} not supported!")
- if len(texture.shape) == 2:
- texture = texture[..., None].repeat(3, axis=-1)
- mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
- print(f"[INFO] load trimesh: load texture: {texture.shape}")
- # there really can be lots of mysterious errors...
- except Exception as e:
- mesh.albedo = None
- print(f"[INFO] load trimesh: failed to load texture.")
- else:
- mesh.albedo = None
- print(f"[INFO] load trimesh: failed to load texture.")
-
- vertices = _mesh.vertices
-
- try:
- texcoords = _mesh.visual.uv
- texcoords[:, 1] = 1 - texcoords[:, 1]
- except Exception as e:
- texcoords = None
-
- try:
- normals = _mesh.vertex_normals
- except Exception as e:
- normals = None
-
- # trimesh only support vertex uv...
- faces = tfaces = nfaces = _mesh.faces
-
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
- mesh.vt = (
- torch.tensor(texcoords, dtype=torch.float32, device=device)
- if texcoords is not None
- else None
- )
- mesh.vn = (
- torch.tensor(normals, dtype=torch.float32, device=device)
- if normals is not None
- else None
- )
-
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
- mesh.ft = (
- torch.tensor(tfaces, dtype=torch.int32, device=device)
- if texcoords is not None
- else None
- )
- mesh.fn = (
- torch.tensor(nfaces, dtype=torch.int32, device=device)
- if normals is not None
- else None
- )
-
- return mesh
-
-
- # sample surface (using trimesh)
-
-[docs]
- def sample_surface(self, count: int):
- """sample points on the surface of the mesh.
-
- Args:
- count (int): number of points to sample.
-
- Returns:
- torch.Tensor: the sampled points, float [count, 3].
- """
- _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
- points, face_idx = trimesh.sample.sample_surface(_mesh, count)
- points = torch.from_numpy(points).float().to(self.device)
- return points
-
-
- # aabb
-
-[docs]
- def aabb(self):
- """get the axis-aligned bounding box of the mesh.
-
- Returns:
- Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
- """
- return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
-
-
- # unit size
-
-[docs]
- @torch.no_grad()
- def auto_size(self, bound=0.9):
- """auto resize the mesh.
-
- Args:
- bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
- """
- vmin, vmax = self.aabb()
- self.ori_center = (vmax + vmin) / 2
- self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
- self.v = (self.v - self.ori_center) * self.ori_scale
-
-
-
-[docs]
- def auto_normal(self):
- """auto calculate the vertex normals.
- """
- i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
- v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
-
- face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
-
- # Splat face normals to vertices
- vn = torch.zeros_like(self.v)
- vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
- vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
- vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
-
- # Normalize, replace zero (degenerated) normals with some default value
- vn = torch.where(
- dot(vn, vn) > 1e-20,
- vn,
- torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
- )
- vn = safe_normalize(vn)
-
- self.vn = vn
- self.fn = self.f
-
-
-
-[docs]
- def auto_uv(self, cache_path=None, vmap=True):
- """auto calculate the uv coordinates.
-
- Args:
- cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
- vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
- Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
- """
- # try to load cache
- if cache_path is not None:
- cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
- if cache_path is not None and os.path.exists(cache_path):
- data = np.load(cache_path)
- vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
- else:
- import xatlas
-
- v_np = self.v.detach().cpu().numpy()
- f_np = self.f.detach().int().cpu().numpy()
- atlas = xatlas.Atlas()
- atlas.add_mesh(v_np, f_np)
- chart_options = xatlas.ChartOptions()
- # chart_options.max_iterations = 4
- atlas.generate(chart_options=chart_options)
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
-
- # save to cache
- if cache_path is not None:
- np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
-
- vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
- ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
- self.vt = vt
- self.ft = ft
-
- if vmap:
- vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
- self.align_v_to_vt(vmapping)
-
-
-
-[docs]
- def remap_uv(self, v):
- """ remap uv texture (vt) to other surface.
-
- Args:
- v (torch.Tensor): the target mesh vertices, float [N, 3].
- """
-
- assert self.vt is not None
-
- if self.v.shape[0] != self.vt.shape[0]:
- self.align_v_to_vt()
-
- # find the closest face for each vertex
- import cubvh
- BVH = cubvh.cuBVH(self.v, self.f)
- dist, face_id, uvw = BVH.unsigned_distance(v, return_uvw=True)
-
- # get original uv
- faces = self.f[face_id].long()
- vt0 = self.vt[faces[:, 0]]
- vt1 = self.vt[faces[:, 1]]
- vt2 = self.vt[faces[:, 2]]
-
- # calc new uv
- vt = vt0 * uvw[:, 0:1] + vt1 * uvw[:, 1:2] + vt2 * uvw[:, 2:3]
-
- return vt
-
-
-
-
-[docs]
- def align_v_to_vt(self, vmapping=None):
- """ remap v/f and vn/fn to vt/ft.
-
- Args:
- vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
- """
- if vmapping is None:
- ft = self.ft.view(-1).long()
- f = self.f.view(-1).long()
- vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
- vmapping[ft] = f # scatter, randomly choose one if index is not unique
-
- self.v = self.v[vmapping]
- self.f = self.ft
-
- if self.vn is not None:
- self.vn = self.vn[vmapping]
- self.fn = self.ft
-
-
-
-[docs]
- def to(self, device):
- """move all tensor attributes to device.
-
- Args:
- device (torch.device): target device.
-
- Returns:
- Mesh: self.
- """
- self.device = device
- for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
- tensor = getattr(self, name)
- if tensor is not None:
- setattr(self, name, tensor.to(device))
- return self
-
-
-
-[docs]
- def write(self, path):
- """write the mesh to a path.
-
- Args:
- path (str): path to write, supports ply, obj and glb.
- """
- if path.endswith(".ply"):
- self.write_ply(path)
- elif path.endswith(".obj"):
- self.write_obj(path)
- elif path.endswith(".glb") or path.endswith(".gltf"):
- self.write_glb(path)
- else:
- raise NotImplementedError(f"format {path} not supported!")
-
-
-
-[docs]
- def write_ply(self, path):
- """write the mesh in ply format. Only for geometry!
-
- Args:
- path (str): path to write.
- """
-
- if self.albedo is not None:
- print(f'[WARN] ply format does not support exporting texture, will ignore!')
-
- v_np = self.v.detach().cpu().numpy()
- f_np = self.f.detach().cpu().numpy()
-
- _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
- _mesh.export(path)
-
-
-
-
-[docs]
- def write_glb(self, path):
- """write the mesh in glb/gltf format.
- This will create a scene with a single mesh.
-
- Args:
- path (str): path to write.
- """
-
- # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
- if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
- self.align_v_to_vt()
-
- import pygltflib
-
- f_np = self.f.detach().cpu().numpy().astype(np.uint32)
- f_np_blob = f_np.flatten().tobytes()
-
- v_np = self.v.detach().cpu().numpy().astype(np.float32)
- v_np_blob = v_np.tobytes()
-
- blob = f_np_blob + v_np_blob
- byteOffset = len(blob)
-
- # base mesh
- gltf = pygltflib.GLTF2(
- scene=0,
- scenes=[pygltflib.Scene(nodes=[0])],
- nodes=[pygltflib.Node(mesh=0)],
- meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
- # indices to accessors (0 is triangles)
- attributes=pygltflib.Attributes(
- POSITION=1,
- ),
- indices=0,
- )])],
- buffers=[
- pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
- ],
- # buffer view (based on dtype)
- bufferViews=[
- # triangles; as flatten (element) array
- pygltflib.BufferView(
- buffer=0,
- byteLength=len(f_np_blob),
- target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
- ),
- # positions; as vec3 array
- pygltflib.BufferView(
- buffer=0,
- byteOffset=len(f_np_blob),
- byteLength=len(v_np_blob),
- byteStride=12, # vec3
- target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
- ),
- ],
- accessors=[
- # 0 = triangles
- pygltflib.Accessor(
- bufferView=0,
- componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
- count=f_np.size,
- type=pygltflib.SCALAR,
- max=[int(f_np.max())],
- min=[int(f_np.min())],
- ),
- # 1 = positions
- pygltflib.Accessor(
- bufferView=1,
- componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
- count=len(v_np),
- type=pygltflib.VEC3,
- max=v_np.max(axis=0).tolist(),
- min=v_np.min(axis=0).tolist(),
- ),
- ],
- )
-
- # append texture info
- if self.vt is not None:
-
- vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
- vt_np_blob = vt_np.tobytes()
-
- albedo = self.albedo.detach().cpu().numpy()
- albedo = (albedo * 255).astype(np.uint8)
- albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
- albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
-
- # update primitive
- gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
- gltf.meshes[0].primitives[0].material = 0
-
- # update materials
- gltf.materials.append(pygltflib.Material(
- pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
- baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
- metallicFactor=0.0,
- roughnessFactor=1.0,
- ),
- alphaMode=pygltflib.OPAQUE,
- alphaCutoff=None,
- doubleSided=True,
- ))
-
- gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
- gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
-
- # update buffers
- gltf.bufferViews.append(
- # index = 2, texcoords; as vec2 array
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(vt_np_blob),
- byteStride=8, # vec2
- target=pygltflib.ARRAY_BUFFER,
- )
- )
-
- gltf.accessors.append(
- # 2 = texcoords
- pygltflib.Accessor(
- bufferView=2,
- componentType=pygltflib.FLOAT,
- count=len(vt_np),
- type=pygltflib.VEC2,
- max=vt_np.max(axis=0).tolist(),
- min=vt_np.min(axis=0).tolist(),
- )
- )
-
- blob += vt_np_blob
- byteOffset += len(vt_np_blob)
-
- gltf.bufferViews.append(
- # index = 3, albedo texture; as none target
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(albedo_blob),
- )
- )
-
- blob += albedo_blob
- byteOffset += len(albedo_blob)
-
- gltf.buffers[0].byteLength = byteOffset
-
- # append metllic roughness
- if self.metallicRoughness is not None:
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
- metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
- metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
-
- # update texture definition
- gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
- gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
- gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
-
- gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
- gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
-
- # update buffers
- gltf.bufferViews.append(
- # index = 4, metallicRoughness texture; as none target
- pygltflib.BufferView(
- buffer=0,
- byteOffset=byteOffset,
- byteLength=len(metallicRoughness_blob),
- )
- )
-
- blob += metallicRoughness_blob
- byteOffset += len(metallicRoughness_blob)
-
- gltf.buffers[0].byteLength = byteOffset
-
-
- # set actual data
- gltf.set_binary_blob(blob)
-
- # glb = b"".join(gltf.save_to_bytes())
- gltf.save(path)
-
-
-
-
-[docs]
- def write_obj(self, path):
- """write the mesh in obj format. Will also write the texture and mtl files.
-
- Args:
- path (str): path to write.
- """
-
- mtl_path = path.replace(".obj", ".mtl")
- albedo_path = path.replace(".obj", "_albedo.png")
- metallic_path = path.replace(".obj", "_metallic.png")
- roughness_path = path.replace(".obj", "_roughness.png")
-
- v_np = self.v.detach().cpu().numpy()
- vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
- vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
- f_np = self.f.detach().cpu().numpy()
- ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
- fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
-
- with open(path, "w") as fp:
- fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
-
- for v in v_np:
- fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
-
- if vt_np is not None:
- for v in vt_np:
- fp.write(f"vt {v[0]} {1 - v[1]} \n")
-
- if vn_np is not None:
- for v in vn_np:
- fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
-
- fp.write(f"usemtl defaultMat \n")
- for i in range(len(f_np)):
- fp.write(
- f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
- {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
- {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
- )
-
- with open(mtl_path, "w") as fp:
- fp.write(f"newmtl defaultMat \n")
- fp.write(f"Ka 1 1 1 \n")
- fp.write(f"Kd 1 1 1 \n")
- fp.write(f"Ks 0 0 0 \n")
- fp.write(f"Tr 1 \n")
- fp.write(f"illum 1 \n")
- fp.write(f"Ns 0 \n")
- if self.albedo is not None:
- fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
- if self.metallicRoughness is not None:
- # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
- fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
- fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
-
- if self.albedo is not None:
- albedo = self.albedo.detach().cpu().numpy()
- albedo = (albedo * 255).astype(np.uint8)
- cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
-
- if self.metallicRoughness is not None:
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
- cv2.imwrite(metallic_path, metallicRoughness[..., 2])
- cv2.imwrite(roughness_path, metallicRoughness[..., 1])
-
-
-
-
-import torch
-import numpy as np
-
-from kiui.op import safe_normalize
-
-import pymeshlab as pml
-from importlib.metadata import version
-
-PML_VER = version('pymeshlab')
-
-# the code assumes the latest 2023.12 version, but we can patch older versions
-if PML_VER.startswith('0.2'):
- # monkey patch for 0.2 (only the used functions in this file!)
- pml.MeshSet.meshing_decimation_quadric_edge_collapse = pml.MeshSet.simplification_quadric_edge_collapse_decimation
- pml.MeshSet.meshing_isotropic_explicit_remeshing = pml.MeshSet.remeshing_isotropic_explicit_remeshing
- pml.MeshSet.meshing_remove_unreferenced_vertices = pml.MeshSet.remove_unreferenced_vertices
- pml.MeshSet.meshing_merge_close_vertices = pml.MeshSet.merge_close_vertices
- pml.MeshSet.meshing_remove_duplicate_faces = pml.MeshSet.remove_duplicate_faces
- pml.MeshSet.meshing_remove_null_faces = pml.MeshSet.remove_zero_area_faces
- pml.MeshSet.meshing_remove_connected_component_by_diameter = pml.MeshSet.remove_isolated_pieces_wrt_diameter
- pml.MeshSet.meshing_remove_connected_component_by_face_number = pml.MeshSet.remove_isolated_pieces_wrt_face_num
- pml.MeshSet.meshing_repair_non_manifold_edges = pml.MeshSet.repair_non_manifold_edges_by_removing_faces
- pml.MeshSet.meshing_repair_non_manifold_vertices = pml.MeshSet.repair_non_manifold_vertices_by_splitting
- pml.PercentageValue = pml.Percentage
- pml.PureValue = float
-elif PML_VER.startswith('2022.2'):
- # monkey patch for 2022.2
- pml.PercentageValue = pml.Percentage
- pml.PureValue = pml.AbsoluteValue
-
-
-
-[docs]
-def decimate_mesh(
- verts, faces, target=5e4, backend="pymeshlab", remesh=False, optimalplacement=True, verbose=True
-):
- """ perform mesh decimation.
-
- Args:
- verts (np.ndarray): mesh vertices, float [N, 3]
- faces (np.ndarray): mesh faces, int [M, 3]
- target (int): targeted number of faces
- backend (str, optional): algorithm backend, can be "pymeshlab" or "pyfqmr". Defaults to "pymeshlab".
- remesh (bool, optional): whether to remesh after decimation. Defaults to False.
- optimalplacement (bool, optional): For flat mesh, use False to prevent spikes. Defaults to True.
- verbose (bool, optional): whether to print the decimation process. Defaults to True.
-
- Returns:
- Tuple[np.ndarray]: vertices and faces after decimation.
- """
-
- _ori_vert_shape = verts.shape
- _ori_face_shape = faces.shape
-
- if backend == "pyfqmr":
- import pyfqmr
-
- solver = pyfqmr.Simplify()
- solver.setMesh(verts, faces)
- solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
- verts, faces, normals = solver.getMesh()
- else:
- m = pml.Mesh(verts, faces)
- ms = pml.MeshSet()
- ms.add_mesh(m, "mesh") # will copy!
-
- # filters
- # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))
- ms.meshing_decimation_quadric_edge_collapse(
- targetfacenum=int(target), optimalplacement=optimalplacement
- )
-
- if remesh:
- # ms.apply_coord_taubin_smoothing()
- ms.meshing_isotropic_explicit_remeshing(
- iterations=3, targetlen=pml.PercentageValue(1)
- )
-
- # extract mesh
- m = ms.current_mesh()
- m.compact()
- verts = m.vertex_matrix()
- faces = m.face_matrix()
-
- if verbose:
- print(f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}")
-
- return verts, faces
-
-
-
-
-[docs]
-def clean_mesh(
- verts,
- faces,
- v_pct=1,
- min_f=0,
- min_d=0,
- repair=True,
- remesh=False,
- remesh_size=0.01,
- remesh_iters=3,
- verbose=True,
-):
- """ perform mesh cleaning, including floater removal, non manifold repair, and remeshing.
-
- Args:
- verts (np.ndarray): mesh vertices, float [N, 3]
- faces (np.ndarray): mesh faces, int [M, 3]
- v_pct (int, optional): percentage threshold to merge close vertices. Defaults to 1.
- min_f (int, optional): maximal number of faces for isolated component to remove. Defaults to 0.
- min_d (int, optional): maximal diameter percentage of isolated component to remove. Defaults to 0.
- repair (bool, optional): whether to repair non-manifold faces (cannot gurantee). Defaults to True.
- remesh (bool, optional): whether to perform a remeshing after all cleaning. Defaults to True.
- remesh_size (float, optional): the targeted edge length for remeshing. Defaults to 0.01.
- remesh_iters (int, optional): the iterations of remeshing. Defaults to 3.
- verbose (bool, optional): whether to print the cleaning process. Defaults to True.
-
- Returns:
- Tuple[np.ndarray]: vertices and faces after decimation.
- """
- # verts: [N, 3]
- # faces: [N, 3]
-
- _ori_vert_shape = verts.shape
- _ori_face_shape = faces.shape
-
- m = pml.Mesh(verts, faces)
- ms = pml.MeshSet()
- ms.add_mesh(m, "mesh") # will copy!
-
- # filters
- ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
-
- if v_pct > 0:
- ms.meshing_merge_close_vertices(
- threshold=pml.PercentageValue(v_pct)
- ) # 1/10000 of bounding box diagonal
-
- ms.meshing_remove_duplicate_faces() # faces defined by the same verts
- ms.meshing_remove_null_faces() # faces with area == 0
-
- if min_d > 0:
- ms.meshing_remove_connected_component_by_diameter(
- mincomponentdiag=pml.PercentageValue(min_d)
- )
-
- if min_f > 0:
- ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
-
- if repair:
- # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
- ms.meshing_repair_non_manifold_edges(method=0)
- ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
-
- if remesh:
- # ms.apply_coord_taubin_smoothing()
- ms.meshing_isotropic_explicit_remeshing(
- iterations=remesh_iters, targetlen=pml.PureValue(remesh_size)
- )
-
- # extract mesh
- m = ms.current_mesh()
- m.compact()
- verts = m.vertex_matrix()
- faces = m.face_matrix()
-
- if verbose:
- print(f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}")
-
- return verts, faces
-
-
-
-
-### mesh related losses
-
-
-[docs]
-def laplacian_uniform(verts, faces):
- """ calculate laplacian uniform matrix
-
- Args:
- verts (torch.Tensor): mesh vertices, float [N, 3]
- faces (torch.Tensor): mesh faces, long [M, 3]
-
- Returns:
- torch.Tensor: sparse laplacian matrix.
- """
-
- V = verts.shape[0]
- F = faces.shape[0]
-
- # Neighbor indices
- ii = faces[:, [1, 2, 0]].flatten()
- jj = faces[:, [2, 0, 1]].flatten()
- adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
- adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float)
-
- # Diagonal indices
- diag_idx = adj[0]
-
- # Build the sparse matrix
- idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
- values = torch.cat((-adj_values, adj_values))
-
- # The coalesce operation sums the duplicate indices, resulting in the
- # correct diagonal
- return torch.sparse_coo_tensor(idx, values, (V,V)).coalesce()
-
-
-
-
-[docs]
-def laplacian_smooth_loss(verts, faces):
- """ calculate laplacian smooth loss.
-
- Args:
- verts (torch.Tensor): mesh vertices, float [N, 3]
- faces (torch.Tensor): mesh faces, int [M, 3]
-
- Returns:
- torch.Tensor: loss value.
- """
- with torch.no_grad():
- L = laplacian_uniform(verts, faces.long())
- loss = L.mm(verts)
- loss = loss.norm(dim=1)
- loss = loss.mean()
- return loss
-
-
-
-[docs]
-@torch.no_grad()
-def compute_edge_to_face_mapping(faces):
- """ compute edge to face mapping.
-
- Args:
- faces (torch.Tensor): mesh faces, int [M, 3]
-
- Returns:
- torch.Tensor: indices to faces for each edge, long, [N, 2]
- """
- # Get unique edges
- # Create all edges, packed by triangle
- all_edges = torch.cat((
- torch.stack((faces[:, 0], faces[:, 1]), dim=-1),
- torch.stack((faces[:, 1], faces[:, 2]), dim=-1),
- torch.stack((faces[:, 2], faces[:, 0]), dim=-1),
- ), dim=-1).view(-1, 2)
-
- # Swap edge order so min index is always first
- order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
- sorted_edges = torch.cat((
- torch.gather(all_edges, 1, order),
- torch.gather(all_edges, 1, 1 - order)
- ), dim=-1)
-
- # Elliminate duplicates and return inverse mapping
- unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
-
- tris = torch.arange(faces.shape[0]).repeat_interleave(3).cuda()
-
- tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
-
- # Compute edge to face table
- mask0 = order[:,0] == 0
- mask1 = order[:,0] == 1
- tris_per_edge[idx_map[mask0], 0] = tris[mask0]
- tris_per_edge[idx_map[mask1], 1] = tris[mask1]
-
- return tris_per_edge
-
-
-
-
-[docs]
-def normal_consistency(verts, faces, face_normals=None):
- """ calculate normal consistency loss.
-
- Args:
- verts (torch.Tensor): mesh vertices, float [N, 3]
- faces (torch.Tensor): mesh faces, int [M, 3]
- face_normals (Optional[torch.Tensor]): the normal vector for each face, will be calculated if not provided, float [M, 3]
-
- Returns:
- torch.Tensor: loss value.
- """
-
- if face_normals is None:
-
- i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long()
- v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]
-
- face_normals = torch.cross(v1 - v0, v2 - v0)
- face_normals = safe_normalize(face_normals)
-
- tris_per_edge = compute_edge_to_face_mapping(faces)
-
- # Fetch normals for both faces sharind an edge
- n0 = face_normals[tris_per_edge[:, 0], :]
- n1 = face_normals[tris_per_edge[:, 1], :]
-
- # Compute error metric based on normal difference
- term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
- term = (1.0 - term)
-
- return torch.mean(torch.abs(term))
-
-
-import cv2
-import torch
-import numpy as np
-
-from kiui.typing import *
-from kiui.grid_put import grid_put
-
-# torch / numpy math utils
-
-[docs]
-def dot(x: Union[Tensor, ndarray], y: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
- """dot product (along the last dim).
-
- Args:
- x (Union[Tensor, ndarray]): x, [..., C]
- y (Union[Tensor, ndarray]): y, [..., C]
-
- Returns:
- Union[Tensor, ndarray]: x dot y, [..., 1]
- """
- if isinstance(x, np.ndarray):
- return np.sum(x * y, -1, keepdims=True)
- else:
- return torch.sum(x * y, -1, keepdim=True)
-
-
-
-[docs]
-def length(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]:
- """length of an array (along the last dim).
-
- Args:
- x (Union[Tensor, ndarray]): x, [..., C]
- eps (float, optional): eps. Defaults to 1e-20.
-
- Returns:
- Union[Tensor, ndarray]: length, [..., 1]
- """
- if isinstance(x, np.ndarray):
- return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
- else:
- return torch.sqrt(torch.clamp(dot(x, x), min=eps))
-
-
-
-[docs]
-def safe_normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]:
- """normalize an array (along the last dim).
-
- Args:
- x (Union[Tensor, ndarray]): x, [..., C]
- eps (float, optional): eps. Defaults to 1e-20.
-
- Returns:
- Union[Tensor, ndarray]: normalized x, [..., C]
- """
-
- return x / length(x, eps)
-
-
-
-[docs]
-def normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]:
- """normalize an array (along the last dim). alias of safe_normalize.
-
- Args:
- x (Union[Tensor, ndarray]): x, [..., C]
- eps (float, optional): eps. Defaults to 1e-20.
-
- Returns:
- Union[Tensor, ndarray]: normalized x, [..., C]
- """
-
- return x / length(x, eps)
-
-
-
-[docs]
-def make_divisible(x: int, m: int = 8):
- """make an int x divisible by m.
-
- Args:
- x (int): x
- m (int, optional): m. Defaults to 8.
-
- Returns:
- int: x + (m - x % m)
- """
- return int(x + (m - x % m))
-
-
-
-[docs]
-def inverse_sigmoid(x: Tensor, eps=1e-6) -> Tensor:
- """inversion of sigmoid function.
-
- Args:
- x (Tensor): x
- eps (float, optional): eps. Defaults to 1e-6.
-
- Returns:
- Tensor: log(x / (1 - x))
- """
- x = x.clamp(eps, 1 - eps)
- return torch.log(x / (1 - x))
-
-
-
-[docs]
-def inverse_softplus(x: Tensor) -> Tensor:
- """inversion of softplus function.
-
- Args:
- x (Tensor): x
-
- Returns:
- Tensor: log(exp(x) - 1)
- """
- # a numerically stable equation (ref: https://github.com/pytorch/pytorch/issues/72759)
- return x + torch.log(-torch.expm1(-x))
-
-
-# torch image scaling
-
-[docs]
-def scale_img_nhwc(x: Tensor, size: Sequence[int], mag='bilinear', min='bilinear') -> Tensor:
- """image scaling helper.
-
- Args:
- x (Tensor): input image, float [N, H, W, C]
- size (Sequence[int]): target size, tuple of [H', W']
- mag (str, optional): upscale interpolation mode. Defaults to 'bilinear'.
- min (str, optional): downscale interpolation mode. Defaults to 'bilinear'.
-
- Returns:
- Tensor: rescaled image, float [N, H', W', C]
- """
- assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
- y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
- if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
- y = torch.nn.functional.interpolate(y, size, mode=min)
- else: # Magnification
- if mag == 'bilinear' or mag == 'bicubic':
- y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
- else:
- y = torch.nn.functional.interpolate(y, size, mode=mag)
- return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
-
-
-
-[docs]
-def scale_img_hwc(x: Tensor, size: Sequence[int], mag='bilinear', min='bilinear') -> Tensor:
- """image scaling helper.
-
- Args:
- x (Tensor): input image, float [H, W, C]
- size (Sequence[int]): target size, tuple of [H', W']
- mag (str, optional): upscale interpolation mode. Defaults to 'bilinear'.
- min (str, optional): downscale interpolation mode. Defaults to 'bilinear'.
-
- Returns:
- Tensor: rescaled image, float [H', W', C]
- """
- return scale_img_nhwc(x[None, ...], size, mag, min)[0]
-
-
-
-[docs]
-def scale_img_nhw(x: Tensor, size: Sequence[int], mag='bilinear', min='bilinear') -> Tensor:
- """image scaling helper.
-
- Args:
- x (Tensor): input image, float [N, H, W]
- size (Sequence[int]): target size, tuple of [H', W']
- mag (str, optional): upscale interpolation mode. Defaults to 'bilinear'.
- min (str, optional): downscale interpolation mode. Defaults to 'bilinear'.
-
- Returns:
- Tensor: rescaled image, float [N, H', W']
- """
- return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
-
-
-
-[docs]
-def scale_img_hw(x: Tensor, size: Sequence[int], mag='bilinear', min='bilinear') -> Tensor:
- """image scaling helper.
-
- Args:
- x (Tensor): input image, float [H, W]
- size (Sequence[int]): target size, tuple of [H', W']
- mag (str, optional): upscale interpolation mode. Defaults to 'bilinear'.
- min (str, optional): downscale interpolation mode. Defaults to 'bilinear'.
-
- Returns:
- Tensor: rescaled image, float [H', W']
- """
- return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
-
-
-
-
-[docs]
-def uv_padding(image: Union[Tensor, ndarray], mask: Union[Tensor, ndarray], padding: Optional[int] = None, backend: Literal['knn', 'cv2'] = 'knn'):
- """padding the uv-space texture image to avoid seam artifacts in mipmaps.
-
- Args:
- image (Union[Tensor, ndarray]): texture image, float, [H, W, C] in [0, 1].
- mask (Union[Tensor, ndarray]): valid uv region, bool, [H, W].
- padding (int, optional): padding size into the unmasked region. Defaults to 0.1 * max(H, W).
- backend (Literal['knn', 'cv2'], optional): algorithm backend, knn is faster. Defaults to 'knn'.
-
- Returns:
- Union[Tensor, ndarray]: padded texture image. float, [H, W, C].
- """
-
- if torch.is_tensor(image):
- image_input = image.detach().cpu().numpy()
- else:
- image_input = image
-
- if torch.is_tensor(mask):
- mask_input = mask.detach().cpu().numpy()
- else:
- mask_input = mask
-
- if padding is None:
- H, W = image_input.shape[:2]
- padding = int(0.1 * max(H, W))
-
- # padding backend
- if backend == 'knn':
-
- from sklearn.neighbors import NearestNeighbors
- from scipy.ndimage import binary_dilation, binary_erosion
-
- inpaint_region = binary_dilation(mask_input, iterations=padding)
- inpaint_region[mask_input] = 0
-
- search_region = mask_input.copy()
- not_search_region = binary_erosion(search_region, iterations=2)
- search_region[not_search_region] = 0
-
- search_coords = np.stack(np.nonzero(search_region), axis=-1)
- inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
-
- knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
- _, indices = knn.kneighbors(inpaint_coords)
-
- inpaint_image = image_input.copy()
- inpaint_image[tuple(inpaint_coords.T)] = inpaint_image[tuple(search_coords[indices[:, 0]].T)]
-
- elif backend == 'cv2':
- # kind of slow
- inpaint_image = cv2.inpaint(
- (image_input * 255).astype(np.uint8),
- (~mask_input * 255).astype(np.uint8),
- padding,
- cv2.INPAINT_TELEA,
- ).astype(np.float32) / 255
-
- if torch.is_tensor(image):
- inpaint_image = torch.from_numpy(inpaint_image).to(image)
-
- return inpaint_image
-
-
-
-
-[docs]
-def recenter(image: ndarray, mask: ndarray, border_ratio: float = 0.2) -> ndarray:
- """ recenter an image to leave some empty space at the image border.
-
- Args:
- image (ndarray): input image, float/uint8 [H, W, 3/4]
- mask (ndarray): alpha mask, bool [H, W]
- border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
-
- Returns:
- ndarray: output image, float/uint8 [H, W, 3/4]
- """
-
- return_int = False
- if image.dtype == np.uint8:
- image = image.astype(np.float32) / 255
- return_int = True
-
- H, W, C = image.shape
- size = max(H, W)
-
- # default to white bg if rgb, but use 0 if rgba
- if C == 3:
- result = np.ones((size, size, C), dtype=np.float32)
- else:
- result = np.zeros((size, size, C), dtype=np.float32)
-
- coords = np.nonzero(mask)
- x_min, x_max = coords[0].min(), coords[0].max()
- y_min, y_max = coords[1].min(), coords[1].max()
- h = x_max - x_min
- w = y_max - y_min
- desired_size = int(size * (1 - border_ratio))
- scale = desired_size / max(h, w)
- h2 = int(h * scale)
- w2 = int(w * scale)
- x2_min = (size - h2) // 2
- x2_max = x2_min + h2
- y2_min = (size - w2) // 2
- y2_max = y2_min + w2
- result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
-
- if return_int:
- result = (result * 255).astype(np.uint8)
-
- return result
-
-
-import os
-import cv2
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.init as init
-from torch.nn import functional as F
-from torch.nn.modules.batchnorm import _BatchNorm
-
-import numpy as np
-from PIL import Image
-
-from huggingface_hub import hf_hub_download
-
-from kiui.typing import *
-
-HF_MODELS = {
- 2: dict(
- repo_id='ai-forever/Real-ESRGAN',
- filename='RealESRGAN_x2.pth',
- ),
- 4: dict(
- repo_id='ai-forever/Real-ESRGAN',
- filename='RealESRGAN_x4.pth',
- ),
- 8: dict(
- repo_id='ai-forever/Real-ESRGAN',
- filename='RealESRGAN_x8.pth',
- ),
-}
-
-@torch.no_grad()
-def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
- """Initialize network weights.
-
- Args:
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
- scale (float): Scale initialized weights, especially for residual
- blocks. Default: 1.
- bias_fill (float): The value to fill bias. Default: 0
- kwargs (dict): Other arguments for initialization function.
- """
- if not isinstance(module_list, list):
- module_list = [module_list]
- for module in module_list:
- for m in module.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, _BatchNorm):
- init.constant_(m.weight, 1)
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
-
- Args:
- basic_block (nn.module): nn.module class for basic block.
- num_basic_block (int): number of blocks.
-
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-class ResidualBlockNoBN(nn.Module):
- """Residual block without BN.
-
- It has a style of:
- ---Conv-ReLU-Conv-+-
- |________________|
-
- Args:
- num_feat (int): Channel number of intermediate features.
- Default: 64.
- res_scale (float): Residual scale. Default: 1.
- pytorch_init (bool): If set to True, use pytorch default init,
- otherwise, use default_init_weights. Default: False.
- """
-
- def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
- super(ResidualBlockNoBN, self).__init__()
- self.res_scale = res_scale
- self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
- self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
- self.relu = nn.ReLU(inplace=True)
-
- if not pytorch_init:
- default_init_weights([self.conv1, self.conv2], 0.1)
-
- def forward(self, x):
- identity = x
- out = self.conv2(self.relu(self.conv1(x)))
- return identity + out * self.res_scale
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
- """Warp an image or feature map with optical flow.
-
- Args:
- x (Tensor): Tensor with size (n, c, h, w).
- flow (Tensor): Tensor with size (n, h, w, 2), normal value.
- interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
- padding_mode (str): 'zeros' or 'border' or 'reflection'.
- Default: 'zeros'.
- align_corners (bool): Before pytorch 1.3, the default value is
- align_corners=True. After pytorch 1.3, the default value is
- align_corners=False. Here, we use the True as default.
-
- Returns:
- Tensor: Warped image or feature map.
- """
- assert x.size()[-2:] == flow.size()[1:3]
- _, _, h, w = x.size()
- # create mesh grid
- grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
- grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
- grid.requires_grad = False
-
- vgrid = grid + flow
- # scale grid to [-1,1]
- vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
- vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
- vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
- output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
-
- # TODO, what if align_corners=False
- return output
-
-
-def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
- """Resize a flow according to ratio or shape.
-
- Args:
- flow (Tensor): Precomputed flow. shape [N, 2, H, W].
- size_type (str): 'ratio' or 'shape'.
- sizes (list[int | float]): the ratio for resizing or the final output
- shape.
- 1) The order of ratio should be [ratio_h, ratio_w]. For
- downsampling, the ratio should be smaller than 1.0 (i.e., ratio
- < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
- ratio > 1.0).
- 2) The order of output_size should be [out_h, out_w].
- interp_mode (str): The mode of interpolation for resizing.
- Default: 'bilinear'.
- align_corners (bool): Whether align corners. Default: False.
-
- Returns:
- Tensor: Resized flow.
- """
- _, _, flow_h, flow_w = flow.size()
- if size_type == 'ratio':
- output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
- elif size_type == 'shape':
- output_h, output_w = sizes[0], sizes[1]
- else:
- raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
-
- input_flow = flow.clone()
- ratio_h = output_h / flow_h
- ratio_w = output_w / flow_w
- input_flow[:, 0, :, :] *= ratio_w
- input_flow[:, 1, :, :] *= ratio_h
- resized_flow = F.interpolate(
- input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
- return resized_flow
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
-
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
-
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-def pad_reflect(image, pad_size):
- imsize = image.shape
- height, width = imsize[:2]
- new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
- new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
-
- new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
- new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
- new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
- new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
-
- return new_img
-
-def unpad_image(image, pad_size):
- return image[pad_size:-pad_size, pad_size:-pad_size, :]
-
-
-def process_array(image_array, expand=True):
- """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
-
- image_batch = image_array / 255.0
- if expand:
- image_batch = np.expand_dims(image_batch, axis=0)
- return image_batch
-
-
-def process_output(output_tensor):
- """ Transforms the 4-dimensional output tensor into a suitable image format. """
-
- sr_img = output_tensor.clip(0, 1) * 255
- sr_img = np.uint8(sr_img)
- return sr_img
-
-
-def pad_patch(image_patch, padding_size, channel_last=True):
- """ Pads image_patch with with padding_size edge values. """
-
- if channel_last:
- return np.pad(
- image_patch,
- ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
- 'edge',
- )
- else:
- return np.pad(
- image_patch,
- ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
- 'edge',
- )
-
-
-def unpad_patches(image_patches, padding_size):
- return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
-
-
-def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
- """ Splits the image into partially overlapping patches.
- The patches overlap by padding_size pixels.
- Pads the image twice:
- - first to have a size multiple of the patch size,
- - then to have equal padding at the borders.
- Args:
- image_array: numpy array of the input image.
- patch_size: size of the patches from the original image (without padding).
- padding_size: size of the overlapping area.
- """
-
- xmax, ymax, _ = image_array.shape
- x_remainder = xmax % patch_size
- y_remainder = ymax % patch_size
-
- # modulo here is to avoid extending of patch_size instead of 0
- x_extend = (patch_size - x_remainder) % patch_size
- y_extend = (patch_size - y_remainder) % patch_size
-
- # make sure the image is divisible into regular patches
- extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
-
- # add padding around the image to simplify computations
- padded_image = pad_patch(extended_image, padding_size, channel_last=True)
-
- xmax, ymax, _ = padded_image.shape
- patches = []
-
- x_lefts = range(padding_size, xmax - padding_size, patch_size)
- y_tops = range(padding_size, ymax - padding_size, patch_size)
-
- for x in x_lefts:
- for y in y_tops:
- x_left = x - padding_size
- y_top = y - padding_size
- x_right = x + patch_size + padding_size
- y_bottom = y + patch_size + padding_size
- patch = padded_image[x_left:x_right, y_top:y_bottom, :]
- patches.append(patch)
-
- return np.array(patches), padded_image.shape
-
-
-def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
- """ Reconstruct the image from overlapping patches.
- After scaling, shapes and padding should be scaled too.
- Args:
- patches: patches obtained with split_image_into_overlapping_patches
- padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
- target_shape: shape of the final image
- padding_size: size of the overlapping area.
- """
-
- xmax, ymax, _ = padded_image_shape
- patches = unpad_patches(patches, padding_size)
- patch_size = patches.shape[1]
- n_patches_per_row = ymax // patch_size
-
- complete_image = np.zeros((xmax, ymax, 3))
-
- row = -1
- col = 0
- for i in range(len(patches)):
- if i % n_patches_per_row == 0:
- row += 1
- col = 0
- complete_image[
- row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
- ] = patches[i]
- col += 1
- return complete_image[0: target_shape[0], 0: target_shape[1], :]
-
-
-class ResidualDenseBlock(nn.Module):
- """Residual Dense Block.
-
- Used in RRDB block in ESRGAN.
-
- Args:
- num_feat (int): Channel number of intermediate features.
- num_grow_ch (int): Channels for each growth.
- """
-
- def __init__(self, num_feat=64, num_grow_ch=32):
- super(ResidualDenseBlock, self).__init__()
- self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
- self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
- self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
- self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
- self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- # Emperically, we use 0.2 to scale the residual for better performance
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- """Residual in Residual Dense Block.
-
- Used in RRDB-Net in ESRGAN.
-
- Args:
- num_feat (int): Channel number of intermediate features.
- num_grow_ch (int): Channels for each growth.
- """
-
- def __init__(self, num_feat, num_grow_ch=32):
- super(RRDB, self).__init__()
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
-
- def forward(self, x):
- out = self.rdb1(x)
- out = self.rdb2(out)
- out = self.rdb3(out)
- # Emperically, we use 0.2 to scale the residual for better performance
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- """Networks consisting of Residual in Residual Dense Block, which is used
- in ESRGAN.
-
- ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
-
- We extend ESRGAN for scale x2 and scale x1.
- Note: This is one option for scale 1, scale 2 in RRDBNet.
- We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
- and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
-
- Args:
- num_in_ch (int): Channel number of inputs.
- num_out_ch (int): Channel number of outputs.
- num_feat (int): Channel number of intermediate features.
- Default: 64
- num_block (int): Block number in the trunk network. Defaults: 23
- num_grow_ch (int): Channels for each growth. Default: 32.
- """
-
- def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
- super(RRDBNet, self).__init__()
- self.scale = scale
- if scale == 2:
- num_in_ch = num_in_ch * 4
- elif scale == 1:
- num_in_ch = num_in_ch * 16
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
- self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- # upsample
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if scale == 8:
- self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- if self.scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- elif self.scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- else:
- feat = x
- feat = self.conv_first(feat)
- body_feat = self.conv_body(self.body(feat))
- feat = feat + body_feat
- # upsample
- feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
- feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
- if self.scale == 8:
- feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.conv_hr(feat)))
- return out
-
-
-class RealESRGAN:
- def __init__(self, device, scale=4):
- print(f'[INFO] init RealESRGAN_{scale}x: {device}')
- self.device = device
- self.scale = scale
- self.model = RRDBNet(
- num_in_ch=3, num_out_ch=3, num_feat=64,
- num_block=23, num_grow_ch=32, scale=scale
- )
- self.load_weights()
-
- def load_weights(self):
- model_path = hf_hub_download(repo_id=HF_MODELS[self.scale]['repo_id'], filename=HF_MODELS[self.scale]['filename'])
- checkpoint = torch.load(model_path)
- if 'params' in checkpoint:
- self.model.load_state_dict(checkpoint['params'], strict=True)
- elif 'params_ema' in checkpoint:
- self.model.load_state_dict(checkpoint['params_ema'], strict=True)
- else:
- self.model.load_state_dict(checkpoint, strict=True)
- self.model.eval()
- self.model.to(self.device)
-
- @torch.cuda.amp.autocast()
- def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
- # lr_image: np.ndarray, [h, w, 3], RGB uint8
- # return: np.ndarray, [H, W, 3], RGB uint8
-
- return_tensor = False
- if torch.is_tensor(lr_image):
- # or Tensor, [1, 3, H, W], RGB float32
- lr_image = (lr_image.detach().permute(0,2,3,1)[0].cpu().numpy() * 255).astype(np.uint8)
- return_tensor = True
-
- lr_image = pad_reflect(lr_image, pad_size)
-
- patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size, padding_size=padding)
- img = torch.from_numpy(patches.astype(np.float32) / 255).permute((0,3,1,2)).to(self.device).detach()
-
- with torch.no_grad():
- res = self.model(img[0:batch_size])
- for i in range(batch_size, img.shape[0], batch_size):
- res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
-
- sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
- np_sr_image = sr_image.numpy()
-
- padded_size_scaled = tuple(np.multiply(p_shape[0:2], self.scale)) + (3,)
- scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], self.scale)) + (3,)
-
- np_sr_image = stich_together(
- np_sr_image, padded_image_shape=padded_size_scaled,
- target_shape=scaled_image_shape, padding_size=padding * self.scale
- )
-
- sr_img = (np_sr_image * 255).astype(np.uint8)
- sr_img = unpad_image(sr_img, pad_size * self.scale)
-
- if return_tensor:
- sr_img = torch.from_numpy(sr_img.astype(np.float32) / 255).permute((2,0,1)).unsqueeze(0).to(self.device)
-
- return sr_img
-
-
-MODELS = {}
-
-[docs]
-def sr(image: ndarray, scale: Literal[2, 4, 8] = 2, device=None):
- """ lazy load functional super-resolution API for convenience.
-
- Args:
- image (ndarray): input image, uint8/float32 [H, W, 3]
- scale (Literal[2, 4, 8], optional): upscale factor. Defaults to 2.
- device (torch.device, optional): device to put SR models, if not provided, will try to use 'cuda'. Defaults to None.
-
- Returns:
- ndarray: super-resolutioned image, uint8/float32 [H * scale, W * scale, 3]
- """
- global MODELS
- if scale not in MODELS:
- if device is None:
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- MODELS[scale] = RealESRGAN(device, scale=scale)
-
- return_float = False
- if image.dtype == np.float32:
- return_float = True
- image = (image * 255).astype(np.uint8)
-
- sr_image = MODELS[scale].predict(image)
-
- if return_float:
- sr_image = sr_image.astype(np.float32) / 255.0
-
- return sr_image
-
-
-def main():
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('input', type=str)
- parser.add_argument('--output', type=str, default=None)
- parser.add_argument('--scale', type=int, default=4)
- args = parser.parse_args()
-
- model = RealESRGAN('cuda', scale=4)
-
- if args.output is None:
- args.output = os.path.splitext(args.input)[0] + f'_{args.scale}x.jpg'
-
- image = cv2.imread(args.input)
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- sr_image = model.predict(image)
- sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
-
- cv2.imwrite(args.output, sr_image)
-
-
-if __name__ == '__main__':
- main()
-
-import os
-import sys
-import glob
-import tqdm
-import json
-import pickle
-import varname
-from objprint import objstr
-from rich.console import Console
-
-import cv2
-from PIL import Image
-
-import numpy as np
-import torch
-
-from kiui.typing import *
-from kiui.env import is_imported
-
-
-[docs]
-def lo(*xs, verbose=0):
- """inspect array like objects and report statistics.
-
- Args:
- xs (Any): array like objects to inspect.
- verbose (int, optional): level of verbosity, set to 1 to report mean and std, 2 to print the content. Defaults to 0.
- """
-
- console = Console()
-
- def _lo(x, name):
-
- if isinstance(x, np.ndarray):
- # general stats
- text = ""
- text += f"[orange1]Array {name}[/orange1] {x.shape} {x.dtype}"
- if x.size > 0:
- text += f" ∈ [{x.min()}, {x.max()}]"
- if verbose >= 1:
- text += f" μ = {x.mean()} σ = {x.std()}"
- # detect abnormal values
- if np.isnan(x).any():
- text += "[red] NaN![/red]"
- if np.isinf(x).any():
- text += "[red] Inf![/red]"
- console.print(text)
-
- # show values if shape is small or verbose is high
- if x.size < 50 or verbose >= 2:
- # np.set_printoptions(precision=4)
- print(x)
-
- elif torch.is_tensor(x):
- # general stats
- text = ""
- text += f"[orange1]Tensor {name}[/orange1] {x.shape} {x.dtype} {x.device}"
- if x.numel() > 0:
- text += f" ∈ [{x.min().item()}, {x.max().item()}]"
- if verbose >= 1:
- text += f" μ = {x.mean().item()} σ = {x.std().item()}"
- # detect abnormal values
- if torch.isnan(x).any():
- text += "[red] NaN![/red]"
- if torch.isinf(x).any():
- text += "[red] Inf![/red]"
- console.print(text)
-
- # show values if shape is small or verbose is high
- if x.numel() < 50 or verbose >= 2:
- # np.set_printoptions(precision=4)
- print(x)
-
- else: # other type, just print them
- console.print(f"[orange1]{type(x)} {name}[/orange1] {objstr(x)}")
-
- # inspect names
- for i, x in enumerate(xs):
- try:
- name = varname.argname(f"xs[{i}]", func=lo)
- except:
- name = f"UNKNOWN"
- _lo(x, name)
-
-
-
-
-[docs]
-def seed_everything(seed=42, verbose=False, strict=False):
- """auto set seed for random, numpy and torch.
-
- Args:
- seed (int, optional): random seed. Defaults to 42.
- verbose (bool, optional): whether to report each seed setting. Defaults to False.
- strict (bool, optional): whether to use strict deterministic mode for better torch reproduction. Defaults to False.
- """
-
- os.environ['PYTHONHASHSEED'] = str(seed)
-
- if is_imported('random'):
- import random # still need to import it here
- random.seed(seed)
- if verbose: print(f'[INFO] set random.seed = {seed}')
- else:
- if verbose: print(f'[INFO] random not imported, skip setting seed')
-
- # assume numpy is imported as np
- if is_imported('np'):
- import numpy as np
- np.random.seed(seed)
- if verbose: print(f'[INFO] set np.random.seed = {seed}')
- else:
- if verbose: print(f'[INFO] numpy not imported, skip setting seed')
-
- if is_imported('torch'):
- import torch
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- if verbose: print(f'[INFO] set torch.manual_seed = {seed}')
-
- if strict:
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.use_deterministic_algorithms(True)
- if verbose: print(f'[INFO] set strict deterministic mode for torch.')
- else:
- if verbose: print(f'[INFO] torch not imported, skip setting seed')
-
-
-
-
-[docs]
-def read_json(path):
- """load a json file.
-
- Args:
- path (str): path to json file.
-
- Returns:
- dict: json content.
- """
- with open(path, "r") as f:
- return json.load(f)
-
-
-
-
-[docs]
-def write_json(path, x):
- """write a json file.
-
- Args:
- path (str): path to write json file.
- x (dict): dict to write.
- """
- with open(path, "w") as f:
- json.dump(x, f, indent=2)
-
-
-
-
-[docs]
-def read_pickle(path):
- """read a pickle file.
-
- Args:
- path (str): path to pickle file.
-
- Returns:
- Any: pickle content.
- """
- with open(path, "rb") as f:
- return pickle.load(f)
-
-
-
-
-[docs]
-def write_pickle(path, x):
- """write a pickle file.
-
- Args:
- path (str): path to write pickle file.
- x (Any): content to write.
- """
- with open(path, "wb") as f:
- pickle.dump(x, f)
-
-
-
-
-[docs]
-def read_image(
- path: str,
- mode: Literal["float", "uint8", "pil", "torch", "tensor"] = "float",
- order: Literal["RGB", "RGBA", "BGR", "BGRA"] = "RGB",
-):
- """read an image file into various formats and color mode.
-
- Args:
- path (str): path to the image file.
- mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float".
- float: float32 numpy array, range [0, 1];
- uint8: uint8 numpy array, range [0, 255];
- pil: PIL image;
- torch/tensor: float32 torch tensor, range [0, 1];
- order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB".
-
- Note:
- By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel.
-
- Returns:
- Union[np.ndarray, PIL.Image, torch.Tensor]: the image array.
- """
-
- if mode == "pil":
- return Image.open(path).convert(order)
-
- if path.endswith('.exr'):
- os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
- img = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
- else:
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
-
- # cvtColor
- if len(img.shape) == 3: # ignore if gray scale
- if order in ["RGB", "RGBA"]:
- if img.shape[-1] == 4:
- img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
- elif img.shape[-1] == 3:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-
- # mix background
- if img.shape[-1] == 4 and 'A' not in order:
- img = img.astype(np.float32) / 255
- img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
-
- # mode
- if mode == "uint8":
- if img.dtype != np.uint8:
- img = (img * 255).astype(np.uint8)
- return img
- elif mode == "float":
- if img.dtype == np.uint8:
- img = img.astype(np.float32) / 255
- return img
- elif mode in ["tensor", "torch"]:
- if img.dtype == np.uint8:
- img = img.astype(np.float32) / 255
- return torch.from_numpy(img)
- else:
- raise ValueError(f"Unknown read_image mode {mode}")
-
-
-
-
-[docs]
-def write_image(
- path: str,
- img: Union[Tensor, np.ndarray, Image.Image],
- order: Literal["RGB", "BGR"] = "RGB",
- ):
- """write an image to various formats.
-
- Args:
- path (str): path to write the image file.
- img (Union[torch.Tensor, np.ndarray, PIL.Image.Image]): image to write.
- order (str, optional): channel order. Defaults to "RGB".
- """
-
- if isinstance(img, Image.Image):
- img.save(path)
- return
-
- if torch.is_tensor(img):
- img = img.detach().cpu().numpy()
-
- if img.dtype == np.float32 or img.dtype == np.float64:
- img = (img * 255).astype(np.uint8)
-
- if len(img.shape) == 4:
- if img.shape[0] > 1:
- raise ValueError(f'only support saving a single image! current image: {img.shape}')
- img = img[0]
-
- if len(img.shape) == 3:
- # cvtColor
- if order == "RGB":
- if img.shape[-1] == 4:
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
- elif img.shape[-1] == 3:
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-
- dir_path = os.path.dirname(path)
- if dir_path != '' and not os.path.exists(dir_path):
- os.makedirs(os.path.dirname(path), exist_ok=True)
- cv2.imwrite(path, img)
-
-
-
-
-[docs]
-def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
- """Load file form http url, will download models if necessary.
-
- Args:
- url (str): URL to be downloaded.
- model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
- Default: None.
- progress (bool): Whether to show the download progress. Default: True.
- file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
-
- Returns:
- str: The path to the downloaded file.
- """
-
- from torch.hub import download_url_to_file, get_dir
- from urllib.parse import urlparse
-
- if model_dir is None: # use the pytorch hub_dir
- hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, "checkpoints")
-
- os.makedirs(model_dir, exist_ok=True)
-
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- if file_name is not None:
- filename = file_name
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
- if not os.path.exists(cached_file):
- print(f'[INFO] Downloading: "{url}" to {cached_file}\n')
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
- return cached_file
-
-
-
-
-[docs]
-def is_format(f: str, format: Sequence[str]):
- """if a file's extension is in a set of format
-
- Args:
- f (str): file name.
- format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok).
-
- Returns:
- bool: if the file's extension is in the set.
- """
- ext = os.path.splitext(f)[1].lower() # include the dot
- return ext in format or ext[1:] in format
-
-
-
-[docs]
-def batch_process_files(
- process_fn, path, out_path,
- overwrite=False,
- in_format=[".jpg", ".jpeg", ".png"],
- out_format=None,
- image_mode='uint8',
- image_color_order="RGB",
- **kwargs
-):
- """simple function wrapper to batch processing files.
-
- Args:
- process_fn (Callable): process function.
- path (str): path to a file or a directory containing the files to process.
- out_path (str): output path of a file or a directory.
- overwrite (bool, optional): whether to overwrite existing results. Defaults to False.
- in_format (list, optional): input file formats. Defaults to [".jpg", ".jpeg", ".png"].
- out_format (str, optional): output file format. Defaults to None.
- image_mode (str, optional): for images, the mode to read. Defaults to 'uint8'.
- image_color_order (str, optional): for images, the color order. Defaults to "RGB".
- """
-
- if os.path.isdir(path):
- file_paths = glob.glob(os.path.join(path, "*"))
- file_paths = [f for f in file_paths if is_format(f, in_format)]
- else:
- file_paths = [path]
-
- if os.path.dirname(out_path) != '':
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
-
- for file_path in tqdm.tqdm(file_paths):
- try:
-
- if len(file_paths) == 1:
- file_out_path = out_path
- else:
- file_out_path = os.path.join(out_path, os.path.basename(file_path))
-
- if out_format is not None:
- file_out_path = os.path.splitext(file_out_path)[0] + out_format
-
- if os.path.exists(file_out_path) and not overwrite:
- print(f"[INFO] ignoring {file_path} --> {file_out_path}")
- continue
-
- # dispatch loader
- if is_format(file_path, ['.jpg', '.jpeg', '.png']):
- input = read_image(file_path, mode=image_mode, order=image_color_order)
- elif is_format(file_path, ['.ply', '.obj', '.glb', '.gltf']):
- from kiui.mesh import Mesh
- input = Mesh.load(file_path)
- else:
- with open(file_path, "r") as f:
- input = f.read()
-
- # process
- output = process_fn(input, **kwargs)
-
- # dispatch writer
- if is_format(file_out_path, ['.jpg', '.jpeg', '.png']):
- write_image(file_out_path, output, order=image_color_order)
- elif is_format(file_out_path, ['.ply', '.obj', '.glb', '.gltf']):
- output.write(file_out_path)
- elif is_format(file_out_path, ['.npy']):
- np.save(file_out_path, output)
- else:
- with open(file_out_path, "w") as f:
- f.write(output)
-
- except Exception as e:
- print(f"[ERROR] when processing {file_path} --> {file_out_path}")
- print(e)
-
-
-import time
-import torch
-import numpy as np
-from datetime import datetime
-
-import matplotlib.cm as cm
-import matplotlib.pyplot as plt
-
-from kiui.typing import *
-from kiui.utils import lo, write_image
-
-
-
-[docs]
-def map_color(value: ndarray, cmap_name: str="viridis", vmin: float=None, vmax: float=None):
- """ map a 1D array to continuous color.
-
- Args:
- value (ndarray): array of float, [N]
- cmap_name (str, optional): color map name, ref: https://matplotlib.org/stable/users/explain/colors/colormaps.html#classes-of-colormaps. Defaults to "viridis".
- vmin (float, optional): min value. Defaults to None.
- vmax (float, optional): max value. Defaults to None.
-
- Returns:
- ndarray: array of color, [N, 3] in [0, 1]
- """
- # value: [N], float
- # return: RGB, [N, 3], float in [0, 1]
-
- if vmin is None:
- vmin = value.min()
- if vmax is None:
- vmax = value.max()
- value = (value - vmin) / (vmax - vmin) # range in [0, 1]
- cmap = cm.get_cmap(cmap_name)
- rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb
- return rgb
-
-
-
-
-[docs]
-def plot_image(*xs, normalize=False, save=False, prefix='kiui_vis_plot_image'):
- """ sequentially plot provided images, optionally save to current dir.
-
- Args:
- xs (Sequence[Union[torch.Tensor, numpy.ndarray]]): can be uint8 or float32.
- [B, 4/3/1, H, W], [B, H, W, 4/3/1], [4/3/1, H, W], [H, W, 4/3/1], [H, W] torch.Tensor or numpy.ndarray
- normalize (bool, optional): whether to renormalize the image to [0, 1]. Defaults to False.
- save (bool, optional): whether to save the image to current dir (in case the plot cannot be showed, like in vscode remote). Defaults to False.
- prefix (str, optional): image save name prefix if save=True.
- """
-
- _cnt = 0
- _signature = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')
-
- def _plot_image(image):
-
- nonlocal _cnt
-
- lo(image)
-
- if isinstance(image, torch.Tensor):
- image = image.detach().cpu().numpy()
-
- if image.dtype == np.uint8:
- image = image.astype(np.float32) / 255.0
-
- # empirially to channel-last
- if len(image.shape) == 3 and image.shape[0] < image.shape[-1]:
- image = image.transpose(1, 2, 0)
-
- # normalize
- if normalize:
- image = (image - image.min(axis=0, keepdims=True)) / (
- image.max(axis=0, keepdims=True)
- - image.min(axis=0, keepdims=True)
- + 1e-8
- )
-
- if save:
- _path = f'{prefix}_{_signature}_{_cnt}.png'
- _cnt += 1
- write_image(_path, image.astype(np.float32))
- print(f'[INFO] write image to {_path}')
- else:
- plt.imshow(image.astype(np.float32))
- plt.show()
-
- for x in xs:
- if len(x.shape) == 4:
- for i in range(x.shape[0]):
- _plot_image(x[i])
- else: # 3 or 2
- _plot_image(x)
-
-
-
-
-[docs]
-def plot_matrix(*xs):
- """ visualize some 2D matrix, different from ``kiui.vis.plot_image``, this will keep the original range and plot channel-by-channel.
-
- Args:
- xs (Sequence[Union[torch.Tensor, numpy.ndarray]]): [B, C, H, W], [C, H, W], or [H, W] torch.Tensor or numpy.ndarray
- """
-
- def _plot_matrix(matrix):
-
- lo(matrix)
-
- if isinstance(matrix, torch.Tensor):
- if len(matrix.shape) == 3:
- matrix = matrix.permute(1, 2, 0).squeeze()
- matrix = matrix.detach().cpu().numpy()
-
- if len(matrix.shape) == 3:
- # per channel
- for i in range(matrix.shape[-1]):
- plt.matshow(matrix[..., i])
- plt.show()
- else:
- plt.matshow(matrix.astype(np.float32))
- plt.show()
-
- for x in xs:
- if len(x.shape) == 4:
- for i in range(x.shape[0]):
- _plot_matrix(x[i])
- else: # 3 or 2
- _plot_matrix(x)
-
-
-
-
-[docs]
-def plot_pointcloud(pc, color=None):
- """plot point cloud.
-
- Args:
- pc (ndarray): point cloud positions, float [N, 3].
- color (ndarray, optional): point cloud colors, float/uint8 [N, 3/4]. Defaults to None.
-
- Note:
- This function requires a desktop (cannot be forwarded by ssh)!
- """
-
- lo(pc)
-
- if color is not None:
- lo(color)
- if color.dtype == np.float32:
- color = (color * 255).astype(np.uint8)
-
- if color is None or color.shape[-1] == 3:
- # use o3d as it's better to control
- import open3d as o3d
-
- pcd = o3d.geometry.PointCloud()
- pcd.points = o3d.utility.Vector3dVector(pc)
- if color is not None:
- pcd.colors = o3d.utility.Vector3dVector(color)
- o3d.visualization.draw_geometries([pcd])
-
- else:
- import trimesh
-
- pc = trimesh.PointCloud(pc, color)
- # axis
- axes = trimesh.creation.axis(axis_length=4)
- # sphere
- box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
- box.colors = np.array([[128, 128, 128]] * len(box.entities))
-
- trimesh.Scene([pc, axes, box]).show()
-
-
-
-
-[docs]
-def plot_poses(poses, size=0.1, bound=1, points=None, mesh=None, opengl=True):
- """plot camera poses.
-
- Args:
- poses (ndarray): camera poses, float [N, 4, 4].
- size (float, optional): line width. Defaults to 0.1.
- bound (int, optional): bounding box bound. Defaults to 1.
- points (ndarray, optional): also draw point clouds, float [M, 3]. Defaults to None.
- mesh (trimesh.Trimesh, optional): also draw mesh. Defaults to None.
- opengl (bool, optional): use OpenGL camera convention. Defaults to True.
-
- Note:
- This function requires a desktop (cannot be forwarded by ssh)!
- """
-
- lo(poses)
-
- if torch.is_tensor(poses):
- poses = poses.detach().cpu().numpy()
-
- import trimesh
-
- axes = trimesh.creation.axis(axis_length=4)
- box = trimesh.primitives.Box(extents=[2 * bound] * 3).as_outline()
- box.colors = np.array([[128, 128, 128]] * len(box.entities))
- objects = [axes, box]
-
- if bound > 1:
- unit_box = trimesh.primitives.Box(extents=[2] * 3).as_outline()
- unit_box.colors = np.array([[128, 128, 128]] * len(unit_box.entities))
- objects.append(unit_box)
-
- for pose in poses:
- # a camera is visualized with 8 line segments.
- pos = pose[:3, 3]
- a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
- b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
- c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
- d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
-
- # construct 3D paths
- frame = np.array([
- [pos, a],
- [pos, b],
- [pos, c],
- [pos, d],
- [a, b],
- [b, c],
- [c, d],
- [d, a],
- [pos, pos + pose[:3, 2] * (-1 if opengl else 1) * 3], # point to target
- ])
- frame = trimesh.load_path(frame)
- objects.append(frame)
-
- right_line = np.array([[pos, pos + pose[:3, 0] * size]])
- right_line = trimesh.load_path(right_line)
- right_line.colors = np.array([[255, 0, 0, 255]]).repeat(len(right_line.entities), axis=0)
- objects.append(right_line)
-
- up_line = np.array([[pos, pos + pose[:3, 1] * size]])
- up_line = trimesh.load_path(up_line)
- up_line.colors = np.array([[0, 255, 0, 255]]).repeat(len(up_line.entities), axis=0)
- objects.append(up_line)
-
- forward_line = np.array([[pos, pos + pose[:3, 2] * size]])
- forward_line = trimesh.load_path(forward_line)
- forward_line.colors = np.array([[0, 0, 255, 255]]).repeat(len(forward_line.entities), axis=0)
- objects.append(forward_line)
-
- if points is not None:
-
- lo(points)
-
- colors = np.zeros((points.shape[0], 4), dtype=np.uint8)
- colors[:, 2] = 255 # blue
- colors[:, 3] = 30 # transparent
- objects.append(trimesh.PointCloud(points, colors))
-
- if mesh is not None:
- objects.append(mesh)
-
- scene = trimesh.Scene(objects)
- scene.set_camera(distance=bound, center=[0, 0, 0])
- scene.show()
-
-
- |
- | - |
- |
- | - |
- | - |
- | - |
- | - |
- | - |
- | - |
- | - |
|
- - |
- | - |
- | - |
- | - |
- | - |
- | - |
- |
- | - |
- |
- | - |
A torch-native trimesh class, with support for ply/obj/glb
formats.
Note
-This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
-Init a mesh directly using all attributes.
-v (Optional[Tensor]) – vertices, float [N, 3]. Defaults to None.
f (Optional[Tensor]) – faces, int [M, 3]. Defaults to None.
vn (Optional[Tensor]) – vertex normals, float [N, 3]. Defaults to None.
fn (Optional[Tensor]) – faces for normals, int [M, 3]. Defaults to None.
vt (Optional[Tensor]) – vertex uv coordinates, float [N, 2]. Defaults to None.
ft (Optional[Tensor]) – faces for uvs, int [M, 3]. Defaults to None.
vc (Optional[Tensor]) – vertex colors, float [N, 3]. Defaults to None.
albedo (Optional[Tensor]) – albedo texture, float [H, W, 3], RGB format. Defaults to None.
metallicRoughness (Optional[Tensor]) – metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[…, 2], roughness(Green) = metallicRoughness[…, 1]. Defaults to None.
device (Optional[torch.device]) – torch device. Defaults to None.
load mesh from path.
-path (str) – path to mesh file, supports ply, obj, glb.
clean (bool, optional) – perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
resize (bool, optional) – auto resize the mesh using bound
into [-bound, bound]^3. Defaults to True.
renormal (bool, optional) – re-calc the vertex normals. Defaults to True.
retex (bool, optional) – re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
wotex (bool, optional) – do not try to load any texture. Defaults to False.
bound (float, optional) – bound to resize. Defaults to 0.9.
front_dir (str, optional) – front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to ‘+z’.
device (torch.device, optional) – torch device. Defaults to None.
Note
-a device
keyword argument can be provided to specify the torch device.
-If it’s not provided, we will try to use 'cuda'
as the device if it’s available.
the loaded Mesh object.
-load an obj
mesh.
path (str) – path to mesh.
wotex (bool, optional) – do not try to load any texture. Defaults to False.
albedo_path (str, optional) – path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
device (torch.device, optional) – torch device. Defaults to None.
Note
-We will try to read mtl path from obj, else we assume the file name is the same as obj but with mtl extension. -The usemtl statement is ignored, and we only use the last material path in mtl file.
-the loaded Mesh object.
-load a mesh using trimesh.load()
.
Can load various formats like glb
and serves as a fallback.
Note
-We will try to merge all meshes if the glb contains more than one, -but this may cause the texture to lose, since we only support one texture image!
-path (str) – path to the mesh file.
wotex (bool, optional) – do not try to load any texture. Defaults to False.
device (torch.device, optional) – torch device. Defaults to None.
the loaded Mesh object.
-sample points on the surface of the mesh.
-count (int) – number of points to sample.
-the sampled points, float [count, 3].
-torch.Tensor
-get the axis-aligned bounding box of the mesh.
-the min xyz and max xyz of the mesh.
-Tuple[torch.Tensor]
-auto resize the mesh.
-bound (float, optional) – resizing into [-bound, bound]^3
. Defaults to 0.9.
auto calculate the vertex normals.
-auto calculate the uv coordinates.
-cache_path (str, optional) – path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
vmap (bool, optional) – remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). -Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
remap uv texture (vt) to other surface.
-v (torch.Tensor) – the target mesh vertices, float [N, 3].
-remap v/f and vn/fn to vt/ft.
-vmapping (np.ndarray, optional) – the mapping relationship from f to ft. Defaults to None.
-move all tensor attributes to device.
-device (torch.device) – target device.
-self.
-write the mesh to a path.
-path (str) – path to write, supports ply, obj and glb.
-write the mesh in ply format. Only for geometry!
-path (str) – path to write.
-This will create a scene with a single mesh.
-path (str) – path to write.
-write the mesh in obj format. Will also write the texture and mtl files.
-path (str) – path to write.
-perform mesh decimation.
-verts (np.ndarray) – mesh vertices, float [N, 3]
faces (np.ndarray) – mesh faces, int [M, 3]
target (int) – targeted number of faces
backend (str, optional) – algorithm backend, can be “pymeshlab” or “pyfqmr”. Defaults to “pymeshlab”.
remesh (bool, optional) – whether to remesh after decimation. Defaults to False.
optimalplacement (bool, optional) – For flat mesh, use False to prevent spikes. Defaults to True.
verbose (bool, optional) – whether to print the decimation process. Defaults to True.
vertices and faces after decimation.
-Tuple[np.ndarray]
-perform mesh cleaning, including floater removal, non manifold repair, and remeshing.
-verts (np.ndarray) – mesh vertices, float [N, 3]
faces (np.ndarray) – mesh faces, int [M, 3]
v_pct (int, optional) – percentage threshold to merge close vertices. Defaults to 1.
min_f (int, optional) – maximal number of faces for isolated component to remove. Defaults to 0.
min_d (int, optional) – maximal diameter percentage of isolated component to remove. Defaults to 0.
repair (bool, optional) – whether to repair non-manifold faces (cannot gurantee). Defaults to True.
remesh (bool, optional) – whether to perform a remeshing after all cleaning. Defaults to True.
remesh_size (float, optional) – the targeted edge length for remeshing. Defaults to 0.01.
remesh_iters (int, optional) – the iterations of remeshing. Defaults to 3.
verbose (bool, optional) – whether to print the cleaning process. Defaults to True.
vertices and faces after decimation.
-Tuple[np.ndarray]
-calculate laplacian uniform matrix
-verts (torch.Tensor) – mesh vertices, float [N, 3]
faces (torch.Tensor) – mesh faces, long [M, 3]
sparse laplacian matrix.
-torch.Tensor
-calculate laplacian smooth loss.
-verts (torch.Tensor) – mesh vertices, float [N, 3]
faces (torch.Tensor) – mesh faces, int [M, 3]
loss value.
-torch.Tensor
-compute edge to face mapping.
-faces (torch.Tensor) – mesh faces, int [M, 3]
-indices to faces for each edge, long, [N, 2]
-torch.Tensor
-calculate normal consistency loss.
-verts (torch.Tensor) – mesh vertices, float [N, 3]
faces (torch.Tensor) – mesh faces, int [M, 3]
face_normals (Optional[torch.Tensor]) – the normal vector for each face, will be calculated if not provided, float [M, 3]
loss value.
-torch.Tensor
-Miscellaneous API.
lazy load functional super-resolution API for convenience.
-image (ndarray) – input image, uint8/float32 [H, W, 3]
scale (Literal[2, 4, 8], optional) – upscale factor. Defaults to 2.
device (torch.device, optional) – device to put SR models, if not provided, will try to use ‘cuda’. Defaults to None.
super-resolutioned image, uint8/float32 [H * scale, W * scale, 3]
-ndarray
-put back values to an image according to the coords. inverse operation of F.grid_sample
.
shape (Sequence[int]) – shape of the image, support 2D image and 3D volume, sequence of [D]
coords (Tensor) – coordinates, float [N, D] in [-1, 1].
values (Tensor) – values, float [N, C].
mode (str, Literal[‘nearest’, ‘linear’, ‘linear-mipmap’]) – interpolation mode, see https://github.com/ashawkey/grid_put for examples. Defaults to ‘linear-mipmap’.
min_resolution (int, optional) – minimal resolution for mipmap. Defaults to 32.
return_count (bool, optional) – whether to return the summed value and weights, instead of the divided results. Defaults to False.
the restored image/volume, float [H, W, C]/[H, W, D, C].
-Tensor
-sr()
grid_put()
A collection of operators for numpy
and torch
.
dot product (along the last dim).
-x (Union[Tensor, ndarray]) – x, […, C]
y (Union[Tensor, ndarray]) – y, […, C]
x dot y, […, 1]
-Union[Tensor, ndarray]
-length of an array (along the last dim).
-x (Union[Tensor, ndarray]) – x, […, C]
eps (float, optional) – eps. Defaults to 1e-20.
length, […, 1]
-Union[Tensor, ndarray]
-normalize an array (along the last dim).
-x (Union[Tensor, ndarray]) – x, […, C]
eps (float, optional) – eps. Defaults to 1e-20.
normalized x, […, C]
-Union[Tensor, ndarray]
-normalize an array (along the last dim). alias of safe_normalize.
-x (Union[Tensor, ndarray]) – x, […, C]
eps (float, optional) – eps. Defaults to 1e-20.
normalized x, […, C]
-Union[Tensor, ndarray]
-make an int x divisible by m.
-x (int) – x
m (int, optional) –
Defaults to 8.
x + (m - x % m)
-int
-inversion of sigmoid function.
-x (Tensor) – x
eps (float, optional) – eps. Defaults to 1e-6.
log(x / (1 - x))
-Tensor
-inversion of softplus function.
-x (Tensor) – x
-log(exp(x) - 1)
-Tensor
-image scaling helper.
-x (Tensor) – input image, float [N, H, W, C]
size (Sequence[int]) – target size, tuple of [H’, W’]
mag (str, optional) – upscale interpolation mode. Defaults to ‘bilinear’.
min (str, optional) – downscale interpolation mode. Defaults to ‘bilinear’.
rescaled image, float [N, H’, W’, C]
-Tensor
-image scaling helper.
-x (Tensor) – input image, float [H, W, C]
size (Sequence[int]) – target size, tuple of [H’, W’]
mag (str, optional) – upscale interpolation mode. Defaults to ‘bilinear’.
min (str, optional) – downscale interpolation mode. Defaults to ‘bilinear’.
rescaled image, float [H’, W’, C]
-Tensor
-image scaling helper.
-x (Tensor) – input image, float [N, H, W]
size (Sequence[int]) – target size, tuple of [H’, W’]
mag (str, optional) – upscale interpolation mode. Defaults to ‘bilinear’.
min (str, optional) – downscale interpolation mode. Defaults to ‘bilinear’.
rescaled image, float [N, H’, W’]
-Tensor
-image scaling helper.
-x (Tensor) – input image, float [H, W]
size (Sequence[int]) – target size, tuple of [H’, W’]
mag (str, optional) – upscale interpolation mode. Defaults to ‘bilinear’.
min (str, optional) – downscale interpolation mode. Defaults to ‘bilinear’.
rescaled image, float [H’, W’]
-Tensor
-padding the uv-space texture image to avoid seam artifacts in mipmaps.
-image (Union[Tensor, ndarray]) – texture image, float, [H, W, C] in [0, 1].
mask (Union[Tensor, ndarray]) – valid uv region, bool, [H, W].
padding (int, optional) – padding size into the unmasked region. Defaults to 0.1 * max(H, W).
backend (Literal['knn', 'cv2'], optional) – algorithm backend, knn is faster. Defaults to ‘knn’.
padded texture image. float, [H, W, C].
-Union[Tensor, ndarray]
-recenter an image to leave some empty space at the image border.
-image (ndarray) – input image, float/uint8 [H, W, 3/4]
mask (ndarray) – alpha mask, bool [H, W]
border_ratio (float, optional) – border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
output image, float/uint8 [H, W, 3/4]
-ndarray
-- | ||
k | - | |
- |
- kiui | - | -
- |
- kiui.cam | - | -
- |
- kiui.mesh | - | -
- |
- kiui.mesh_utils | - | -
- |
- kiui.op | - | -
- |
- kiui.utils | - | -
- |
- kiui.vis | - | -
inspect array like objects and report statistics.
-xs (Any) – array like objects to inspect.
verbose (int, optional) – level of verbosity, set to 1 to report mean and std, 2 to print the content. Defaults to 0.
auto set seed for random, numpy and torch.
-seed (int, optional) – random seed. Defaults to 42.
verbose (bool, optional) – whether to report each seed setting. Defaults to False.
strict (bool, optional) – whether to use strict deterministic mode for better torch reproduction. Defaults to False.
load a json file.
-path (str) – path to json file.
-json content.
-dict
-write a json file.
-path (str) – path to write json file.
x (dict) – dict to write.
read a pickle file.
-path (str) – path to pickle file.
-pickle content.
-Any
-write a pickle file.
-path (str) – path to write pickle file.
x (Any) – content to write.
read an image file into various formats and color mode.
-path (str) – path to the image file.
mode (Literal[“float”, “uint8”, “pil”, “torch”, “tensor”], optional) – returned image format. Defaults to “float”. -float: float32 numpy array, range [0, 1]; -uint8: uint8 numpy array, range [0, 255]; -pil: PIL image; -torch/tensor: float32 torch tensor, range [0, 1];
order (Literal[“RGB”, “RGBA”, “BGR”, “BGRA”], optional) – channel order. Defaults to “RGB”.
Note
-By default this function will convert RGBA image to white-background RGB image. Use order="RGBA"
to keep the alpha channel.
the image array.
-Union[np.ndarray, PIL.Image, torch.Tensor]
-write an image to various formats.
-path (str) – path to write the image file.
img (Union[torch.Tensor, np.ndarray, PIL.Image.Image]) – image to write.
order (str, optional) – channel order. Defaults to “RGB”.
Load file form http url, will download models if necessary.
-url (str) – URL to be downloaded.
model_dir (str) – The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. -Default: None.
progress (bool) – Whether to show the download progress. Default: True.
file_name (str) – The downloaded file name. If None, use the file name in the url. Default: None.
The path to the downloaded file.
-str
-if a file’s extension is in a set of format
-f (str) – file name.
format (Sequence[str]) – set of extensions (both ‘.jpg’ or ‘jpg’ is ok).
if the file’s extension is in the set.
-bool
-simple function wrapper to batch processing files.
-process_fn (Callable) – process function.
path (str) – path to a file or a directory containing the files to process.
out_path (str) – output path of a file or a directory.
overwrite (bool, optional) – whether to overwrite existing results. Defaults to False.
in_format (list, optional) – input file formats. Defaults to [“.jpg”, “.jpeg”, “.png”].
out_format (str, optional) – output file format. Defaults to None.
image_mode (str, optional) – for images, the mode to read. Defaults to ‘uint8’.
image_color_order (str, optional) – for images, the color order. Defaults to “RGB”.
Visualization tools.
-map a 1D array to continuous color.
-value (ndarray) – array of float, [N]
cmap_name (str, optional) – color map name, ref: https://matplotlib.org/stable/users/explain/colors/colormaps.html#classes-of-colormaps. Defaults to “viridis”.
vmin (float, optional) – min value. Defaults to None.
vmax (float, optional) – max value. Defaults to None.
array of color, [N, 3] in [0, 1]
-ndarray
-sequentially plot provided images, optionally save to current dir.
-xs (Sequence[Union[torch.Tensor, numpy.ndarray]]) – can be uint8 or float32. -[B, 4/3/1, H, W], [B, H, W, 4/3/1], [4/3/1, H, W], [H, W, 4/3/1], [H, W] torch.Tensor or numpy.ndarray
normalize (bool, optional) – whether to renormalize the image to [0, 1]. Defaults to False.
save (bool, optional) – whether to save the image to current dir (in case the plot cannot be showed, like in vscode remote). Defaults to False.
prefix (str, optional) – image save name prefix if save=True.
visualize some 2D matrix, different from kiui.vis.plot_image
, this will keep the original range and plot channel-by-channel.
xs (Sequence[Union[torch.Tensor, numpy.ndarray]]) – [B, C, H, W], [C, H, W], or [H, W] torch.Tensor or numpy.ndarray
-plot point cloud.
-pc (ndarray) – point cloud positions, float [N, 3].
color (ndarray, optional) – point cloud colors, float/uint8 [N, 3/4]. Defaults to None.
Note
-This function requires a desktop (cannot be forwarded by ssh)!
-plot camera poses.
-poses (ndarray) – camera poses, float [N, 4, 4].
size (float, optional) – line width. Defaults to 0.1.
bound (int, optional) – bounding box bound. Defaults to 1.
points (ndarray, optional) – also draw point clouds, float [M, 3]. Defaults to None.
mesh (trimesh.Trimesh, optional) – also draw mesh. Defaults to None.
opengl (bool, optional) – use OpenGL camera convention. Defaults to True.
Note
-This function requires a desktop (cannot be forwarded by ssh)!
-