Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiaoming committed Oct 27, 2020
1 parent e5047d4 commit c03e88f
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions DataLoader/TUMRGBLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@

import cv2
import numpy as np
import glob
from tqdm import tqdm
import logging
from pathlib import Path
import torch.utils.data as data
from torchvision import transforms
import torch
import sys

from utils.PinholeCamera import PinholeCamera


class TUMRGBLoader(object):
default_config = {
'root_path': '/mnt/dataset_hdd/tumrgbd',
'scene': "rgbd_dataset_freiburg1_360",
'scene': "rgbd_dataset_freiburg1_360",
"start": 0,
}

Expand All @@ -27,14 +22,11 @@ def __init__(self, config={}):
logging.info("TUMRGB Dataset config: ")
logging.info(self.config)


self.img_id = self.config["start"]
self.root = self.config['root_path']


assert (Path(self.root).exists()), f"Dataset root path {self.root} dose not exist!"


self.imgs = []
self.poses = []
self.read_imgs()
Expand All @@ -43,7 +35,7 @@ def __init__(self, config={}):

self.cam = PinholeCamera(640.0, 4801.0, 525, 525, 319.5, 239.5)

def quaternion_to_rotation_matrix( self, quat):
def quaternion_to_rotation_matrix(self, quat):
"""
Args:
quat: x,y,z,w
Expand All @@ -55,7 +47,8 @@ def quaternion_to_rotation_matrix( self, quat):
r = R.from_quat(quat)
rot_matrix = r.as_matrix()
return rot_matrix
def pose_transform( self, posexyzquat):

def pose_transform(self, posexyzquat):
"""
Args:
posexyzquat: tx,ty,tz,qx,qy,qz,qw
Expand All @@ -69,13 +62,12 @@ def pose_transform( self, posexyzquat):
return pose

def read_imgs(self):

root_path = self.root
scene = self.config['scene']

scene = scene.strip()

rgbd_gt_file= Path(root_path) / scene / "rgbd_gt.txt"
rgbd_gt_file = Path(root_path) / scene / "rgbd_gt.txt"
with open(rgbd_gt_file, 'r') as f:
lines = f.readlines()
lines = np.array([line.strip().split() for line in lines[2:]])
Expand All @@ -84,15 +76,17 @@ def read_imgs(self):
rgb_list = lines[:, 1]
# depth_list = lines[:, 2]
poses = [self.pose_transform(pose) for pose in lines[:, 3:]]
imgs = [ os.path.join(root_path,scene,img) for img in rgb_list]
imgs = [os.path.join(root_path, scene, img) for img in rgb_list]

self.imgs += imgs
self.poses += poses

def __iter__(self):
return self

def __getitem__(self, item):
return cv2.imread(self.imgs[item])

def __next__(self):
if self.img_id < self.__len__():
img = self.__getitem__(self.img_id)
Expand All @@ -101,15 +95,15 @@ def __next__(self):

return img
raise StopIteration()

def __len__(self):
return self.img_N - self.config["start"]
def get_cur_pose(self):
return self.poses[self.img_id -1]

if __name__ == "__main__":

def get_cur_pose(self):
return self.poses[self.img_id - 1]


if __name__ == "__main__":
loader = TUMRGBLoader()

for img in tqdm(loader):
Expand All @@ -119,4 +113,3 @@ def get_cur_pose(self):
# press Esc to exit
if cv2.waitKey(10) == 27:
break

0 comments on commit c03e88f

Please sign in to comment.