Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] Running integration tests for random util fixes PR #626

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
a8187b8
Add pass/fail thresholds to traindiffs test
runame Jan 13, 2024
2373e15
Add traindiffs_test option to docker startup script
runame Jan 13, 2024
d1da9c7
Rename PytWorkload to PyTorchWorkload
runame Jan 13, 2024
6a5d63a
Add traindiffs tests to workflows (self-hosted)
runame Jan 13, 2024
047475c
Merge branch 'dev' into traindiffs
runame Jan 17, 2024
1683ba3
add variant scoring conditions
priyakasimbeg Jan 18, 2024
370687d
add flag for self-tuning rulset
priyakasimbeg Jan 18, 2024
2128ce8
score group of submissions
priyakasimbeg Jan 18, 2024
c65794d
Merge branch 'dev' into scoring_fixes
priyakasimbeg Jan 19, 2024
d43ccf4
correct max number of steps
priyakasimbeg Jan 19, 2024
fb81436
add heldout workloads"
priyakasimbeg Jan 19, 2024
1ea2282
add trial args to docker startup.sh"
priyakasimbeg Jan 23, 2024
0bcb969
add script for sampling held out workloads
priyakasimbeg Jan 24, 2024
ce5f202
add code for run workloads
priyakasimbeg Jan 25, 2024
f431eef
add workload sampling
priyakasimbeg Jan 25, 2024
f260497
formatting
priyakasimbeg Jan 25, 2024
1a41f8b
imports
priyakasimbeg Jan 25, 2024
87df162
make seed splitting parallelizable
priyakasimbeg Jan 25, 2024
9d9cdb9
fix
priyakasimbeg Jan 25, 2024
1775307
formatting
priyakasimbeg Jan 25, 2024
8108c00
Merge pull request #613 from runame/traindiffs
priyakasimbeg Jan 25, 2024
2a11708
held out workloads example
priyakasimbeg Jan 25, 2024
a8385a2
add docker for run_workloads.py
priyakasimbeg Jan 25, 2024
ffddbdc
fix run_workloads.py
priyakasimbeg Jan 25, 2024
91cdf34
fix
priyakasimbeg Jan 25, 2024
95572ad
add rng seed to startup.sh docker script
priyakasimbeg Jan 25, 2024
d577d5c
fix
priyakasimbeg Jan 25, 2024
91ff705
fix
priyakasimbeg Jan 25, 2024
296dc1e
fix
priyakasimbeg Jan 26, 2024
a5b1154
fix
priyakasimbeg Jan 26, 2024
226544d
fix
priyakasimbeg Jan 26, 2024
6faad04
fix
priyakasimbeg Jan 26, 2024
9e7def9
fix log message
priyakasimbeg Jan 26, 2024
9b410b7
fix
priyakasimbeg Jan 26, 2024
7634a0b
debug
priyakasimbeg Jan 26, 2024
235bc69
debugging
priyakasimbeg Jan 26, 2024
a8d04cc
debugging
priyakasimbeg Jan 26, 2024
b2571b2
fix
priyakasimbeg Jan 26, 2024
18bc347
remove debugging statemetns
priyakasimbeg Jan 26, 2024
4a98698
fix
priyakasimbeg Jan 26, 2024
4d38e55
formatting
priyakasimbeg Jan 26, 2024
4d413f4
take into account median of studies for scoring
priyakasimbeg Jan 27, 2024
84c87b9
remove debugging
priyakasimbeg Jan 27, 2024
d6e2a36
formatting
priyakasimbeg Jan 27, 2024
f34838a
documentation
priyakasimbeg Jan 27, 2024
be263a2
Merge branch 'dev' into scoring_fixes
priyakasimbeg Jan 29, 2024
84dbb07
fix
priyakasimbeg Jan 30, 2024
af0b608
Merge branch 'scoring_fixes' of github.com:mlcommons/algorithmic-effi…
priyakasimbeg Jan 30, 2024
6d3b0ae
remove indexing for rng_subkeys
priyakasimbeg Jan 31, 2024
7b23443
add documentation
priyakasimbeg Jan 31, 2024
4b77ddd
fix documentation
priyakasimbeg Jan 31, 2024
2f48009
add warning
priyakasimbeg Jan 31, 2024
d39eb24
typo
priyakasimbeg Jan 31, 2024
aecb37f
fix documentation
priyakasimbeg Jan 31, 2024
5135cc8
remove prng import from generate_held_out_workloads.py
priyakasimbeg Jan 31, 2024
6d4f82e
fix technical documentation
priyakasimbeg Feb 1, 2024
aaa1014
formatting
priyakasimbeg Feb 1, 2024
761a877
add default for workload metadata config file
priyakasimbeg Feb 1, 2024
6b3827a
yapf fix
priyakasimbeg Feb 1, 2024
c0e1aad
import order
priyakasimbeg Feb 1, 2024
ff3c9b0
Merge pull request #618 from mlcommons/scoring_fixes
priyakasimbeg Feb 1, 2024
c794339
fix fold_in in pytorch
priyakasimbeg Feb 2, 2024
247dcb0
random utils fixes
priyakasimbeg Feb 2, 2024
759c90d
remove indexing from rngs in pytorch workloads
priyakasimbeg Feb 2, 2024
6baacd7
formatting
priyakasimbeg Feb 2, 2024
368e348
formatting
priyakasimbeg Feb 2, 2024
4be495f
fix seed shapes
priyakasimbeg Feb 2, 2024
5495d72
fix dataset seed
priyakasimbeg Feb 3, 2024
355ebe0
fix rng utils
priyakasimbeg Feb 3, 2024
43c381a
fix
priyakasimbeg Feb 3, 2024
f81e877
remove unused types from random_utils
priyakasimbeg Feb 3, 2024
3494f16
fix overflow error in jax sampling
priyakasimbeg Feb 3, 2024
665049a
use bits instead of randint
priyakasimbeg Feb 3, 2024
d0d3e3e
add numpy bits function
priyakasimbeg Feb 3, 2024
e4476aa
change seed method call for mnist dataset
priyakasimbeg Feb 3, 2024
9ebdca7
add documentation
priyakasimbeg Feb 6, 2024
f888a99
add fold in method
priyakasimbeg Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/traindiffs_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Containerized training differences tests between Jax and PyTorch

