Skip to content

Commit

Permalink
A bunch of changes that allows benchmarking Taichi GS in large number…
Browse files Browse the repository at this point in the history
… of scenes (#153)

See details at
#151

---------

Co-authored-by: wanmeihuali <[email protected]>
Co-authored-by: Jianbo Ye <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2023
1 parent f7631e3 commit 2447148
Show file tree
Hide file tree
Showing 18 changed files with 992 additions and 319 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ data/
Data/
*.egg
*.egg-info
imgui.ini
imgui.ini
*.qdstrm
*.nsys-rep
*.sqlite
2 changes: 1 addition & 1 deletion benchmark/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taichi_3d_gaussian_splatting.GaussianPointCloudRasterisation import GaussianPointCloudRasterisation
from taichi_3d_gaussian_splatting.ImagePoseDataset import ImagePoseDataset
from taichi_3d_gaussian_splatting.GaussianPointCloudScene import GaussianPointCloudScene
from taichi_3d_gaussian_splatting.utils import torch2ti, se3_to_quaternion_and_translation_torch, quaternion_rotate_torch, quaternion_multiply_torch, quaternion_conjugate_torch
from taichi_3d_gaussian_splatting.utils import torch2ti, SE3_to_quaternion_and_translation_torch, quaternion_rotate_torch, quaternion_multiply_torch, quaternion_conjugate_torch
# %%
ITERATIONS = 100
WARMUP_ITERATIONS = 1000
Expand Down
4 changes: 2 additions & 2 deletions config/tat_truck_every_8_test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
adaptive-controller-config:
densification-view-space-position-gradients-threshold: 6e-6
densification-view-space-position-gradients-threshold: 3e-6
densification_view_avg_space_position_gradients_threshold: 1e3
densification_multi_frame_view_space_position_gradients_threshold: 1e3
densification_multi_frame_view_pixel_avg_space_position_gradients_threshold: 1e3
Expand Down Expand Up @@ -28,7 +28,7 @@ log-image-interval: 200
log-loss-interval: 10
log-metrics-interval: 100
print-metrics-to-console: False
enable_taichi_kernel_profiler: True
enable_taichi_kernel_profiler: False
log_taichi_kernel_profile_interval: 3000
log_validation_image: False
feature_learning_rate: 0.005
Expand Down
176 changes: 176 additions & 0 deletions gaussian_point_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/bin/python3

import argparse
import taichi as ti
from taichi_3d_gaussian_splatting.Camera import CameraInfo
from taichi_3d_gaussian_splatting.GaussianPointCloudRasterisation import GaussianPointCloudRasterisation
from taichi_3d_gaussian_splatting.GaussianPointCloudScene import GaussianPointCloudScene
from taichi_3d_gaussian_splatting.utils import SE3_to_quaternion_and_translation_torch, quaternion_to_rotation_matrix_torch
from dataclasses import dataclass
from taichi_3d_gaussian_splatting.ImagePoseDataset import ImagePoseDataset

import torch
import torchvision
import numpy as np
from PIL import Image
from pathlib import Path
import os
from tqdm import tqdm

class GaussianPointRenderer:
@dataclass
class GaussianPointRendererConfig:
parquet_path: str
cameras: torch.Tensor
device: str = "cuda"
image_height: int = 544
image_width: int = 976
camera_intrinsics: torch.Tensor = torch.tensor(
[[581.743, 0.0, 488.0], [0.0, 581.743, 272.0], [0.0, 0.0, 1.0]],
device="cuda")

def set_portrait_mode(self):
self.image_height = 976
self.image_width = 544
self.camera_intrinsics = torch.tensor(
[[1163.486, 0.0, 272.0], [0.0, 1163.486, 488.0], [0.0, 0.0, 1.0]],
device="cuda")

@dataclass
class ExtraSceneInfo:
start_offset: int
end_offset: int
center: torch.Tensor
visible: bool

def __init__(self, config: GaussianPointRendererConfig) -> None:
self.config = config
self.config.image_height = self.config.image_height - self.config.image_height % 16
self.config.image_width = self.config.image_width - self.config.image_width % 16
scene = GaussianPointCloudScene.from_parquet(
config.parquet_path, config=GaussianPointCloudScene.PointCloudSceneConfig(max_num_points_ratio=None))
self.scene = self._merge_scenes([scene])
self.scene = self.scene.to(self.config.device)
self.cameras = self.config.cameras.to(self.config.device)
self.camera_info = CameraInfo(
camera_intrinsics=self.config.camera_intrinsics.to(
self.config.device),
camera_width=self.config.image_width,
camera_height=self.config.image_height,
camera_id=0,
)
self.rasteriser = GaussianPointCloudRasterisation(
config=GaussianPointCloudRasterisation.GaussianPointCloudRasterisationConfig(
near_plane=0.8,
far_plane=1000.,
depth_to_sort_key_scale=100.))

def _merge_scenes(self, scene_list):
# the config does not matter here, only for training

merged_point_cloud = torch.cat(
[scene.point_cloud for scene in scene_list], dim=0)
merged_point_cloud_features = torch.cat(
[scene.point_cloud_features for scene in scene_list], dim=0)
num_of_points_list = [scene.point_cloud.shape[0]
for scene in scene_list]
start_offset_list = [0] + np.cumsum(num_of_points_list).tolist()[:-1]
end_offset_list = np.cumsum(num_of_points_list).tolist()
self.extra_scene_info_dict = {
idx: self.ExtraSceneInfo(
start_offset=start_offset,
end_offset=end_offset,
center=scene_list[idx].point_cloud.mean(dim=0),
visible=True
) for idx, (start_offset, end_offset) in enumerate(zip(start_offset_list, end_offset_list))
}
point_object_id = torch.zeros(
(merged_point_cloud.shape[0],), dtype=torch.int32, device=self.config.device)
for idx, (start_offset, end_offset) in enumerate(zip(start_offset_list, end_offset_list)):
point_object_id[start_offset:end_offset] = idx
merged_scene = GaussianPointCloudScene(
point_cloud=merged_point_cloud,
point_cloud_features=merged_point_cloud_features,
point_object_id=point_object_id,
config=GaussianPointCloudScene.PointCloudSceneConfig(
max_num_points_ratio=None
))
return merged_scene

def run(self, output_prefix):
num_cameras = self.cameras.shape[0]
for i in tqdm(range(num_cameras)):
c = self.cameras[i, :, :].unsqueeze(0)
q, t = SE3_to_quaternion_and_translation_torch(c)

with torch.no_grad():
image, _, _ = self.rasteriser(
GaussianPointCloudRasterisation.GaussianPointCloudRasterisationInput(
point_cloud=self.scene.point_cloud,
point_cloud_features=self.scene.point_cloud_features,
point_invalid_mask=self.scene.point_invalid_mask,
point_object_id=self.scene.point_object_id,
camera_info=self.camera_info,
q_pointcloud_camera=q,
t_pointcloud_camera=t,
color_max_sh_band=3,
)
)

img = Image.fromarray(torch.clamp(image * 255, 0, 255).byte().cpu().numpy(), 'RGB')
img.save(output_prefix / f'frame_{i:03}.png')

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--parquet_path", type=str, required=True)
parser.add_argument("--poses", type=str, required=True, help="could be a .pt file that was saved as torch.save(), or a json dataset file used by Taichi-GS")
parser.add_argument("--output_prefix", type=str, required=True)
parser.add_argument("--gt_prefix", type=str, default="")
parser.add_argument("--portrait_mode", action='store_true', default=False)
args = parser.parse_args()
ti.init(arch=ti.cuda, device_memory_GB=4, kernel_profiler=True)

output_prefix = Path(args.output_prefix)
os.makedirs(output_prefix, exist_ok=True)
if args.gt_prefix:
gt_prefix = Path(args.gt_prefix)
os.makedirs(gt_prefix, exist_ok=True)
else:
gt_prefix = None

if args.poses.endswith(".pt"):
config = GaussianPointRenderer.GaussianPointRendererConfig(
args.parquet_path, torch.load(args.poses))
if args.portrait_mode:
config.set_portrait_mode()
elif args.poses.endswith(".json"):
val_dataset = ImagePoseDataset(
dataset_json_path=args.poses)
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=None, shuffle=False, pin_memory=True, num_workers=4)

