diff --git a/yarr/envs/rlbench_env.py b/yarr/envs/rlbench_env.py index 495b6b6..af2096a 100644 --- a/yarr/envs/rlbench_env.py +++ b/yarr/envs/rlbench_env.py @@ -84,7 +84,7 @@ def _get_cam_observation_elements(camera: CameraConfig, prefix: str, channels_la ObservationElement("%s_camera_intrinsics" % prefix, (3, 3), np.float32) ) if camera.depth: - shape = img_s + [1] if schannels_last else [1] + img_s + shape = img_s + [1] if channels_last else [1] + img_s elements.append(ObservationElement("%s_depth" % prefix, shape, np.float32)) if camera.mask: raise NotImplementedError() @@ -137,6 +137,11 @@ def _observation_elements( observation_config.wrist_camera, "wrist", channels_last ) ) + elements.extend( + _get_cam_observation_elements( + observation_config.overhead_camera, "overhead", channels_last + ) + ) return elements diff --git a/yarr/runners/pytorch_train_runner.py b/yarr/runners/pytorch_train_runner.py index 65c58ea..1ba01bf 100644 --- a/yarr/runners/pytorch_train_runner.py +++ b/yarr/runners/pytorch_train_runner.py @@ -43,6 +43,9 @@ def __init__(self, replay_ratio: Optional[float] = None, tensorboard_logging: bool = True, csv_logging: bool = False, + wandb_logging: bool = True, + wandb_cfg: dict = {}, + project_name: str = "c2farm", buffers_per_batch: int = -1 # -1 = all ): super(PyTorchTrainRunner, self).__init__( @@ -78,7 +81,7 @@ def __init__(self, logging.info("'logdir' was None. No logging will take place.") else: self._writer = LogWriter( - self._logdir, tensorboard_logging, csv_logging) + self._logdir, tensorboard_logging, csv_logging, wandb_logging, wandb_cfg, project_name) if weightsdir is None: logging.info( "'weightsdir' was None. No weight saving will take place.") diff --git a/yarr/utils/log_writer.py b/yarr/utils/log_writer.py index f9ccff1..f083f2a 100644 --- a/yarr/utils/log_writer.py +++ b/yarr/utils/log_writer.py @@ -8,16 +8,20 @@ from yarr.agents.agent import ScalarSummary, HistogramSummary, ImageSummary, \ VideoSummary from torch.utils.tensorboard import SummaryWriter - +import wandb class LogWriter(object): def __init__(self, logdir: str, tensorboard_logging: bool, - csv_logging: bool): + csv_logging: bool, + wandb_logging: bool, + wandb_cfg: dict = None, + project_name: str = 'c2farm'): self._tensorboard_logging = tensorboard_logging self._csv_logging = csv_logging + self._wandb_logging = wandb_logging os.makedirs(logdir, exist_ok=True) if tensorboard_logging: self._tf_writer = SummaryWriter(logdir) @@ -25,36 +29,66 @@ def __init__(self, self._prev_row_data = self._row_data = OrderedDict() self._csv_file = os.path.join(logdir, 'data.csv') self._field_names = None - + if wandb_logging: + try: + task_name = wandb_cfg['rlbench']['task'] + method_name = wandb_cfg['method']['name'] + exp_name = task_name + '-' + method_name + except: + exp_name = None + wandb.init( + project=project_name, + config=wandb_cfg, + name=exp_name + ) def add_scalar(self, i, name, value): if self._tensorboard_logging: self._tf_writer.add_scalar(name, value, i) if self._csv_logging: if len(self._row_data) == 0: self._row_data['step'] = i - self._row_data[name] = value.item() if isinstance( - value, torch.Tensor) else value + self._row_data[name] = value.item() if isinstance(value, torch.Tensor) else value + if self._wandb_logging: + wandb.log({name: value, 'step': i}) def add_summaries(self, i, summaries): for summary in summaries: try: - if isinstance(summary, ScalarSummary): - self.add_scalar(i, summary.name, summary.value) - elif self._tensorboard_logging: - if isinstance(summary, HistogramSummary): + if self._csv_logging and isinstance(summary, ScalarSummary): + self._row_data['step'] = i + name, value = summary.name, summary.value + self._row_data[name] = value.item() if isinstance(value, torch.Tensor) else value + + if isinstance(summary, HistogramSummary): + if self._tensorboard_logging: self._tf_writer.add_histogram( summary.name, summary.value, i) - elif isinstance(summary, ImageSummary): - # Only grab first item in batch - v = (summary.value if summary.value.ndim == 3 else - summary.value[0]) + if self._wandb_logging: + wandb.log({summary.name: wandb.Histogram(summary.value.cpu()), 'step': i}) + elif isinstance(summary, ImageSummary): + # Only grab first item in batch + + v = (summary.value if summary.value.ndim == 3 else + summary.value[0]) + if self._tensorboard_logging: self._tf_writer.add_image(summary.name, v, i) - elif isinstance(summary, VideoSummary): - # Only grab first item in batch - v = (summary.value if summary.value.ndim == 5 else - np.array([summary.value])) + if self._wandb_logging: + wandb.log({summary.name: wandb.Image(v), 'step': i}) + elif isinstance(summary, VideoSummary): + # Only grab first item in batch + v = (summary.value if summary.value.ndim == 5 else + np.array([summary.value])) + if self._tensorboard_logging: self._tf_writer.add_video( summary.name, v, i, fps=summary.fps) + if self._wandb_logging: + wandb.log({summary.name: wandb.Video(v, fps=summary.fps), 'step': i}) + elif isinstance(summary, ScalarSummary): + if self._tensorboard_logging: + self._tf_writer.add_scalar(summary.name, summary.value, i) + if self._wandb_logging: + wandb.log({summary.name: summary.value, 'step': i}) + except Exception as e: logging.error('Error on summary: %s' % summary.name) raise e