Skip to content

Commit

Permalink
Update run.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696161473
  • Loading branch information
grenlayk authored and copybara-github committed Nov 13, 2024
1 parent 08ad3d1 commit a891bc2
Showing 1 changed file with 62 additions and 15 deletions.
77 changes: 62 additions & 15 deletions clrs/examples/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import functools
import os
import shutil
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from absl import app
from absl import flags
Expand Down Expand Up @@ -271,16 +271,52 @@ def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
return {k: unpack(v) for k, v in out.items()}


def create_samplers(rng, train_lengths: List[int]):
"""Create all the samplers."""
def create_samplers(
rng,
train_lengths: List[int],
*,
algorithms: Optional[List[str]] = None,
val_lengths: Optional[List[int]] = None,
test_lengths: Optional[List[int]] = None,
train_batch_size: int = 32,
val_batch_size: int = 32,
test_batch_size: int = 32,
):
"""Create samplers for training, validation and testing.
Args:
rng: Numpy random state.
train_lengths: list of training lengths to use for each algorithm.
algorithms: list of algorithms to generate samplers for. Set to
FLAGS.algorithms if not provided.
val_lengths: list of lengths for validation samplers for each algorithm. Set
to maxumim training length if not provided.
test_lengths: list of lengths for test samplers for each algorithm. Set to
[-1] to use the benchmark dataset if not provided.
train_batch_size: batch size for training samplers.
val_batch_size: batch size for validation samplers.
test_batch_size: batch size for test samplers.
Returns:
Tuple of:
train_samplers: list of samplers for training.
val_samplers: list of samplers for validation.
val_sample_counts: list of sample counts for validation.
test_samplers: list of samplers for testing.
test_sample_counts: list of sample counts for testing.
spec_list: list of specs for each algorithm.
"""

train_samplers = []
val_samplers = []
val_sample_counts = []
test_samplers = []
test_sample_counts = []
spec_list = []

for algo_idx, algorithm in enumerate(FLAGS.algorithms):
algorithms = algorithms or FLAGS.algorithms
for algo_idx, algorithm in enumerate(algorithms):
# Make full dataset pipeline run on CPU (including prefetching).
with tf.device('/cpu:0'):

Expand Down Expand Up @@ -310,7 +346,7 @@ def create_samplers(rng, train_lengths: List[int]):
sampler_kwargs.pop('length_needle')

common_sampler_args = dict(
algorithm=FLAGS.algorithms[algo_idx],
algorithm=algorithms[algo_idx],
rng=rng,
enforce_pred_as_input=FLAGS.enforce_pred_as_input,
enforce_permutations=FLAGS.enforce_permutations,
Expand All @@ -319,7 +355,7 @@ def create_samplers(rng, train_lengths: List[int]):

train_args = dict(sizes=train_lengths,
split='train',
batch_size=FLAGS.batch_size,
batch_size=train_batch_size,
multiplier=-1,
randomize_pos=FLAGS.random_pos,
chunked=FLAGS.chunked_training,
Expand All @@ -328,19 +364,19 @@ def create_samplers(rng, train_lengths: List[int]):
train_sampler, _, spec = make_multi_sampler(**train_args)

mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
val_args = dict(sizes=[np.amax(train_lengths)],
val_args = dict(sizes=val_lengths or [np.amax(train_lengths)],
split='val',
batch_size=32,
batch_size=val_batch_size,
multiplier=2 * mult,
randomize_pos=FLAGS.random_pos,
chunked=False,
sampler_kwargs=sampler_kwargs,
**common_sampler_args)
val_sampler, val_samples, spec = make_multi_sampler(**val_args)

test_args = dict(sizes=[-1],
test_args = dict(sizes=test_lengths or [-1],
split='test',
batch_size=32,
batch_size=test_batch_size,
multiplier=2 * mult,
randomize_pos=False,
chunked=False,
Expand Down Expand Up @@ -380,16 +416,27 @@ def main(unused_argv):
rng_key = jax.random.PRNGKey(rng.randint(2**32))

# Create samplers
(train_samplers,
val_samplers, val_sample_counts,
test_samplers, test_sample_counts,
spec_list) = create_samplers(rng, train_lengths)
(
train_samplers,
val_samplers,
val_sample_counts,
test_samplers,
test_sample_counts,
spec_list,
) = create_samplers(
rng=rng,
train_lengths=train_lengths,
algorithms=FLAGS.algorithms,
val_lengths=[np.amax(train_lengths)],
test_lengths=[-1],
train_batch_size=FLAGS.batch_size,
)

processor_factory = clrs.get_processor_factory(
FLAGS.processor_type,
use_ln=FLAGS.use_ln,
nb_triplet_fts=FLAGS.nb_triplet_fts,
nb_heads=FLAGS.nb_heads
nb_heads=FLAGS.nb_heads,
)
model_params = dict(
processor_factory=processor_factory,
Expand Down

0 comments on commit a891bc2

Please sign in to comment.