Skip to content
This repository has been archived by the owner on May 3, 2022. It is now read-only.

Commit

Permalink
Switch output writing to h5py to reduce memory footprint. Update version
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-a-neil committed Aug 8, 2019
1 parent c570ba8 commit 769a421
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
3 changes: 2 additions & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
1.1.1
1.1.2

24 changes: 13 additions & 11 deletions rinokeras/core/v1x/train/RinokerasGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence, Union, Any, Optional, Dict
import pickle as pkl

import h5py
import tensorflow as tf
from tensorflow.python.client import timeline
from tqdm import tqdm
Expand Down Expand Up @@ -160,21 +161,22 @@ def run_epoch(self,
epoch_num: Optional[int] = None,
summary_writer: Optional[tf.summary.FileWriter] = None,
save_outputs: Optional[str] = None) -> MetricsAccumulator:
all_outputs = []

with self.add_progress_bar(data_len, epoch_num).initialize():
assert self.epoch_metrics is not None
while True:
if save_outputs is not None:
loss, outputs = self.run('default', return_outputs=True)
all_outputs.append(outputs)
else:
if save_outputs is not None:
i = 0
with h5py.File(save_outputs, 'w') as f:
while True:
loss, outputs = self.run('default', return_outputs=True)
grp = f.create_group(str(i))
outputs = outputs[0] # can we rely on this being a tuple of length 1?
for key in outputs.keys():
grp.create_dataset(key, data=outputs[key])
i += 1
else:
while True:
self.run('default')

if save_outputs is not None:
with open(save_outputs, 'wb') as f:
pkl.dump(all_outputs, f)

return self.epoch_metrics

def run_for_n_steps(self,
Expand Down

0 comments on commit 769a421

Please sign in to comment.