cameras = torch.zeros((len(val_data_loader), 4, 4))
camera_info = None
for idx, val_data in enumerate(tqdm(val_data_loader)):
image_gt, q, t, camera_info = val_data
r = quaternion_to_rotation_matrix_torch(q)
cameras[idx, :3, :3] = r
cameras[idx, :3, 3] = t
cameras[idx, 3, 3] = 1.0
# dump autoscaled GT images at the resolution of training
if gt_prefix is not None:
img = torchvision.transforms.functional.to_pil_image(image_gt)
img.save(gt_prefix / f'frame_{idx:03}.png')
config = GaussianPointRenderer.GaussianPointRendererConfig(
args.parquet_path, cameras
)
# override camera meta data as provided
config.image_width = camera_info.camera_width
config.image_height = camera_info.camera_height
config.camera_intrinsics = camera_info.camera_intrinsics
else:
raise ValueError(f"Unrecognized poses file format: {args.poses}, Must be .pt or .json file")

renderer = GaussianPointRenderer(config)
renderer.run(output_prefix)
15 changes: 15 additions & 0 deletions parquet_to_ply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import argparse
import pandas as pd
from taichi_3d_gaussian_splatting.GaussianPointCloudScene import GaussianPointCloudScene

def save_ply(pointcloud):
print(pointcloud.head())

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--parquet_path", type=str, required=True)
parser.add_argument("--ply_path", type=str, required=True)
args = parser.parse_args()
scene = GaussianPointCloudScene.from_parquet(
args.parquet_path, config=GaussianPointCloudScene.PointCloudSceneConfig(max_num_points_ratio=None))
scene.to_ply(args.ply_path)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ pyyaml
pandas[parquet]>=2.0.0
scipy
argparse
tensorboard
tensorboard
plyfile
11 changes: 5 additions & 6 deletions taichi_3d_gaussian_splatting/GaussianPoint3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def rotation_matrix_from_quaternion(q: ti.math.vec4) -> ti.math.mat3:


