Skip to content

Commit

Permalink
merge fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 30, 2023
2 parents 3c5882b + 20376e0 commit 70cfc72
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 59 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


## Installation
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker container](#install-in-docker) (recommended).
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended).

*TL;DR to install the Jax version for GPU run:*

Expand Down Expand Up @@ -89,7 +89,8 @@ pip3 install -e '.[full]'
</details>

## Docker
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
Alternatively, a Singularity/Apptainer container can also be used (see instructions below).


**Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs.
Expand Down Expand Up @@ -133,6 +134,25 @@ To use the Docker container as an interactive virtual environment, you can run a
### Running Docker Container (End-to-end)
To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container).

### Using Singularity/Apptainer instead of Docker
Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile.

To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli):
```bash
pip3 install spython
cd algorithmic-efficiency/docker
spython recipe Dockerfile &> Singularity.def
```
Now we can build the Apptainer image by running
```bash
singularity build --fakeroot <singularity_image_name>.sif Singularity.def
```
To start a shell session with GPU support (by using the `--nv` flag), we can run
```bash
singularity shell --nv <singularity_image_name>.sif
```
Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html).

# Getting Started
For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md).
## Running a workload
Expand Down
31 changes: 15 additions & 16 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np(
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0:
if remainder_size != 0 or pad_to_global_batch_size:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
Expand All @@ -50,8 +57,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0:
x = pad(x, pad_size, 'jax', padding_value=padding_value)
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
Expand All @@ -61,21 +68,13 @@ def _prepare(x):
return jax.tree_map(_prepare, batch)


def pad(tensor: spec.Tensor,
def pad(tensor: np.ndarray,
pad_size: int,
framework: str,
padding_value: int = 0) -> spec.Tensor:
if len(tensor) > 1:
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
if framework == 'pytorch':
padding = torch.full(
pad_size, padding_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat((tensor, padding), dim=0)
elif framework == 'jax':
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
else:
raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.')
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
return padded_tensor


Expand Down
6 changes: 6 additions & 0 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict:
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def _get_monotonic_time() -> float:
if torch.cuda.is_available():
if torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.synchronize()
return time.monotonic()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _eval_batch(self,
summed_loss = self.loss_fn(
label_batch=batch['targets'], logits_batch=logits,
mask_batch=weights)['summed']
return summed_loss
return summed_loss.to(dtype=torch.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int:

@property
def num_validation_examples(self) -> int:
return 89_000_000
return 83_274_637

@property
def num_test_examples(self) -> int:
return 89_274_637
return 95_000_000

@property
def train_mean(self):
Expand Down
2 changes: 1 addition & 1 deletion datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ python3 datasets/dataset_setup.py \
--imagenet \
--temp_dir $DATA_DIR/tmp \
--imagenet_train_url <imagenet_train_url> \
--imagenet_val_url <imagenet_val_url\
--imagenet_val_url <imagenet_val_url> \
--framework jax

```
Expand Down
14 changes: 6 additions & 8 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def download_criteo1tb(data_dir,
logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}')
p = subprocess.Popen(unzip_cmd, shell=True)
p.communicate()
_maybe_prompt_for_deletion(all_days_zip_filepath, interactive_deletion)
_maybe_prompt_for_deletion([all_days_zip_filepath], interactive_deletion)

# Unzip the individual days.
processes = []
Expand All @@ -316,9 +316,9 @@ def download_criteo1tb(data_dir,
_maybe_prompt_for_deletion(gz_paths, interactive_deletion)

# Split into files with 5M lines each: day_1.csv -> day_1_[0-39].csv.
unzipped_paths = []
for batch in range(6):
batch_processes = []
unzipped_paths = []
for day_offset in range(4):
day = batch * 4 + day_offset
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
Expand All @@ -330,7 +330,7 @@ def download_criteo1tb(data_dir,
batch_processes.append(subprocess.Popen(split_cmd, shell=True))
for p in batch_processes:
p.communicate()
_maybe_prompt_for_deletion(unzipped_paths, interactive_deletion)
_maybe_prompt_for_deletion(unzipped_paths, interactive_deletion)


def download_cifar(data_dir, framework):
Expand Down Expand Up @@ -567,14 +567,12 @@ def download_librispeech(dataset_dir, tmp_dir):
# After extraction the result is a folder named Librispeech containing audio
# files in .flac format along with transcripts containing name of audio file
# and corresponding transcription.
# tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech')
# extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech')
# final_data_dir = os.path.join(dataset_dir, 'librispeech_processed')
tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw')
extracted_data_dir = os.path.join(tmp_dir, 'librispeech_extracted')
tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech')
extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech')
final_data_dir = os.path.join(dataset_dir, 'librispeech')

_maybe_mkdir(tmp_librispeech_dir)
_maybe_mkdir(final_data_dir)

for split in ['dev', 'test']:
for version in ['clean', 'other']:
Expand Down
1 change: 1 addition & 0 deletions datasets/librispeech_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'train-clean-360': 104014,
'train-other-500': 148688,
'test-clean': 2620,
'test-other': 2939,
'dev-clean': 2703,
'dev-other': 2864,
}
Expand Down
23 changes: 11 additions & 12 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

# To build Docker image
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
ARG DEBIAN_FRONTEND=noninteractive

# Installing machine packages
RUN echo "Setting up machine"
RUN apt-get update
RUN apt-get install -y curl tar
RUN apt-get install -y git python3 pip wget ffmpeg
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg
RUN apt-get install libtcmalloc-minimal4
RUN apt-get install unzip
RUN apt-get install pigz
Expand All @@ -34,38 +33,38 @@ RUN echo "Setting up algorithmic_efficiency repo"
ARG branch="main"
ARG framework="both"
ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git
RUN git clone $git_url && cd algorithmic-efficiency
RUN cd algorithmic-efficiency && git checkout $branch
RUN git clone $git_url && cd /algorithmic-efficiency
RUN cd /algorithmic-efficiency && git checkout $branch

RUN cd algorithmic-efficiency && pip install -e '.[full]'
RUN cd /algorithmic-efficiency && pip install -e '.[full]'

RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_cpu]' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "both" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
fi

RUN cd algorithmic-efficiency && pip install -e '.[wandb]'
RUN cd /algorithmic-efficiency && pip install -e '.[wandb]'

RUN cd algorithmic-efficiency && git fetch origin
RUN cd algorithmic-efficiency && git pull
RUN cd /algorithmic-efficiency && git fetch origin
RUN cd /algorithmic-efficiency && git pull

# Todo: remove this, this is temporary for developing
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh

ENTRYPOINT ["bash", "algorithmic-efficiency/docker/scripts/startup.sh"]
ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"]
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jax_core_deps =
# Todo(kasimbeg): verify if this is necessary after we
# upgrade jax.
chex==0.1.7
ml_dtypes==0.2.0

# JAX CPU
jax_cpu =
Expand Down
24 changes: 17 additions & 7 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@
'hparam_end_index',
None,
'End index to slice set of hyperparameters in tuning spearch space.')
'rng_seed',
None,
'Value of rng seed. If None, a random seed will'
'be generated from hardware.')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down Expand Up @@ -182,6 +186,7 @@ def train_once(
update_params: spec.UpdateParamsFn,
data_selection: spec.DataSelectionFn,
hyperparameters: Optional[spec.Hyperparameters],
rng_seed: int,
rng: spec.RandomState,
profiler: Profiler,
max_global_steps: int = None,
Expand Down Expand Up @@ -276,10 +281,9 @@ def train_once(
global_step,
preemption_count,
checkpoint_dir=log_dir)
meta_data = logger_utils.get_meta_data(workload)
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
logger_utils.write_json(meta_file_name, meta_data)
logger_utils.save_meta_data(workload, rng_seed, preemption_count)
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
logging.info(f'Saving flags to {flag_file_name}.')
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
Expand Down Expand Up @@ -462,7 +466,8 @@ def score_submission_on_workload(
save_checkpoints: Optional[bool] = True,
hparam_start_index: Optional[bool] = None,
hparam_end_index: Optional[bool] = None,
):
rng_seed: Optional[int] = None
):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
if imagenet_v2_data_dir:
Expand Down Expand Up @@ -511,7 +516,8 @@ def score_submission_on_workload(
enumerate(tuning_search_space), hparam_start_index, hparam_end_index)
for hi, hyperparameters in tuning_search_space_iter:
# Generate a new seed from hardware sources of randomness for each trial.
rng_seed = struct.unpack('I', os.urandom(4))[0]
if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]
logging.info('Using RNG seed %d', rng_seed)
rng = prng.PRNGKey(rng_seed)
# Because we initialize the PRNGKey with only a single 32 bit int, in the
Expand Down Expand Up @@ -543,7 +549,9 @@ def score_submission_on_workload(
data_dir, imagenet_v2_data_dir,
init_optimizer_state,
update_params, data_selection,
hyperparameters, rng,
hyperparameters,
rng_seed,
rng,
profiler,
max_global_steps,
tuning_dir_name,
Expand All @@ -560,7 +568,8 @@ def score_submission_on_workload(
logging.info(f'Total number of evals: {num_evals}')
logging.info('=' * 20)
else:
rng_seed = struct.unpack('q', os.urandom(8))[0]
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
# If the submission is responsible for tuning itself, we only need to run it
# once and return the total time.
Expand All @@ -569,7 +578,7 @@ def score_submission_on_workload(
workload, global_batch_size, global_eval_batch_size,
data_dir, imagenet_v2_data_dir,
init_optimizer_state, update_params, data_selection,
None, rng, profiler, max_global_steps, log_dir,
None, rng_seed, rng, profiler, max_global_steps, log_dir,
save_checkpoints=save_checkpoints)
return score

Expand Down Expand Up @@ -628,6 +637,7 @@ def main(_):
save_checkpoints=FLAGS.save_checkpoints,
hparam_start_index=FLAGS.hparam_start_index,
hparam_end_index=FLAGS.hparam_end_index)
rng_seed=FLAGS.rng_seed)
logging.info(f'Final {FLAGS.workload} score: {score}')

if FLAGS.profile:
Expand Down
Loading

0 comments on commit 70cfc72

Please sign in to comment.