diff --git a/NeRF/tensorflow/.gitignore b/NeRF/tensorflow/.gitignore new file mode 100644 index 0000000..398d29f --- /dev/null +++ b/NeRF/tensorflow/.gitignore @@ -0,0 +1,2 @@ +data +temp \ No newline at end of file diff --git a/NeRF/tensorflow/colmap.py b/NeRF/tensorflow/colmap.py new file mode 100644 index 0000000..1acbace --- /dev/null +++ b/NeRF/tensorflow/colmap.py @@ -0,0 +1,440 @@ +import collections +import os +import struct +import subprocess + +import numpy as np + + +# The following three functions are adapted from https://github.com/fyusion/llff +# And the rest is from COLMAP +def load_colmap_data(realdir): + + camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin') + camdata = read_cameras_binary(camerasfile) + + # cam = camdata[camdata.keys()[0]] + list_of_keys = list(camdata.keys()) + cam = camdata[list_of_keys[0]] + print('Cameras', len(cam)) + + h, w, f = cam.height, cam.width, cam.params[0] + # w, h, f = factor * w, factor * h, factor * f + hwf = np.array([h, w, f]).reshape([3, 1]) + + imagesfile = os.path.join(realdir, 'sparse/0/images.bin') + imdata = read_images_binary(imagesfile) + + w2c_mats = [] + bottom = np.array([0, 0, 0, 1.]).reshape([1, 4]) + + names = [imdata[k].name for k in imdata] + print('Images #', len(names)) + perm = np.argsort(names) + for k in imdata: + im = imdata[k] + R = im.qvec2rotmat() + t = im.tvec.reshape([3, 1]) + m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) + w2c_mats.append(m) + + w2c_mats = np.stack(w2c_mats, 0) + c2w_mats = np.linalg.inv(w2c_mats) + + poses = c2w_mats[:, :3, :4].transpose([1, 2, 0]) + poses = np.concatenate( + [poses, np.tile(hwf[..., np.newaxis], [1, 1, poses.shape[-1]])], 1) + + points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin') + pts3d = read_points3d_binary(points3dfile) + + # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t] + poses = np.concatenate([ + poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], + poses[:, 3:4, :], poses[:, 4:5, :] + ], 1) + + return poses, pts3d, perm + + +def run_colmap(basedir, colmap_location, match_type): + reconstructor_args = [ + colmap_location, + 'automatic_reconstructor', + '--workspace_path', + basedir, + '--image_path', + os.path.join(basedir, 'images'), + ] + + print(' '.join(reconstructor_args)) + + subprocess.run(reconstructor_args) + + print('Finished running COLMAP.') + + +def save_poses(basedir, poses, pts3d, perm): + pts_arr = [] + vis_arr = [] + for k in pts3d: + pts_arr.append(pts3d[k].xyz) + cams = [0] * poses.shape[-1] + for ind in pts3d[k].image_ids: + if len(cams) < ind - 1: + print( + 'ERROR: the correct camera poses for current points cannot be accessed' + ) + return + cams[ind - 1] = 1 + vis_arr.append(cams) + + pts_arr = np.array(pts_arr) + vis_arr = np.array(vis_arr) + print('Points', pts_arr.shape, 'Visibility', vis_arr.shape) + + zvals = np.sum( + -(pts_arr[:, np.newaxis, :].transpose([2, 0, 1]) - poses[:3, 3:4, :]) * + poses[:3, 2:3, :], 0) + valid_z = zvals[vis_arr == 1] + print('Depth stats', valid_z.min(), valid_z.max(), valid_z.mean()) + + save_arr = [] + for i in perm: + vis = vis_arr[:, i] + zs = zvals[:, i] + zs = zs[vis == 1] + close_depth, inf_depth = np.percentile(zs, .1), np.percentile(zs, 99.9) + # print( i, close_depth, inf_depth ) + + save_arr.append( + np.concatenate( + [poses[..., i].ravel(), + np.array([close_depth, inf_depth])], 0)) + save_arr = np.array(save_arr) + print(save_arr.shape) + + np.save(os.path.join(basedir, 'poses_bounds.npy'), save_arr) + + +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) + +CameraModel = collections.namedtuple("CameraModel", + ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple("Camera", + ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, + num_bytes, + format_char_sequence, + endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, + model=model, + width=width, + height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes(fid, + num_bytes=24, + format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, + num_bytes=8 * num_params, + format_char_sequence="d" * num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([ + tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3])) + ]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image(id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, + num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * + num_points2D) + xys = np.column_stack([ + tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3])) + ]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image(id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids) + return images + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, + num_bytes=8, + format_char_sequence="Q")[0] + track_elems = read_next_bytes(fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * + track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([[ + 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 + ]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([[Rxx - Ryy - Rzz, 0, 0, 0], [ + Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0 + ], [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec \ No newline at end of file diff --git a/NeRF/tensorflow/data.py b/NeRF/tensorflow/data.py new file mode 100644 index 0000000..4abb7bd --- /dev/null +++ b/NeRF/tensorflow/data.py @@ -0,0 +1,138 @@ +import os + +import numpy as np +from numpy.linalg.linalg import norm +import tensorflow as tf + + +def normalize(x): + return x / np.linalg.norm(x) + + +def viewmatrix(z, up, pos): + vec2 = normalize(z) + vec1_avg = up + vec0 = normalize(np.cross(vec1_avg, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + +def poses_avg(poses): + hwf = poses[0, :3, -1:] + center = poses[:, :3, 3].mean(0) + vec2 = normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + camera2world = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) + return camera2world + + +def recenter_poses(poses): + poses_ = poses + 0 + bottom = np.reshape([0, 0, 0, 1.], [1, 4]) + camera2world = poses_avg(poses) + camera2world = np.concatenate([camera2world[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + poses = np.linalg.inv(camera2world) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + +def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): + render_poses = [] + rads = np.array(list(rads) + [1.0]) + hwf = c2w[:, 4:5] + + for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]: + c = np.dot( + c2w[:3, :4], + np.array( + [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * + rads) + z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) + render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) + return render_poses + + +def load_data(basedir, factor=8, bound_factor=0.75): + poses_array = np.load(os.path.join(basedir, 'poses_bounds.npy')) + + # for M images, poses -> [3, 5, M] + poses = poses_array[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + + # bounds -> [2, M] + bounds = poses_array[:, -2:].transpose([1, 0]) + + print('Loaded', basedir, bounds.min(), bounds.max()) + + imgdir = os.path.join(basedir, 'images') + + images = [] + sh = None + for f in sorted(os.listdir(imgdir)): + if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png'): + image_path = os.path.join(imgdir, f) + image = tf.io.decode_image(tf.io.read_file(image_path)) + images.append(image) + if sh is None: + sh = np.array([image.shape[0], image.shape[1]]) / 8 + poses[:2, 4, :] = sh[:2].reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1. / factor + images = np.stack(images, -1) + + poses = np.concatenate( + [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + images = np.moveaxis(images, -1, 0).astype(np.float32) + bounds = np.moveaxis(bounds, -1, 0).astype(np.float32) + + scale = 1. if bound_factor is None else 1. / (bounds.min() * bound_factor) + poses[:, :3, 3] *= scale + bounds *= scale + + poses = recenter_poses(poses) + + camera2world = poses_avg(poses) + print('recentered', camera2world.shape) + print(camera2world[:3, :4]) + + up = normalize(poses[:, :3, 1].sum(0)) + + close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 0.5 + dt = 0.75 + mean_dz = 1 / ((1 - dt) / close_depth + dt / inf_depth) + focal = mean_dz + + shrink_factor = 0.8 + zdelta = close_depth * 0.2 + tt = poses[:, :3, 3] + rads = np.percentile(np.abs(tt), 90, 0) + c2w_path = camera2world + N_views = 120 + N_rots = 2 + + render_poses = render_path_spiral(c2w_path, + up, + rads, + focal, + zdelta, + zrate=.5, + rots=N_rots, + N=N_views) + + render_poses = np.array(render_poses).astype(np.float32) + + c2w = poses_avg(poses) + print('Data:') + print(poses.shape, images.shape, bounds.shape) + + dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1) + i_test = np.argmin(dists) + + return images, poses, bounds, render_poses, i_test + + +load_data( + '/Users/yanjia.li/Snapchat/Dev/deep-vision/NeRF/tensorflow/data/frames') diff --git a/NeRF/tensorflow/extract_frames.py b/NeRF/tensorflow/extract_frames.py new file mode 100644 index 0000000..4856b50 --- /dev/null +++ b/NeRF/tensorflow/extract_frames.py @@ -0,0 +1,113 @@ +import os +import shutil +import subprocess + +import click +import numpy as np +import tensorflow as tf + +from colmap import load_colmap_data, run_colmap, save_poses + + +def get_blur_score(image_path): + """Calculate bluriness score with FFT + https://www.pyimagesearch.com/2020/06/15/opencv-fast-fourier-transform-fft-for-blur-detection-in-images-and-video-streams/ + """ + size = 60 + encoded = tf.io.read_file(image_path) + image = tf.io.decode_image(encoded) + image = tf.image.resize(image, (500, 500), preserve_aspect_ratio=True) + image = tf.squeeze(tf.image.rgb_to_grayscale(image), -1).numpy() + h = image.shape[0] + w = image.shape[1] + (cX, cY) = (int(w / 2.0), int(h / 2.0)) + fft = np.fft.fft2(image) + fftShift = np.fft.fftshift(fft) + fftShift[cY - size:cY + size, cX - size:cX + size] = 0 + fftShift = np.fft.ifftshift(fftShift) + recon = np.fft.ifft2(fftShift) + magnitude = 20 * np.log(np.abs(recon).clip(min=1e-8)) + mean = np.mean(magnitude) + + return mean, image_path + + +def filter_blurry_images(frames_dir, output_dir, threshold): + images = os.listdir(frames_dir) + images_dir = os.path.join(output_dir, 'images') + os.makedirs(images_dir, exist_ok=True) + + results = [] + for image_name in images: + if image_name.endswith('.jpg') or image_name.endswith('.png'): + image_path = os.path.join(frames_dir, image_name) + results.append(get_blur_score(image_path)) + + cnt = 0 + for score, image_path in results: + if score < threshold: + cnt += 1 + continue + output_path = os.path.join(images_dir, os.path.basename(image_path)) + shutil.copy(image_path, output_path) + print( + f'Filtered out {cnt} blurry images out of {len(results)} total images.' + ) + + +def generate_poses(basedir, colmap_location, match_type='exhaustive_matcher'): + + files_needed = [ + '{}.bin'.format(f) for f in ['cameras', 'images', 'points3D'] + ] + if os.path.exists(os.path.join(basedir, 'sparse/0')): + files_had = os.listdir(os.path.join(basedir, 'sparse/0')) + else: + files_had = [] + if not all([f in files_had for f in files_needed]): + print('Need to run COLMAP') + run_colmap(basedir, colmap_location, match_type) + else: + print('Don\'t need to run COLMAP') + + print('Post-colmap') + + poses, pts3d, perm = load_colmap_data(basedir) + + save_poses(basedir, poses, pts3d, perm) + + print('Done with imgs2poses') + + return True + + +def extract_raw_frames(video_path, frames_dir, fps): + subprocess.run([ + 'ffmpeg', '-i', video_path, '-r', f'{fps}/1', + f'{frames_dir}/frame%03d.png' + ]) + + +@click.command() +@click.option('--video', help='Video path.') +@click.option('--output-dir', help='Directory for output images.') +@click.option('--threshold', default=15, help='Bluriness threshold.') +@click.option('--fps', default=5, help='Frames per second.') +@click.option('--colmap', + default='/Applications/COLMAP.app/Contents/MacOS/colmap', + help='The path to COLMAP executable.') +def main(video, output_dir, colmap, threshold, fps): + frames_dir = './temp/frames' + os.makedirs(frames_dir, exist_ok=True) + + try: + extract_raw_frames(video, frames_dir, fps) + filter_blurry_images(frames_dir, output_dir, threshold) + generate_poses(output_dir, colmap) + finally: + shutil.rmtree(frames_dir, ignore_errors=True) + os.rmdir(frames_dir) + + +if __name__ == "__main__": + main() \ No newline at end of file