Skip to content

Commit

Permalink
Set default max batch size in megabytes to None
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 3, 2024
1 parent 283ab05 commit d1f43a8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torchflows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def fit(self,
keep_best_weights: bool = True,
early_stopping: bool = False,
early_stopping_threshold: int = 50,
max_batch_size_mb: int = 2000):
max_batch_size_mb: int = None):
"""Fit the normalizing flow to a dataset.
Fitting the flow means finding the parameters of the bijection that maximize the probability of training data.
Expand Down Expand Up @@ -117,8 +117,9 @@ def fit(self,
min_batch_size = max(32, min(1024, len(x_train) // 100))
max_batch_size = min(4096, len(x_train) // 10)

event_size_mb = self.event_size / 2 ** 20
max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb)))
if max_batch_size_mb is not None:
event_size_mb = self.event_size / 2 ** 20
max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb)))

batch_size_adaptation_interval = 10 # double the batch size every 10 epochs
adaptive_batch_size = True
Expand Down

0 comments on commit d1f43a8

Please sign in to comment.