From 769a4210ff4981f26a29d6c9e4f57fbf71f29b28 Mon Sep 17 00:00:00 2001 From: Neil Thomas Date: Thu, 8 Aug 2019 16:45:33 -0700 Subject: [PATCH] Switch output writing to h5py to reduce memory footprint. Update version --- VERSION.txt | 3 ++- rinokeras/core/v1x/train/RinokerasGraph.py | 24 ++++++++++++---------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index 524cb55..b37a96a 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1,2 @@ -1.1.1 +1.1.2 + diff --git a/rinokeras/core/v1x/train/RinokerasGraph.py b/rinokeras/core/v1x/train/RinokerasGraph.py index b12a461..ccb4960 100644 --- a/rinokeras/core/v1x/train/RinokerasGraph.py +++ b/rinokeras/core/v1x/train/RinokerasGraph.py @@ -2,6 +2,7 @@ from typing import Sequence, Union, Any, Optional, Dict import pickle as pkl +import h5py import tensorflow as tf from tensorflow.python.client import timeline from tqdm import tqdm @@ -160,21 +161,22 @@ def run_epoch(self, epoch_num: Optional[int] = None, summary_writer: Optional[tf.summary.FileWriter] = None, save_outputs: Optional[str] = None) -> MetricsAccumulator: - all_outputs = [] with self.add_progress_bar(data_len, epoch_num).initialize(): assert self.epoch_metrics is not None - while True: - if save_outputs is not None: - loss, outputs = self.run('default', return_outputs=True) - all_outputs.append(outputs) - else: + if save_outputs is not None: + i = 0 + with h5py.File(save_outputs, 'w') as f: + while True: + loss, outputs = self.run('default', return_outputs=True) + grp = f.create_group(str(i)) + outputs = outputs[0] # can we rely on this being a tuple of length 1? + for key in outputs.keys(): + grp.create_dataset(key, data=outputs[key]) + i += 1 + else: + while True: self.run('default') - - if save_outputs is not None: - with open(save_outputs, 'wb') as f: - pkl.dump(all_outputs, f) - return self.epoch_metrics def run_for_n_steps(self,