Skip to content

Commit

Permalink
filter Nan and infs
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Oct 16, 2024
1 parent 9a3e606 commit 4c10e88
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,33 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No
quats = numpy_data["quats"]
opacities = numpy_data["opacities"]

sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)

# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)

# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]

num_points = means.shape[0]

with open(dir, "wb") as f:
Expand All @@ -213,8 +240,6 @@ def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = No
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
Expand Down

0 comments on commit 4c10e88

Please sign in to comment.