Skip to content

Commit

Permalink
feat: refactor gs_model and add Parser class
Browse files Browse the repository at this point in the history
  • Loading branch information
AtticusZeller committed Jun 30, 2024
1 parent b6595f6 commit 6d30853
Show file tree
Hide file tree
Showing 14 changed files with 981 additions and 218 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions src/gsplat_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import torch

from my_gsplat.trainer import Runner
from my_gsplat.gs_trainer import Runner
from pose_estimation import DEVICE

# from my_gsplat.trainer import Runner


def main():
# BUG: failed to show results
runner = Runner()
runner.adjust_steps()
if runner.ckpt is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(
self._pose = to_tensor(pose, device=DEVICE, requires_grad=True)
self._pcd = self._project_pcds(include_homogeneous=False)

@property
def size(self):
return self._pcd.size(0)

@property
def color(self) -> Tensor:
"""
Expand Down
Empty file.
60 changes: 23 additions & 37 deletions src/my_gsplat/base.py → src/my_gsplat/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from nerfview import Viewer
from numpy.typing import NDArray
from torch import Tensor
from torchmetrics.image import (
LearnedPerceptualImagePatchSimilarity,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
from viser import ViserServer

from .datasets.normalize import normalize_dataset_slice, scene_scale
from .structure import Replica, RGBDImage
from .utils import DEVICE
from ..utils import DEVICE
from .Image import RGBDImage


@dataclass
Expand Down Expand Up @@ -41,34 +41,6 @@ class DatasetConfig:
pcd: NDArray | None = None # N,3
color: NDArray | None = None # N,3

def load_data(self, depth_loss, normalize: bool = True):
# Load data: Training data should contain initial points and colors.
# self.parser = Parser(
# data_dir=self.data_dir,
# factor=self.data_factor,
# normalize=True,
# test_every=self.test_every,
# )
# self.trainset = Dataset(
# self.parser,
# split="train",
# patch_size=self.patch_size,
# load_depths=depth_loss,

start = 1000
step = 20
self.trainset = normalize_dataset_slice(Replica()[start:start+step:8])
print(len(self.trainset))
# self.valset = Dataset(self.parser, split="val")
self.scene_scale = scene_scale(self.trainset).item() * 1.1 * self.global_scale

self.c2w_gts = []
for rgb_d in self.trainset:
self.c2w_gts.append(rgb_d.pose)
# rgb_d.pose = self.trainset[0].pose

print("Scene scale:", self.scene_scale)

def make_dir(self):
# Where to dump results.
self.res_dir = Path(self.result_dir)
Expand All @@ -86,9 +58,9 @@ def make_dir(self):
@dataclass
class TrainingConfig:
batch_size: int = 1
max_steps: int = 200
eval_steps: list[int] = field(default_factory=lambda: [7_000, 30_000])
save_steps: list[int] = field(default_factory=lambda: [7_000, 30_000])
max_steps: int = 1000
eval_steps: list[int] = field(default_factory=lambda: [200, 30_000])
save_steps: list[int] = field(default_factory=lambda: [1000, 30_000])
steps_scaler: float = 1.0
refine_start_iter: int = 500
refine_stop_iter: int = 15_000
Expand Down Expand Up @@ -134,9 +106,9 @@ class RasterizeConfig:

@dataclass
class CameraConfig:
pose_opt: bool = False
pose_opt_lr: float = 1e-5
pose_opt_reg: float = 1e-6
pose_opt: bool = True
pose_opt_lr: float = 1e-6
pose_opt_reg: float = 1e-3
pose_noise: float = 0.0


Expand Down Expand Up @@ -202,3 +174,17 @@ def adjust_steps(self, factor: float = 1.0):
self.refine_stop_iter = int(self.refine_stop_iter * factor)
self.reset_every = int(self.reset_every * factor)
self.refine_every = int(self.refine_every * factor)


@dataclass
class AlignData:
"""normed data"""

# for GS
scene_scale: float
colors: Tensor # N,3
pixels: Tensor # H,W,3
points: Tensor # N,3
tar_c2w: Tensor # 4,4
src_c2w: Tensor # 4,4
tar_nums: int # for slice tar and src
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import cv2
import numpy as np
import torch
from natsort import natsorted

from ..utils import as_intrinsics_matrix, load_camera_cfg
from ..utils import as_intrinsics_matrix, load_camera_cfg, to_tensor
from .base import AlignData
from .Image import RGBDImage
from .normalize import normalize_2, scene_scale


class DataLoaderBase:
Expand Down Expand Up @@ -136,3 +139,36 @@ def _filepaths(self) -> tuple[list[Path], list[Path]]:
f"Number of color and depth images do not match in {self.input_folder}."
)
return color_paths, depth_paths


class Parser(Replica):

def __init__(self):
super().__init__()
self.K = to_tensor(self.K)

def __len__(self) -> int:
return super().__len__() - 1

def __getitem__(self, index: int) -> AlignData:
assert index < len(self)
tar, src = super().__getitem__(index), super().__getitem__(index + 5)
tar_normed, src_normed = normalize_2(tar, src)
scene_scale_normed = scene_scale([tar_normed, src_normed])

# test
points = torch.cat([tar_normed.points, src_normed.points], dim=0) # N,3
rgbs = torch.stack(
[tar_normed.color / 255.0, src_normed.color / 255.0], dim=0
).reshape(
-1, 3
) # N,3
return AlignData(
scene_scale_normed,
rgbs,
src_normed.color,
points,
tar_c2w=tar_normed.pose,
src_c2w=src_normed.pose,
tar_nums=tar_normed.points.shape[0],
)
41 changes: 37 additions & 4 deletions src/my_gsplat/datasets/normalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from ..structure import RGBDImage
from .Image import RGBDImage


@torch.no_grad()
Expand Down Expand Up @@ -206,14 +206,47 @@ def normalize_dataset_slice(dataset_slice: list[RGBDImage]) -> list[RGBDImage]:
for i, rgb_d in enumerate(dataset_slice):
num_points = len(rgb_d.points)
rgb_d.pose = poses[i]
rgb_d.points = points[start_idx: start_idx + num_points]
rgb_d.points = points[start_idx : start_idx + num_points]
start_idx += num_points

return dataset_slice


@torch.no_grad()
def scene_scale(dataset_slice: list[RGBDImage]) -> torch.Tensor:
def normalize_2(tar: RGBDImage, src: RGBDImage) -> tuple[RGBDImage, RGBDImage]:
"""normalize two rgb image with tar pose"""

# combine as one scene
poses = torch.stack([tar.pose, src.pose], dim=0)
# NOTE: transform to world,init with first pose
points = torch.cat(
[
transform_points(tar.pose, tar.points),
transform_points(tar.pose, src.points),
],
dim=0,
)
# normalize
T1 = similarity_from_cameras(poses)
poses = transform_cameras(T1, poses)
points = transform_points(T1, points)

T2 = align_principle_axes(points)
poses = transform_cameras(T2, poses)
points = transform_points(T2, points)

# transform = T2 @ T1

# Update the original data with normalized values
num_points = len(tar.points)
tar.points, src.points = points[:num_points], points[num_points:]
tar.pose, src.pose = poses[0], poses[1]

return tar, src


@torch.no_grad()
def scene_scale(dataset_slice: list[RGBDImage], global_scale: float = 1.0) -> float:
poses = torch.stack([rgb_d.pose for rgb_d in dataset_slice], dim=0)

camera_locations = poses[:, :3, 3]
Expand All @@ -223,4 +256,4 @@ def scene_scale(dataset_slice: list[RGBDImage]) -> torch.Tensor:
dists = torch.norm(camera_locations - scene_center, dim=1)
scale = torch.max(dists)

return scale
return scale.item() * 1.1 * global_scale
Loading

0 comments on commit 6d30853

Please sign in to comment.