on:
pull_request:
branches:
- 'main'

jobs:
build_and_push_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker image
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=both
IMAGE_NAME="algoperf_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
traindiffs_tests:
runs-on: self-hosted
needs: build_and_push_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized traindiffs test
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true
2 changes: 1 addition & 1 deletion DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ The held-out workloads function similarly to a holdout test set discouraging sub

Modifications could, for example, include changing the number of layers or units (drawn from an interval), swapping the activation function (drawn from a set of applicable functions), or using different data augmentations (drawn from a list of possible pre-processing steps). The sample space should be wide enough to discourage submitters from simply trying them all out, but at the same time should be restricted enough to produce realistic workloads with acceptable achievable performances.

In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each fixed workload.
In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each dataset.

Our scoring procedure uses the held-out workloads only to penalize submissions that can't handle the introduced modifications (see the [Scoring](#scoring) section for further details).

Expand Down
42 changes: 40 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,49 @@ docker exec -it <container_id> /bin/bash
```

## Score your Submission
To score your submission we will score over all workloads, held-out workloads and studies as described in the rules.
We will sample 1 held-out workload per dataset for a total of 6 held-out workloads and will use the sampled
held-out workloads in the scoring criteria for the matching base workloads.
In other words, the total number of runs expected for official scoring is:
- for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials)
- for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies)

To produce performance profile and performance table:


### Running workloads
To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important
to note that the official set of held-out workloads will be sampled by the competition organizers during scoring time.

An example config for held-out workloads is stored in `scoring/held_workloads_example.json`.
To generate a new sample of held out workloads run:

```bash
python3 generate_held_out_workloads.py --seed <optional_rng_seed> --output_filename <output_filename>
```

To run a number of studies and trials over all workload using Docker containers for each run:

```bash
python scoring/run_workloads.py \
--framework <framework> \
--experiment_name <experiment_name> \
--docker_image_url <docker_image_url> \
--submission_path <sumbission_path> \
--tuning_search_space <submission_path> \
--held_out_workloads_config_path held_out_workloads_example.json \
--num_studies <num_studies>
--seed <rng_seed>
```

Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package.

During submission development, it might be useful to do faster, approximate scoring (e.g. without 5 different s
tudies or when some trials are missing) so the scoring scripts allow some flexibility. To simulate official scoring,
pass the `--strict=True` flag in score_submission.py. To get the raw scores and performance profiles of group of
submissions or single submission:

```bash
python3 scoring/score_submission.py --experiment_path=<path_to_experiment_dir> --output_dir=<output_dir>
python score_submissions.py --submission_directory <directory_with_submissions> --output_dir <output_dir> --compute_performance_profiles
```

We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179).
Expand Down
32 changes: 25 additions & 7 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API."""

from typing import Any, List, Union
from typing import Union

from absl import flags
from absl import logging
Expand All @@ -21,6 +21,10 @@
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32

# SALT constants
_SALT1 = np.random.RandomState(seed=5).randint(MIN_INT32, MAX_INT32, dtype=np.int32)
_SALT2 = np.random.RandomState(seed=6).randint(MIN_INT32, MAX_INT32, dtype=np.int32)

SeedType = Union[int, list, np.ndarray]


Expand All @@ -33,15 +37,16 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return [new_seed, data]
def _fold_in(seed, data, verbose = True):
a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint(MIN_INT32, MAX_INT32, dtype=np.int32)
b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint(MIN_INT32, MAX_INT32, dtype=np.int32)
c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return c


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand All @@ -58,7 +63,13 @@ def _check_jax_install() -> None:
'--framework=pytorch to use the Numpy version instead.')


def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
def _bits(seed: SeedType) -> int:
rng = np.random.RandomState(_signed_to_unsigned(seed))
b = rng.bytes(4)
return int.from_bytes(b, byteorder='little')


def fold_in(seed: SeedType, data: int) -> SeedType:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.fold_in(seed, data)
Expand All @@ -77,3 +88,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
_check_jax_install()
return jax_rng.PRNGKey(seed)
return _PRNGKey(seed)


def bits(seed: SeedType) -> int:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.bits(seed)
return _bits(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_dataset(
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(data_rng[0]).shuffle(train_indices)
random.Random(data_rng).shuffle(train_indices)
indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
Expand Down Expand Up @@ -111,7 +111,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = resnet18(num_classes=self._num_classes)
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
if self.use_resnet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _build_dataset(

if split == 'eval_train':
indices = list(range(self.num_train_examples))
random.Random(data_rng[0]).shuffle(indices)
random.Random(data_rng).shuffle(indices)
dataset = torch.utils.data.Subset(dataset,
indices[:self.num_eval_train_examples])

Expand Down Expand Up @@ -147,7 +147,7 @@ def init_model_fn(
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.use_silu and self.use_gelu:
raise RuntimeError('Cannot use both GELU and SiLU activations.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def init_model_fn(
Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as
input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Configure torch backends to avoid OOM errors.
torch.backends.cudnn.benchmark = False
torch.backends.cuda.enable_flash_sdp(False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_model_fn(
Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
as input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = DeepspeechEncoderDecoder(
DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = _Model()
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _build_mnist_dataset(

if shuffle:
ds = ds.repeat()
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng))
ds = ds.batch(global_batch_size, drop_remainder=is_train)

if repeat_final_dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = GNN(
num_outputs=self._num_outputs,
dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is used as attention_dropout_rate."""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.activation == 'relu':
activation = F.relu
Expand Down
61 changes: 57 additions & 4 deletions docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ function usage() {
$0 [--dataset dataset] [--framework framework] [--submission_path submission_path]
[--tuning_search_space tuning_search_space] [--experiment_name experiment_name]
[--workload workload] [--max_global_steps max_global_steps] [--rsync_data rsync_data]
[--internal_contributor true]
[--internal_contributor true] [--traindiffs_test false]

Options:
-d | --dataset: Can be imagenet, criteo1tb, ogbg, fastmri, wmt, librispeech.
-f | --framework: Can be jax or pytorch.
Expand All @@ -34,11 +34,17 @@ function usage() {
from internal GCP bucket.
-i | --internal_contributor: If true, allow rsync of data and transfer of experiment results
with GCP project.
--num_tuning_trials Number of tuning trials for externally tuned ruleset submission.
--hparam_start_index Should be > 0 and < num_tuning_trials - 1.
--hparam_end_index Should be > 0 and < num_tuning_trials - 1.
--rng_seed RNG seed to pass to workload submission_runner.
--traindiffs_test: If true, ignore all other options and run the traindiffs test.
USAGE
exit 1
}

# Defaults
TEST="false"
INTERNAL_CONTRIBUTOR_MODE="false"
HOME_DIR=""
RSYNC_DATA="true"
Expand All @@ -47,7 +53,11 @@ SAVE_CHECKPOINTS="true"

# Pass flag
while [ "$1" != "" ]; do
case $1 in
case $1 in
--traindiffs_test)
shift
TEST=$1
;;
-d | --dataset)
shift
DATASET=$1
Expand Down Expand Up @@ -100,14 +110,37 @@ while [ "$1" != "" ]; do
shift
HOME_DIR=$1
;;
--num_tuning_trials)
shift
NUM_TUNING_TRIALS=$1
;;
--hparam_start_index)
shift
HPARAM_START_INDEX=$1
;;
--hparam_end_index)
shift
HPARAM_END_INDEX=$1
;;
--rng_seed)
shift
RNG_SEED=$1
;;
*)
usage
exit 1
;;
esac
shift
done
done

if [[ ${TEST} == "true" ]]; then
cd algorithmic-efficiency
COMMAND="python3 tests/test_traindiffs.py"
echo $COMMAND
eval $COMMAND
exit
fi

# Check if arguments are valid
VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \
Expand Down Expand Up @@ -180,6 +213,22 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
MAX_STEPS_FLAG="--max_global_steps=${MAX_GLOBAL_STEPS}"
fi

if [[ ! -z ${NUM_TUNING_TRIALS+x} ]]; then
NUM_TUNING_TRIALS_FLAG="--num_tuning_trials=${NUM_TUNING_TRIALS}"
fi

if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then
HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}"
fi

if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then
HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}"
fi

if [[ ! -z ${RNG_SEED+x} ]]; then
RNG_SEED_FLAG="--rng_seed=${RNG_SEED}"
fi

# Define special flags for imagenet and librispeech workloads
if [[ ${DATASET} == "imagenet" ]]; then
SPECIAL_FLAGS="--imagenet_v2_data_dir=${DATA_DIR}"
Expand All @@ -204,6 +253,10 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
--experiment_name=${EXPERIMENT_NAME} \
--overwrite=${OVERWRITE} \
--save_checkpoints=${SAVE_CHECKPOINTS} \
${NUM_TUNING_TRIALS_FLAG} \
${HPARAM_START_INDEX_FLAG} \
${HPARAM_END_INDEX_FLAG} \
${RNG_SEED_FLAG} \
${MAX_STEPS_FLAG} \
${SPECIAL_FLAGS} \
${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}"
Expand Down
Loading
Loading