@ti.func
def tranform_matrix_from_quaternion_and_translation(q: ti.math.vec4, t: ti.math.vec3) -> ti.math.mat4:
def transform_matrix_from_quaternion_and_translation(q: ti.math.vec4, t: ti.math.vec3) -> ti.math.mat4:
"""
Convert a quaternion and a translation to a transformation matrix.
"""
Expand Down Expand Up @@ -86,6 +86,7 @@ def get_projective_transform_jacobian(
[0, fy / z, -(fy * y) / (z * z)]
])


@ti.func
def box_muller_transform(u1, u2):
z1 = ti.sqrt(-2 * ti.log(u1)) * ti.cos(2 * 3.141592653589 * u2)
Expand Down Expand Up @@ -128,7 +129,6 @@ def project_to_camera_position_with_extra_translation_and_rotation_and_scale(
translation = quaternion_rotate(extra_rotation, translation)
return project_point_to_camera(translation, T_camera_world, projective_transform)


@ti.func
def project_to_camera_position_jacobian(
self,
Expand Down Expand Up @@ -196,7 +196,7 @@ def project_to_camera_covariance_with_extra_rotation_and_scale(
T_camera_world: ti.math.mat4,
projective_transform: ti.math.mat3,
translation_camera: ti.math.vec3,
extra_rotation_quaternion: ti.math.vec4,
extra_rotation_quaternion: ti.math.vec4,
extra_scale: ti.math.vec3,
):
"""
Expand All @@ -214,7 +214,7 @@ def project_to_camera_covariance_with_extra_rotation_and_scale(
])
# covariance matrix, 3x3, equation (6) in the paper
Sigma = R @ S @ S.transpose() @ R.transpose()

# for inference, we can add extra rotation and scale to the covariance matrix
# e.g. when we want to rotate or resize point cloud for an object in the scene
R_extra = rotation_matrix_from_quaternion(extra_rotation_quaternion)
Expand Down Expand Up @@ -368,7 +368,7 @@ def get_color_with_jacobian_by_ray(
r_jacobian = r_normalized_jacobian * r_jacobian
g_jacobian = g_normalized_jacobian * g_jacobian
b_jacobian = b_normalized_jacobian * b_jacobian

# return ti.math.vec3(r, g, b), r_jacobian, g_jacobian, b_jacobian
return ti.math.vec3(r_normalized, g_normalized, b_normalized), r_jacobian, g_jacobian, b_jacobian

Expand Down Expand Up @@ -404,7 +404,6 @@ def sample(self) -> ti.math.vec3:
])
base = ti.math.vec3(z1, z2, z3)
return self.translation + R @ S @ base



# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _find_densify_points(self, input_data: GaussianPointCloudRasterisation.Backw
ax.legend()
image_height = input_data.magnitude_grad_viewspace_on_image.shape[0]
image_width = input_data.magnitude_grad_viewspace_on_image.shape[1]
print(image_height, image_width)
# print(image_height, image_width)
ax.set_xlim([0, image_width])
ax.set_ylim([image_height, 0])
self.has_plot = True
Expand Down Expand Up @@ -354,7 +354,8 @@ def _add_densify_points(self):

def reset_alpha(self):
pointcloud_features = self.maintained_parameters.pointcloud_features
pointcloud_features[:, 7] = self.config.reset_alpha_value
pointcloud_features[:, 7] = torch.clamp(pointcloud_features[:, 7],
max=self.config.reset_alpha_value)

def _generate_point_offset(self,
point_to_split: torch.Tensor, # (N, 3)
Expand Down
Loading

0 comments on commit 2447148

Please sign in to comment.