Skip to content

Commit

Permalink
Merge pull request #511 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Sep 26, 2023
2 parents c3273d9 + ae3587d commit ddf5e14
Show file tree
Hide file tree
Showing 19 changed files with 482 additions and 597 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ See instructions [here](https://github.com/NVIDIA/nvidia-docker).

### Running Docker Container (Interactive)
To use the Docker container as an interactive virtual environment, you can run a container mounted to your local data and code directories and execute the `bash` program. This may be useful if you are in the process of developing a submission.
1. Run detached Docker Container. The container_id will be printed if the container is run successfully.
1. Run detached Docker Container. The `container_id` will be printed if the container is running successfully.
```bash
docker run -t -d \
-v $HOME/data/:/data/ \
Expand All @@ -122,7 +122,7 @@ To use the Docker container as an interactive virtual environment, you can run a
-v $HOME/algorithmic-efficiency:/algorithmic-efficiency \
--gpus all \
--ipc=host \
<docker_image_name>
<docker_image_name> \
--keep_container_alive true
```
2. Open a bash terminal
Expand Down
31 changes: 15 additions & 16 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np(
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0:
if remainder_size != 0 or pad_to_global_batch_size:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
Expand All @@ -50,8 +57,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0:
x = pad(x, pad_size, 'jax', padding_value=padding_value)
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
Expand All @@ -61,21 +68,13 @@ def _prepare(x):
return jax.tree_map(_prepare, batch)


def pad(tensor: spec.Tensor,
def pad(tensor: np.ndarray,
pad_size: int,
framework: str,
padding_value: int = 0) -> spec.Tensor:
if len(tensor) > 1:
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
if framework == 'pytorch':
padding = torch.full(
pad_size, padding_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat((tensor, padding), dim=0)
elif framework == 'jax':
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
else:
raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.')
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
return padded_tensor


Expand Down
6 changes: 6 additions & 0 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict:
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.
Expand Down
6 changes: 4 additions & 2 deletions algorithmic_efficiency/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def pytorch_param_types(
elif 'attn' in name or 'attention' in name:
if 'bias' in name:
param_types[name] = spec.ParameterType.ATTENTION_BIAS
elif 'in_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'kv_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_KV
elif 'k_proj' in name or 'key' in name:
param_types[name] = spec.ParameterType.ATTENTION_K
elif 'q_proj' in name or 'query' in name:
Expand All @@ -51,8 +55,6 @@ def pytorch_param_types(
param_types[name] = spec.ParameterType.ATTENTION_OUT
elif 'scale' in name:
param_types[name] = spec.ParameterType.WEIGHT
elif 'in_proj_weight' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
else:
raise ValueError(f'Unrecognized attention parameter: {name}.')
elif 'bias' in name:
Expand Down
15 changes: 11 additions & 4 deletions algorithmic_efficiency/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from typing import Dict, Generator, List, Optional, Tuple

import numpy as np
import torch


def _get_monotonic_time() -> float:
if torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.synchronize()
return time.monotonic()


class Profiler:
Expand All @@ -20,7 +27,7 @@ def __init__(self, local_rank: Optional[int] = None) -> None:

self.current_actions: Dict[str, float] = {}
self.recorded_durations = defaultdict(list)
self.start_time = time.monotonic()
self.start_time = _get_monotonic_time()

def set_local_rank(self, local_rank: int) -> None:
self._local_rank = local_rank
Expand All @@ -35,12 +42,12 @@ def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
f'Attempted to start {action_name} which has already started.')
self.current_actions[action_name] = time.monotonic()
self.current_actions[action_name] = _get_monotonic_time()

def stop(self, action_name: str) -> None:
if self.local_rank != 0:
pass
end_time = time.monotonic()
end_time = _get_monotonic_time()
if action_name not in self.current_actions:
raise ValueError(f'Attempting to stop recording an action '
f'({action_name}) which was never started.')
Expand All @@ -59,7 +66,7 @@ def profile(self, action_name: str) -> Generator:
def _make_report(
self
) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]:
total_duration = time.monotonic() - self.start_time
total_duration = _get_monotonic_time() - self.start_time
report = [(str(a),
float(np.mean(d)),
float(np.std(d)),
Expand Down
7 changes: 4 additions & 3 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ class ParameterType(enum.Enum):
ATTENTION_V = 10
ATTENTION_OUT = 11
ATTENTION_QKV = 12 # This is used for implementations that fuse QKV together.
# We need to split this out because otherwise fused QKV models will have a
# different number of biases.
ATTENTION_BIAS = 13
ATTENTION_KV = 13 # This is used for implementations that fuse KV together.
# We sometimes need to split this out because otherwise fused models will have
# a different number of biases.
ATTENTION_BIAS = 14


# Of course, Tensor knows its shape and dtype.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _eval_batch(self,
summed_loss = self.loss_fn(
label_batch=batch['targets'], logits_batch=logits,
mask_batch=weights)['summed']
return summed_loss
return summed_loss.to(dtype=torch.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int:

@property
def num_validation_examples(self) -> int:
return 89_000_000
return 83_274_637

@property
def num_test_examples(self) -> int:
return 89_274_637
return 95_000_000

@property
def train_mean(self):
Expand Down
Loading

0 comments on commit ddf5e14

Please sign in to comment.