From 0f608e50f2dcec4d237f6990898d7d4c197692de Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 15 Aug 2023 16:43:45 +0200 Subject: [PATCH 01/66] Add parameter type for fused KV attention --- algorithmic_efficiency/param_utils.py | 6 ++++-- algorithmic_efficiency/spec.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/param_utils.py b/algorithmic_efficiency/param_utils.py index 00c50ee4f..b430366b1 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algorithmic_efficiency/param_utils.py @@ -41,6 +41,10 @@ def pytorch_param_types( elif 'attn' in name or 'attention' in name: if 'bias' in name: param_types[name] = spec.ParameterType.ATTENTION_BIAS + elif 'in_proj' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'kv_proj' in name: + param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: param_types[name] = spec.ParameterType.ATTENTION_K elif 'q_proj' in name or 'query' in name: @@ -51,8 +55,6 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_OUT elif 'scale' in name: param_types[name] = spec.ParameterType.WEIGHT - elif 'in_proj_weight' in name: - param_types[name] = spec.ParameterType.ATTENTION_QKV else: raise ValueError(f'Unrecognized attention parameter: {name}.') elif 'bias' in name: diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 570b7c55b..285983957 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -39,9 +39,10 @@ class ParameterType(enum.Enum): ATTENTION_V = 10 ATTENTION_OUT = 11 ATTENTION_QKV = 12 # This is used for implementations that fuse QKV together. - # We need to split this out because otherwise fused QKV models will have a - # different number of biases. - ATTENTION_BIAS = 13 + ATTENTION_KV = 13 # This is used for implementations that fuse KV together. + # We sometimes need to split this out because otherwise fused models will have + # a different number of biases. + ATTENTION_BIAS = 14 # Of course, Tensor knows its shape and dtype. From 4187b0c374aead59c5bcba340f3ae3325e193972 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 15 Aug 2023 16:44:34 +0200 Subject: [PATCH 02/66] Refactor MultiheadAttention module --- .../workloads/wmt/wmt_pytorch/models.py | 634 +++++------------- 1 file changed, 164 insertions(+), 470 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 9fbc48578..391bd03b8 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -1,7 +1,6 @@ import copy import math from typing import Any, Callable, Dict, Optional, Tuple, Union -import warnings import torch from torch import nn @@ -430,8 +429,7 @@ def forward( # TransformerEncoderLayer and TransformerDecoderLayer are taken from: # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py -# Only difference is using custom MultiheadAttention modules without bias and -# '_qkv_same_embed_dim' always set to 'False'. +# Main difference is the use of custom MultiheadAttention modules. class TransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". @@ -450,22 +448,16 @@ class TransformerEncoderLayer(nn.Module): string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components (default=1e-6). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``True`` (batch, seq, feature). norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``True``. Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ - __constants__ = ['batch_first', 'norm_first'] + __constants__ = ['norm_first'] def __init__(self, d_model: int = 1024, @@ -475,7 +467,6 @@ def __init__(self, attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = True, norm_first: bool = True, device=None, dtype=None) -> None: @@ -484,8 +475,8 @@ def __init__(self, self.self_attn = MultiheadAttention( d_model, nhead, + self_attn=True, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) @@ -532,7 +523,7 @@ def forward(self, 'Only bool and floating types of key_padding_mask are supported') x = src if self.norm_first: - x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._sa_block(self.norm1(x), src_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) @@ -541,17 +532,8 @@ def forward(self, return x # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor]) -> Tensor: - x = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False)[0] + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask) return self.dropout1(x) # Feed forward block: @@ -560,7 +542,8 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout2(x) -# Modified to use cache for autoregressive decoding. +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. class TransformerDecoder(nn.Module): r"""TransformerDecoder is a stack of N decoder layers Args: @@ -643,7 +626,8 @@ def forward(self, return output -# Modified to use cache for autoregressive decoding. +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. class TransformerDecoderLayer(nn.Module): r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. @@ -663,17 +647,10 @@ class TransformerDecoderLayer(nn.Module): string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components (default=1e-6). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``True`` (batch, seq, feature). norm_first: if ``True``, layer norm is done prior to self attention, multihead attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``True``. Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - Alternatively, when ``batch_first`` is ``True``: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) >>> memory = torch.rand(32, 10, 512) @@ -690,7 +667,6 @@ def __init__(self, attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = True, norm_first: bool = True, device=None, dtype=None) -> None: @@ -699,15 +675,15 @@ def __init__(self, self.self_attn = MultiheadAttention( d_model, nhead, + self_attn=True, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) self.multihead_attn = MultiheadAttention( d_model, nhead, + self_attn=False, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) @@ -759,7 +735,7 @@ def forward( # pylint: disable=arguments-renamed cache=cache, index=index) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask, None) + x = x + self._mha_block(self.norm2(x), memory, memory_mask) x = x + self._ff_block(self.norm3(x)) else: sa_out, cache = self._sa_block( @@ -770,7 +746,7 @@ def forward( # pylint: disable=arguments-renamed cache=cache, index=index) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask, None)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask)) x = self.norm3(x + self._ff_block(x)) return x, cache @@ -784,12 +760,9 @@ def _sa_block( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None) -> Any: - x, _, cache = self.self_attn( - x, - x, + x, cache = self.self_attn( x, attn_mask=attn_mask, - need_weights=False, decode=decode, max_len=max_len, cache=cache, @@ -797,18 +770,9 @@ def _sa_block( # pylint: disable=arguments-renamed return self.dropout1(x), cache # Multihead attention block: - def _mha_block(self, - x: Tensor, - mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor]) -> Tensor: - x = self.multihead_attn( - x, - mem, - mem, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False)[0] + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) return self.dropout2(x) # Feed forward block. @@ -817,12 +781,10 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout3(x) -# Only difference to standard PyTorch class is that 'self._qkv_same_embed_dim' -# is always set to 'False' and the use of a cache registered as a buffer for -# autoregressive decoding. -class MultiheadAttention(nn.MultiheadAttention): +class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information - from different representation subspaces. + from different representation subspaces. Supports self-attention and + encoder-decoder attention. See `Attention Is All You Need `_. .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O @@ -832,117 +794,75 @@ class MultiheadAttention(nn.MultiheadAttention): num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + self_attn: Whether self attention or encoder-decoder attention is used. + Default: ``True``. dropout_rate: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. - Default: ``True``. - add_bias_kv: If specified, adds bias to the key and value sequences at - dim=0. Default: ``False``. - add_zero_attn: If specified, adds a new batch of zeros to the key and value - sequences at dim=1. Default: ``False``. - kdim: Total number of features for keys. Default: ``None`` - (uses ``kdim=embed_dim``). - vdim: Total number of features for values. Default: ``None`` - (uses ``vdim=embed_dim``). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + Default: ``False``. + device: The device of the module. + dtype: The dtype of the module. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + >>> attn_output, cache = multihead_attn(x) """ def __init__(self, - embed_dim, - num_heads, - dropout_rate=0., - bias=True, - add_bias_kv=False, - add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=True, - device=None, - dtype=None) -> None: - super().__init__( - embed_dim, - num_heads, - dropout=dropout_rate, - bias=bias, - add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn, - kdim=kdim, - vdim=vdim, - batch_first=batch_first, - device=device, - dtype=dtype) - # This is set to 'True' for kdim == vdim == embed_dim in the standard - # PyTorch class. - self._qkv_same_embed_dim = False + embed_dim: int, + num_heads: int, + self_attn: bool = True, + dropout_rate: float = 0., + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.self_attn = self_attn + self.dropout = dropout_rate + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, \ + 'embed_dim must be divisible by num_heads.' factory_kwargs = {'device': device, 'dtype': dtype} - self.q_proj_weight = nn.Parameter( - torch.empty((embed_dim, embed_dim), **factory_kwargs)) - self.k_proj_weight = nn.Parameter( - torch.empty((embed_dim, self.kdim), **factory_kwargs)) - self.v_proj_weight = nn.Parameter( - torch.empty((embed_dim, self.vdim), **factory_kwargs)) - self.register_parameter('in_proj_weight', None) + if self_attn: + # Self-attention. + self.in_proj = nn.Linear( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + # Encoder-decoder attention. + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.kv_proj = nn.Linear( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self._reset_parameters() def _reset_parameters(self): - if self._qkv_same_embed_dim: - xavier_uniform_(self.in_proj_weight) - else: - xavier_uniform_(self.q_proj_weight) - xavier_uniform_(self.k_proj_weight) - xavier_uniform_(self.v_proj_weight) - - if self.in_proj_bias is not None: - normal_(self.in_proj_bias, std=1e-6) - normal_(self.out_proj.bias, std=1e-6) - if self.bias_k is not None: - normal_(self.bias_k, std=1e-6) - if self.bias_v is not None: - normal_(self.bias_v, std=1e-6) + """Initiate parameters in the MultiheadAttention module.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) def forward(self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + x: Tensor, + mem: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None) -> Any: r""" Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, - :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` - when ``batch_first=True``, where :math:`L` is the target sequence - length, :math:`N` is the batch size, and :math:`E_q` is the query - embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, - :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` - when ``batch_first=True``, where :math:`S` is the source sequence - length, :math:`N` is the batch size, and :math:`E_k` is the key - embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, - :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` - when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the - value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: Dummy argument to make MultiheadAttention compatible - with standard PyTorch TransformerEncoder implementation. - need_weights: If specified, returns ``attn_output_weights`` in addition - to ``attn_outputs``.Default: ``True``. + x: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for self + attention mechanism. See "Attention Is All You Need" for more details. + mem: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for + encoder-decoder attention. See "Attention Is All You Need" for more + details. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the @@ -955,338 +875,112 @@ def forward(self, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. - average_attn_weights: If true, indicates that the returned - ``attn_weights`` should be averaged across heads. Otherwise, - ``attn_weights`` are provided separately per head. Note that this - flag only has an effect when ``need_weights=True``. Default: - ``True`` (i.e. average weights across heads) decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + cache: cache dictionary for autoregressive decoding. + index: index of the current decoding step, necessary for decoding cache. Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input - is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or - :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch - size, and :math:`E` is the embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. - If ``average_attn_weights=True``, returns attention weights averaged - across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the - target sequence length, and :math:`S` is the source sequence length. - If ``average_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched - or :math:`(N, \text{num\_heads}, L, S)`. - .. note:: - `batch_first` argument is ignored for unbatched inputs. + - **attn_output** - Attention outputs of shape :math:`(N, L, E)` when + ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E` is the embedding dimension + ``embed_dim``. + - **cache** - For autoregressive decoding. """ - del key_padding_mask - is_batched = query.dim() == 3 - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + # Shape: (batch size, sequence length, embedding dimensionality) + bsz, seq_len, embed_dim = x.size() + # In projection. + if self.self_attn: + q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) + else: + q = self.q_proj(x) + k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) + # This is 1 (!= seq_len) during autoreregressive decoding. + tgt_len = q.size(1) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. name = f'decoder.layers.{index}.self_attn' loc_cache = cache[name] if decode and name in cache else None + if decode: + if loc_cache is None: + loc_cache = { + 'cached_key': + torch.zeros((bsz, max_len, embed_dim), + dtype=k.dtype, + device=k.device), + 'cached_value': + torch.zeros((bsz, max_len, embed_dim), + dtype=v.dtype, + device=v.device), + 'cache_index': + torch.tensor(0, dtype=torch.long, device=k.device), + } + cached_key = loc_cache['cached_key'] + cached_value = loc_cache['cached_value'] + cache_index = loc_cache['cache_index'] + batch_size, max_length, num_features = cached_key.shape + assert batch_size == bsz, f'{batch_size} != {bsz}' + assert max_length == max_len, f'{max_length} != {max_len}' + assert num_features == embed_dim, f'{num_features} != {embed_dim}' + # Shape check of cached keys against query input. + expected_shape = (batch_size, 1, num_features) + if expected_shape != x.shape: + raise ValueError('Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.') + # Update key, value caches with our new 1d spatial slices. + cached_key[:, cache_index:cache_index + 1, :] = k + cached_value[:, cache_index:cache_index + 1, :] = v + k = cached_key + v = cached_value + cache_index += 1 + # Causal mask for cached decoder self-attention: + # our single query position should only attend to those key + # positions that have already been generated and cached, + # not the remaining zero elements. + if attn_mask is not None: + raise ValueError('Attention mask has to be None for decode == True.') + attn_mask = (torch.arange(max_length, device=k.device) >= + cache_index).reshape(1, max_length) + + # Update sequence length to account for complete sequence. + seq_len = k.size(1) + + # Reshape q, k, v for multihead attention. + q, k, v = q.transpose(0, 1), k.transpose(0, 1), v.transpose(0, 1) + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + v = v.view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + # Check dtype and shape of attention mask. + if not decode and attn_mask is not None: + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + # Ensure attn_mask's dim is 3. + if attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.') + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # Convert attention mask to float. + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, -1e10) + attn_mask = new_attn_mask - attn_output, attn_output_weights, loc_cache = multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_bias, self.bias_k, self.bias_v, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, need_weights=need_weights, attn_mask=attn_mask, - q_proj_weight=self.q_proj_weight, - k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, - average_attn_weights=average_attn_weights, - decode=decode, cache=loc_cache, max_len=max_len) + # Adjust dropout_rate probability. + dropout_rate = self.dropout if self.training else 0.0 + + # Calculate attention and out projection. + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_rate) + attn_output = self.out_proj(attn_output.view(bsz, tgt_len, embed_dim)) if decode: cache[name] = loc_cache - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights, cache - else: - return attn_output, attn_output_weights, cache - - -def _in_projection( - q: Tensor, - k: Tensor, - v: Tensor, - w_q: Tensor, - w_k: Tensor, - w_v: Tensor, - b_q: Optional[Tensor] = None, - b_k: Optional[Tensor] = None, - b_v: Optional[Tensor] = None, -) -> Tuple[Tensor, Tensor, Tensor]: - r"""Performs the in-projection step of the attention operation. This is simply - a triple of linear projections, with shape constraints on the weights which - ensure embedding dimension uniformity in the projected outputs. - Output is a triple containing projection tensors for query, key and value. - """ - eq, ek = q.size(-1), k.size(-1) - assert w_q.shape == (eq, eq), \ - f'Expecting query weights shape of {(eq, eq)}, but got {w_q.shape}' - assert w_k.shape == (eq, ek), \ - f'Expecting key weights shape of {(eq, ek)}, but got {w_k.shape}' - assert w_v.shape == (eq, ek), \ - f'Expecting value weights shape of {(eq, ek)}, but got {w_v.shape}' - assert b_q is None or b_q.shape == (eq,), \ - f'Expecting query bias shape of {(eq,)}, but got {b_q.shape}' - assert b_k is None or b_k.shape == (eq,), \ - f'Expecting key bias shape of {(eq,)}, but got {b_k.shape}' - assert b_v is None or b_v.shape == (eq,), \ - f'Expecting value bias shape of {(eq,)}, but got {b_v.shape}' - return torch.nn.functional.linear(q, w_q, b_q), \ - torch.nn.functional.linear(k, w_k, b_k), \ - torch.nn.functional.linear(v, w_v, b_v) - - -# Modified to create cache for autoregressive decoding. -def multi_head_attention_forward(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_bias: Optional[Tensor], - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - dropout_rate: float, - out_proj_weight: Tensor, - out_proj_bias: Optional[Tensor], - training: bool = True, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - average_attn_weights: bool = True, - decode: bool = False, - cache: Optional[dict] = None, - max_len: Optional[int] = None) -> Any: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - See "Attention Is All You Need" for more details. - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_bias: input projection bias. - bias_k, bias_v: bias of the key and value sequences to be added at dim=0. - dropout_rate: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout_rate if is ``True``. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. - A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the - entries of each batch. - q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: - input projection weight and bias. - average_attn_weights: If true, indicates that the returned ``attn_weights`` - should be averaged across heads. - Otherwise, ``attn_weights`` are provided separately per head. - Note that this flag only has an effect when ``need_weights=True.``. - Default: True - decode: wether to use cache for autoregressive decoding or not. - cache: dict which contains cache for decoding for the current - MulitheadAttention module. - max_len: maximum sequence length, necessary for decoding cache. - Shape: - Inputs: - - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence - length, N is the batch size, E is the embedding dimension. - - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence - length, N is the batch size, E is the embedding dimension. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, - S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` - where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is - allowed to attend the unmasked positions. If a ByteTensor is provided, - the non-zero positions are not allowed to attend while the zero positions - will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. - If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target - sequence length, N is the batch size, E is the embedding dimension. - - attn_output_weights: Only returned when ``need_weights=True``. - If ``average_attn_weights=True``, returns - attention weights averaged across heads of shape :math:`(L, S)` when input - is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source - sequence length. If ``average_weights=False``, returns attention weights - per head of shape :math:`(num_heads, L, S)` when input is unbatched or - :math:`(N, num_heads, L, S)`. - """ - # Set up shape variables. - tgt_len, bsz, embed_dim = query.shape - src_len, _, _ = key.shape - assert embed_dim == embed_dim_to_check, \ - f'was expecting dimension of {embed_dim_to_check}, but got {embed_dim}' - if isinstance(embed_dim, torch.Tensor): - # `embed_dim` can be a tensor when JIT tracing. - head_dim = embed_dim.div(num_heads, rounding_mode='trunc') - else: - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, \ - f'embed_dim {embed_dim} not divisible by num_heads {num_heads}' - # Allow MHA to have different embedding dimensions when separate projection - # weights are used. - assert key.shape[:2] == value.shape[:2], \ - (f"key's sequence and batch dims {key.shape[:2]} do not match value's " - f'{value.shape[:2]}') - - # Compute in-projection. - assert q_proj_weight is not None, \ - 'use_separate_proj_weight is True but q_proj_weight is None' - assert k_proj_weight is not None, \ - 'use_separate_proj_weight is True but k_proj_weight is None' - assert v_proj_weight is not None, \ - 'use_separate_proj_weight is True but v_proj_weight is None' - if in_proj_bias is None: - b_q = b_k = b_v = None - else: - b_q, b_k, b_v = in_proj_bias.chunk(3) - q, k, v = _in_projection( - query, key, value, q_proj_weight, k_proj_weight, - v_proj_weight, b_q, b_k, b_v) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if decode: - if cache is None: - cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), - } - cached_key = cache['cached_key'] - cached_value = cache['cached_value'] - cache_index = cache['cache_index'] - batch_size, max_length, num_features = cached_key.shape - assert batch_size == bsz, f'{batch_size} != {bsz}' - assert max_length == max_len, f'{max_length} != {max_len}' - assert num_features == embed_dim, f'{num_features} != {embed_dim}' - # Shape check of cached keys against query input. - expected_shape = (1, batch_size, num_features) - if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, expected query shape ' - f'{expected_shape} instead got {query.shape}.') - # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k.transpose(dim0=0, dim1=1) - cached_value[:, cache_index:cache_index + 1, :] = v.transpose( - dim0=0, dim1=1) - k = cached_key.transpose(dim0=0, dim1=1) - v = cached_value.transpose(dim0=0, dim1=1) - cache_index += 1 - # Causal mask for cached decoder self-attention: - # our single query position should only attend to those key - # positions that have already been generated and cached, - # not the remaining zero elements. - if attn_mask is not None: - raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_length, device=k.device) >= - cache_index).reshape(1, max_length) - - # Prepare attention mask. - if not decode and attn_mask is not None: - if attn_mask.dtype == torch.uint8: - warnings.warn( - 'Byte tensor for attn_mask in nn.MultiheadAttention is deprecated.' - 'Use bool tensor instead.') - attn_mask = attn_mask.to(torch.bool) - else: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'float, byte, and bool types are supported, not {attn_mask.dtype}' - # ensure attn_mask's dim is 3 - if attn_mask.dim() == 2: - correct_2d_size = (tgt_len, src_len) - if attn_mask.shape != correct_2d_size: - raise RuntimeError( - f'The shape of the 2D attn_mask is {attn_mask.shape}, ' - f'but should be {correct_2d_size}.') - attn_mask = attn_mask.unsqueeze(0) - elif attn_mask.dim() == 3: - correct_3d_size = (bsz * num_heads, tgt_len, src_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'should be {correct_3d_size}.') - else: - raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") - - # Add bias along batch dimension (currently second). - if bias_k is not None and bias_v is not None: - k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - else: - assert bias_k is None - assert bias_v is None - - # Reshape q, k, v for multihead attention and make em batch first. - q = \ - q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) - k = \ - k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) - v = \ - v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) - - # Update source sequence length after adjustments. - src_len = k.size(1) - - # Convert mask to float. - if attn_mask is not None and attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, -1e10) - attn_mask = new_attn_mask - - # Adjust dropout_rate probability. - if not training: - dropout_rate = 0.0 - - # Calculate attention and out projection. - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) - attn_output = attn_output.transpose(0, 1).contiguous().view( - tgt_len * bsz, embed_dim) - attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) - attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) - - if need_weights: - q_scaled = q / math.sqrt(q.shape[-1]) - - if attn_mask is not None: - attn_output_weights = torch.baddbmm(attn_mask, - q_scaled, - k.transpose(-2, -1)) - else: - attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) - - # Optionally average attention weights over heads. - attn_output_weights = attn_output_weights.view(bsz, - num_heads, - tgt_len, - src_len) - if average_attn_weights: - attn_output_weights = attn_output_weights.sum(dim=1) / num_heads - return attn_output, attn_output_weights, cache - else: - return attn_output, None, cache + return attn_output, cache From 4490c8e494d5722901c11ef33297a05ed9f9ad16 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 15 Aug 2023 16:48:06 +0200 Subject: [PATCH 03/66] Fix pylint and yapf --- .../workloads/wmt/wmt_pytorch/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 391bd03b8..6798b0016 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -542,7 +542,7 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout2(x) -# Modified to use cache for autoregressive decoding and custom +# Modified to use cache for autoregressive decoding and custom # MultiheadAttention modules. class TransformerDecoder(nn.Module): r"""TransformerDecoder is a stack of N decoder layers @@ -626,7 +626,7 @@ def forward(self, return output -# Modified to use cache for autoregressive decoding and custom +# Modified to use cache for autoregressive decoding and custom # MultiheadAttention modules. class TransformerDecoderLayer(nn.Module): r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and @@ -939,8 +939,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_length, device=k.device) >= - cache_index).reshape(1, max_length) + attn_mask = (torch.arange(max_length, device=k.device) + >= cache_index).reshape(1, max_length) # Update sequence length to account for complete sequence. seq_len = k.size(1) From 49292e48f146e0308df6f50beb3283dac40e5eb3 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 15 Aug 2023 16:53:30 +0200 Subject: [PATCH 04/66] Fix for old yapf version --- algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 6798b0016..2ef36b44f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -939,8 +939,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_length, device=k.device) - >= cache_index).reshape(1, max_length) + attn_mask = (torch.arange(max_length, device=k.device) >= + cache_index).reshape(1, max_length) # Update sequence length to account for complete sequence. seq_len = k.size(1) From 12578fd109f63263edd5fbf2680775356fe89aca Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 15 Aug 2023 18:16:40 +0200 Subject: [PATCH 05/66] Fix param shapes test for fused attention layers in PyTorch WMT workload --- tests/test_param_shapes.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index fef9c2978..098ab9121 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -1,3 +1,5 @@ +from itertools import zip_longest + import jax import numpy as np import pytest @@ -53,13 +55,23 @@ def test_param_shapes(workload): jax_workload.param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) - assert len(jax_param_shapes) == len(pytorch_param_shapes) + if workload == 'wmt': + # The PyTorch transformer for WMT is implemented with fused linear layers + # for the projection of QKV inside of the MultiheadAttention module. + # Two weight matrices for each of the two self-attention layers less and one + # less for the encoder-decoder attention layer -> 5 weight matrices less. + # We have 6 encoder/decoder layers, hence 30 weight matrices less in total. + assert len(jax_param_shapes) == len(pytorch_param_shapes) + 30 + else: + assert len(jax_param_shapes) == len(pytorch_param_shapes) # Check if total number of params deduced from shapes match. num_jax_params = 0 num_pytorch_params = 0 - for jax_shape, pytorch_shape in zip(jax_param_shapes, pytorch_param_shapes): + for jax_shape, pytorch_shape in zip_longest(jax_param_shapes, + pytorch_param_shapes): num_jax_params += np.prod(jax_shape.shape_tuple) - num_pytorch_params += np.prod(pytorch_shape.shape_tuple) + if pytorch_shape is not None: + num_pytorch_params += np.prod(pytorch_shape.shape_tuple) assert num_jax_params == num_pytorch_params From 8d5463a5c37dc3f4d5b5c348ec9bfbc83492baa8 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 16 Aug 2023 15:39:10 +0200 Subject: [PATCH 06/66] Fix param types test for fused attention layers in PyTorch WMT workload --- tests/test_param_types.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 3679289ed..6e7e5a9ec 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -71,6 +71,12 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), } + num_kv = { + 'jax': + jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'pytorch': + pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + } num_q = { 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), @@ -96,11 +102,13 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), } qkv_match = num_qkv['jax'] == num_qkv['pytorch'] + kv_match = num_kv['jax'] == num_kv['pytorch'] q_match = num_q['jax'] == num_q['pytorch'] k_match = num_k['jax'] == num_k['pytorch'] v_match = num_v['jax'] == num_v['pytorch'] bias_match = num_bias['jax'] == num_bias['pytorch'] - qkv_match = qkv_match and q_match and k_match and v_match and bias_match + qkv_match = ( + qkv_match and kv_match and q_match and k_match and v_match and bias_match) # We subtract 2 * num_qkv from the number of biases because there are 2 # missing for each of q, k, v. @@ -112,7 +120,12 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and (num_qkv['pytorch'] != 0 and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'])) - qkv_match = qkv_match or jax_qkv_match or pytorch_qkv_match + pytorch_kv_match = ( + num_q['jax'] == num_k['jax'] == num_v['jax'] == + num_qkv['pytorch'] + num_kv['pytorch'] and + num_q['pytorch'] == num_kv['pytorch']) + qkv_match = ( + qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match) return qkv_match @@ -144,6 +157,7 @@ def test_param_types(workload_name): # Check if total number of each type match. attention_keys = { spec.ParameterType.ATTENTION_QKV, + spec.ParameterType.ATTENTION_KV, spec.ParameterType.ATTENTION_Q, spec.ParameterType.ATTENTION_K, spec.ParameterType.ATTENTION_V, From 7308c91f2cd5e6343f189beee42a6a9be1094bfa Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 16 Aug 2023 16:10:15 +0200 Subject: [PATCH 07/66] Fix arrangement of elements for attention --- .../workloads/wmt/wmt_pytorch/models.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 2ef36b44f..3451f7b5e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -918,12 +918,8 @@ def forward(self, cached_key = loc_cache['cached_key'] cached_value = loc_cache['cached_value'] cache_index = loc_cache['cache_index'] - batch_size, max_length, num_features = cached_key.shape - assert batch_size == bsz, f'{batch_size} != {bsz}' - assert max_length == max_len, f'{max_length} != {max_len}' - assert num_features == embed_dim, f'{num_features} != {embed_dim}' # Shape check of cached keys against query input. - expected_shape = (batch_size, 1, num_features) + expected_shape = (bsz, 1, embed_dim) if expected_shape != x.shape: raise ValueError('Autoregressive cache shape error, expected query ' f'shape {expected_shape} instead got {x.shape}.') @@ -939,18 +935,16 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_length, device=k.device) >= - cache_index).reshape(1, max_length) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) - # Reshape q, k, v for multihead attention. - q, k, v = q.transpose(0, 1), k.transpose(0, 1), v.transpose(0, 1) - q, k, v = q.contiguous(), k.contiguous(), v.contiguous() - q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - k = k.view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - v = v.view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # Rearrange q, k, v for multihead attention. + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Check dtype and shape of attention mask. if not decode and attn_mask is not None: @@ -965,6 +959,8 @@ def forward(self, else: raise RuntimeError( f"attn_mask's dimension {attn_mask.dim()} is not supported") + # Reshape attention mask to be consistent with q, k, v. + attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) # Convert attention mask to float. if attn_mask is not None and attn_mask.dtype == torch.bool: @@ -975,10 +971,14 @@ def forward(self, # Adjust dropout_rate probability. dropout_rate = self.dropout if self.training else 0.0 - # Calculate attention and out projection. + # Calculate attention. attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask, dropout_rate) - attn_output = self.out_proj(attn_output.view(bsz, tgt_len, embed_dim)) + # Rearrange for output projection. + attn_output = attn_output.transpose(1, 2).contiguous().view( + bsz, tgt_len, embed_dim) + # Output projection. + attn_output = self.out_proj(attn_output) if decode: cache[name] = loc_cache From dc92fedb136e890e7806584260261417ee6c2d23 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 29 Aug 2023 16:51:28 +0200 Subject: [PATCH 08/66] Remove redundant dtype conversions --- .../workloads/wmt/wmt_pytorch/models.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 59c1cbc97..b787785a1 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -10,36 +10,31 @@ from torch.nn.init import xavier_uniform_ -def make_causal_mask(x: Tensor, - device: str = 'cuda:0', - dtype: torch.dtype = torch.float32) -> Tensor: +def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. Args: x: input array of shape `[batch..., len]` device: device to store the idxs - dtype: mask return dtype Returns: A `[batch..., len, len]` shaped causal attention mask. """ idxs = torch.broadcast_to( torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) - return torch.greater_equal(idxs.unsqueeze(-1), - idxs.unsqueeze(-2)).to(dtype=dtype) + return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) def make_src_mask(src, inputs_segmentation, nhead): """Utility for creating src mask and adjust it for PyTorch Transformer API.""" - src_mask = torch.mul((src > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: src_mask = torch.logical_and( src_mask, torch.eq( inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + inputs_segmentation.unsqueeze(-2))) # Flip values and ensure numerical stability. src_mask = torch.repeat_interleave( torch.logical_not(src_mask), repeats=nhead, dim=0) @@ -58,27 +53,25 @@ def make_tgt_and_memory_mask(tgt, Transformer API.""" if not decode: tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), - (tgt > 0).unsqueeze(-2)).to(dtype=torch.float32), + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), make_causal_mask(tgt, device=tgt.device)) - memory_mask = torch.mul((tgt > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) else: tgt_mask = None memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + (src > 0).unsqueeze(-2)) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: tgt_mask = torch.logical_and( tgt_mask, torch.eq( targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + targets_segmentation.unsqueeze(-2))) memory_mask = torch.logical_and( memory_mask, torch.eq( targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + inputs_segmentation.unsqueeze(-2))) # Flip values and ensure numerical stability. memory_mask = torch.repeat_interleave( torch.logical_not(memory_mask), repeats=nhead, dim=0) From 4a2b58a006ae252e269a5f79845df266dc29d4e5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 00:54:23 +0000 Subject: [PATCH 09/66] update datasetup documentation --- README.md | 4 +- datasets/README.md | 128 +++++++++++++++++++++++++++++++++----- datasets/dataset_setup.py | 6 +- 3 files changed, 116 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 216354e73..a4536c35e 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ See instructions [here](https://github.com/NVIDIA/nvidia-docker). ### Running Docker Container (Interactive) To use the Docker container as an interactive virtual environment, you can run a container mounted to your local data and code directories and execute the `bash` program. This may be useful if you are in the process of developing a submission. -1. Run detached Docker Container. The container_id will be printed if the container is run successfully. +1. Run detached Docker Container. The `container_id` will be printed if the container is running successfully. ```bash docker run -t -d \ -v $HOME/data/:/data/ \ @@ -122,7 +122,7 @@ To use the Docker container as an interactive virtual environment, you can run a -v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ --gpus all \ --ipc=host \ - + \ --keep_container_alive true ``` 2. Open a bash terminal diff --git a/datasets/README.md b/datasets/README.md index 5afe257fe..e191c780a 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -1,28 +1,114 @@ # Dataset Setup -Use `dataset_setup.py` to download datasets, for example: +TL;DR: +Use `dataset_setup.py` to download datasets. +Usage: ``` python3 datasets/dataset_setup.py \ --data_dir=~/data \ + -- + -- +``` +The complete benchmark uses 6 datasets: +- OGBG +- WMT +- FastMRI +- Imagenet +- Criteo 1TB +- Librispeech + + + +Some dataset setups will require you to sign a third party agreement with the +dataset in order to get the donwload URLs. + + +# Per dataset instructions +## Environment + +### Set data directory (Docker container) +If you are running the `dataset_setup.py` script from a Docker container, please +make sure the data directory is mounted to a directory on your host with +-v flag. If you are following instructions from the README you will have used +the `-v $HOME/data:/data` flag in the `docker run` command. This will mount +the `$HOME/data` directory to the `/data` directory in the container. +In this case set --data_dir to `\data`. +``` +DATA_DIR=\data +``` +### Set data directory (on host) +Alternatively, if you are running the data download script directly on your host, feel free +to choose whatever directory you find suitable, further submission instructions +assume the data is stored in `~/data`. +``` +DATA_DIR=~/data +``` +#### Start tmux session (Recommended) +If running the dataset_setup.py on directly on host it is recommended to run +the dataset_setup.py script in a tmux session because some of the data downloads may +take several hours. To avoid your setup being interrupted start a tmux session: +``` +tmux new -s data_setup +``` + + +## Datasets + +### OGBG +From `algorithmic-efficiency` run: +``` +python3 datasets/dataset_setup.py \ + --data_dir=$DATA_DIR/ogbg \ --ogbg ``` -This will require the same pip dependencies as `submission_runner.py`. +### WMT +From `algorithmic-efficiency` run: +``` +python3 datasets/dataset_setup.py \ + --data_dir=$DATA_DIR/wmt \ + --wmt +``` + + +## FastMRI +Fill out form on https://fastmri.med.nyu.edu/. After filling out the form +you should get an email containing the URLS for "knee_singlecoil_train", +"knee_singlecoil_val" and "knee_singlecoil_test". + +``` +python3 datasets/dataset_setup.py \ + --data_dir=$DATA_DIR/fastmri \ + --fastmri \ + --fastmri_knee_singlecoil_train_url "" \ + --fastmri_knee_singlecoil_val_url "" \ + --fastmri_knee_singlecoil_test_url "" +``` -Some datasets require signing a form before downloading: +## ImageNet +Register on https://image-net.org/ and follow directions to obtain the +URLS for the ILSVRC2012 train and validation images. -FastMRI: -Fill out form on https://fastmri.med.nyu.edu/ and run this script with the -links that are emailed to you for "knee_singlecoil_train" and -"knee_singlecoil_val". +Imagenet dataset processsing is resource intensive. To avoid potential +ResourcExhausted errors increase the maximum number of open file descriptors: +``` +ulimit -n 8192 +``` -ImageNet: -Register on https://image-net.org/ and run this script with the links to the -ILSVRC2012 train and validation images. +Alo note that some functions use subprocess.Popen(..., shell=True), which can be +dangerous if the user injects code into the --data_dir or --temp_dir flags. We +do some basic sanitization in main(), but submitters should not let untrusted +users run this script on their systems. -Note for tfds ImageNet, you may have to increase the max number of files allowed -open at once using `ulimit -n 8192`. +``` +python3 datasets/dataset_setup.py \ + --data_dir=$DATA_DIR/imagenet \ + --imagenet + --imagenet_train_url + --imagenet_val_url +``` -Note that in order to avoid potential accidental deletion, this script does NOT +### Cleanup +Note: that in order to avoid potential accidental deletion, this script does NOT delete any intermediate temporary files (such as zip archives) without a user confirmation. Deleting temp files is particularly important for Criteo 1TB, as there can be multiple copies of the dataset on disk during preprocessing if @@ -31,10 +117,7 @@ can pass --interactive_deletion=false and then all files will be downloaded to the provided --temp_dir, and the user can manually delete these after downloading has finished. -Note that some functions use subprocess.Popen(..., shell=True), which can be -dangerous if the user injects code into the --data_dir or --temp_dir flags. We -do some basic sanitization in main(), but submitters should not let untrusted -users run this script on their systems. + ## Librispeech @@ -57,3 +140,14 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +## Criteo1tb + +Note: that in order to avoid potential accidental deletion, this script does NOT +delete any intermediate temporary files (such as zip archives) without a user +confirmation. Deleting temp files is particularly important for Criteo 1TB, as +there can be multiple copies of the dataset on disk during preprocessing if +files are not cleaned up. If you do not want any temp files to be deleted, you +can pass --interactive_deletion=false and then all files will be downloaded to +the provided --temp_dir, and the user can manually delete these after +downloading has finished. diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index d1636a3e5..2f09eb68d 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -55,7 +55,7 @@ Example command: python3 datasets/dataset_setup.py \ - --data_dir=~/data \ + --data_dir=~/data/imagenet \ --temp_dir=/tmp/mlcommons_data --imagenet \ --imagenet_train_url= \ @@ -170,7 +170,7 @@ flags.DEFINE_string( 'fastmri_knee_singlecoil_test_url', None, - 'Only necessary if you want this script to `wget` the FastMRI validation ' + 'Only necessary if you want this script to `wget` the FastMRI test ' 'split. If not, you can supply the path to --data_dir in ' 'submission_runner.py.') @@ -345,7 +345,7 @@ def download_cifar(data_dir, framework): raise ValueError('Invalid value for framework: {}'.format(framework)) -def extract_filename_from_url(url, start_str='knee', end_str='.xz'): +def extract_filename_from_url(url, start_str='knee', end_str='.g-uz'): """ The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. From 14089b479d4690bbe93c04b053a0ad1debf008db Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 00:56:26 +0000 Subject: [PATCH 10/66] fastmri datasetup fix --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2f09eb68d..2ee9ca022 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -345,7 +345,7 @@ def download_cifar(data_dir, framework): raise ValueError('Invalid value for framework: {}'.format(framework)) -def extract_filename_from_url(url, start_str='knee', end_str='.g-uz'): +def extract_filename_from_url(url, start_str="knee", end_str=".gz"): """ The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. @@ -355,7 +355,7 @@ def extract_filename_from_url(url, start_str='knee', end_str='.g-uz'): end = url.find(end_str) if failure in (start, end): raise ValueError( - f'Unable to locate filename wrapped in {start}--{end} in {url}') + f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}') end += len(end_str) # make it inclusive return url[start:end] From 1f12ac4b0f68b12fc175a6279e1dd46d3a964225 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 01:01:54 +0000 Subject: [PATCH 11/66] make data_dir consistent --- datasets/README.md | 6 +++--- datasets/dataset_setup.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index e191c780a..1ed1f77ee 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -65,7 +65,7 @@ python3 datasets/dataset_setup.py \ From `algorithmic-efficiency` run: ``` python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR/wmt \ + --data_dir=$DATA_DIR \ --wmt ``` @@ -77,7 +77,7 @@ you should get an email containing the URLS for "knee_singlecoil_train", ``` python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR/fastmri \ + --data_dir=$DATA_DIR \ --fastmri \ --fastmri_knee_singlecoil_train_url "" \ --fastmri_knee_singlecoil_val_url "" \ @@ -101,7 +101,7 @@ users run this script on their systems. ``` python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR/imagenet \ + --data_dir=$DATA_DIR \ --imagenet --imagenet_train_url --imagenet_val_url diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2ee9ca022..2ceb6b4b6 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -55,7 +55,7 @@ Example command: python3 datasets/dataset_setup.py \ - --data_dir=~/data/imagenet \ + --data_dir=~/dataa \ --temp_dir=/tmp/mlcommons_data --imagenet \ --imagenet_train_url= \ @@ -345,7 +345,7 @@ def download_cifar(data_dir, framework): raise ValueError('Invalid value for framework: {}'.format(framework)) -def extract_filename_from_url(url, start_str="knee", end_str=".gz"): +def extract_filename_from_url(url, start_str='knee', end_str='.gz'): """ The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. @@ -364,7 +364,6 @@ def download_fastmri(data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url): - data_dir = os.path.join(data_dir, 'fastmri') # Download fastmri train dataset knee_train_filename = extract_filename_from_url(fastmri_train_url) @@ -597,11 +596,13 @@ def download_mnist(data_dir): def download_ogbg(data_dir): + data_dir = os.path.join(data_dir, 'ogbg') tfds.builder('ogbg_molpcba:0.1.3', data_dir=data_dir).download_and_prepare() def download_wmt(data_dir): """WMT14 and WMT17 de-en.""" + data_dir = os.path.join(data_dir, 'wmt') for ds_name in ['wmt14_translate/de-en:1.0.0', 'wmt17_translate/de-en:1.0.0']: dataset_builder = tfds.builder(ds_name, data_dir=data_dir) dataset_builder.download_and_prepare() From d43264c1dbe63270206f13c1ce5fe5cbcf5e4ec5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 01:25:57 +0000 Subject: [PATCH 12/66] fix fastmri datasetup --- datasets/dataset_setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2ceb6b4b6..3badd0614 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -95,9 +95,9 @@ IMAGENET_TRAIN_TAR_FILENAME = 'ILSVRC2012_img_train.tar' IMAGENET_VAL_TAR_FILENAME = 'ILSVRC2012_img_val.tar' -FASTMRI_TRAIN_TAR_FILENAME = 'knee_singlecoil_train.tar.xz' -FASTMRI_VAL_TAR_FILENAME = 'knee_singlecoil_val.tar.xz' -FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz' +FASTMRI_TRAIN_TAR_FILENAME = 'knee_singlecoil_train.tar.gz' +FASTMRI_VAL_TAR_FILENAME = 'knee_singlecoil_val.tar.gz' +FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.gz' flags.DEFINE_boolean( 'interactive_deletion', From a5f42c283d34d2c261f4c01dcd71eeb0140a7eae Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 01:28:43 +0000 Subject: [PATCH 13/66] fix end str --- datasets/dataset_setup.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 3badd0614..a9087e004 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -95,9 +95,9 @@ IMAGENET_TRAIN_TAR_FILENAME = 'ILSVRC2012_img_train.tar' IMAGENET_VAL_TAR_FILENAME = 'ILSVRC2012_img_val.tar' -FASTMRI_TRAIN_TAR_FILENAME = 'knee_singlecoil_train.tar.gz' -FASTMRI_VAL_TAR_FILENAME = 'knee_singlecoil_val.tar.gz' -FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.gz' +FASTMRI_TRAIN_TAR_FILENAME = 'knee_singlecoil_train.tar.xz' +FASTMRI_VAL_TAR_FILENAME = 'knee_singlecoil_val.tar.xz' +FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz' flags.DEFINE_boolean( 'interactive_deletion', @@ -304,12 +304,12 @@ def download_criteo1tb(data_dir, # Unzip the individual days. processes = [] - gz_paths = [] + xz_paths = [] for day in range(24): - input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz') - gz_paths.append(input_path) + input_path = os.path.join(tmp_criteo_dir, f'day_{day}.xz') + xz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') - unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' + unzip_cmd = (f'pixz -d -c -p{num_decompression_threads} "{input_path}" > ' f'"{unzipped_path}"') logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}') processes.append(subprocess.Popen(unzip_cmd, shell=True)) @@ -345,7 +345,7 @@ def download_cifar(data_dir, framework): raise ValueError('Invalid value for framework: {}'.format(framework)) -def extract_filename_from_url(url, start_str='knee', end_str='.gz'): +def extract_filename_from_url(url, start_str='knee', end_str='.xz'): """ The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. From 8a7302299cafe9f66c127fd51337084ae948f025 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 30 Aug 2023 01:40:24 +0000 Subject: [PATCH 14/66] fix extract method --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a9087e004..898dbdce8 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -392,7 +392,7 @@ def extract(source, dest): if not os.path.exists(dest): os.path.makedirs(dest) logging.info(f'Extracting {source} to {dest}') - tar = tarfile.open(source) + tar = tarfile.open(source, 'r:xz') logging.info('Opened tar') tar.extractall(dest) From 426a5b8c4f43db104ef3035b93a2ff9675819789 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 30 Aug 2023 19:12:36 +0200 Subject: [PATCH 15/66] Add torch.cuda.synchronize() to profiler --- algorithmic_efficiency/profiler.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index 0a1c1be79..d52a532b2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -11,6 +11,13 @@ from typing import Dict, Generator, List, Optional, Tuple import numpy as np +import torch + + +def _get_monotonic_time() -> float: + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.monotonic() class Profiler: @@ -20,7 +27,7 @@ def __init__(self, local_rank: Optional[int] = None) -> None: self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) - self.start_time = time.monotonic() + self.start_time = _get_monotonic_time() def set_local_rank(self, local_rank: int) -> None: self._local_rank = local_rank @@ -35,12 +42,12 @@ def start(self, action_name: str) -> None: if action_name in self.current_actions: raise ValueError( f'Attempted to start {action_name} which has already started.') - self.current_actions[action_name] = time.monotonic() + self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: if self.local_rank != 0: pass - end_time = time.monotonic() + end_time = _get_monotonic_time() if action_name not in self.current_actions: raise ValueError(f'Attempting to stop recording an action ' f'({action_name}) which was never started.') @@ -59,7 +66,7 @@ def profile(self, action_name: str) -> Generator: def _make_report( self ) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]: - total_duration = time.monotonic() - self.start_time + total_duration = _get_monotonic_time() - self.start_time report = [(str(a), float(np.mean(d)), float(np.std(d)), From 96361d3a21d2b5747d7a6427a21601d333850c9f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 00:12:17 +0000 Subject: [PATCH 16/66] change librispeech datasetup folder names --- datasets/README.md | 80 +++++++++++++++++++++++---------------- datasets/dataset_setup.py | 9 +++-- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 1ed1f77ee..65249be35 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -2,7 +2,7 @@ TL;DR: Use `dataset_setup.py` to download datasets. Usage: -``` +```bash python3 datasets/dataset_setup.py \ --data_dir=~/data \ -- @@ -17,7 +17,6 @@ The complete benchmark uses 6 datasets: - Librispeech - Some dataset setups will require you to sign a third party agreement with the dataset in order to get the donwload URLs. @@ -32,21 +31,21 @@ make sure the data directory is mounted to a directory on your host with the `-v $HOME/data:/data` flag in the `docker run` command. This will mount the `$HOME/data` directory to the `/data` directory in the container. In this case set --data_dir to `\data`. -``` +```bash DATA_DIR=\data ``` ### Set data directory (on host) Alternatively, if you are running the data download script directly on your host, feel free to choose whatever directory you find suitable, further submission instructions assume the data is stored in `~/data`. -``` +```bash DATA_DIR=~/data ``` #### Start tmux session (Recommended) If running the dataset_setup.py on directly on host it is recommended to run the dataset_setup.py script in a tmux session because some of the data downloads may take several hours. To avoid your setup being interrupted start a tmux session: -``` +```bash tmux new -s data_setup ``` @@ -55,17 +54,17 @@ tmux new -s data_setup ### OGBG From `algorithmic-efficiency` run: -``` +```bash python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR/ogbg \ + --data_dir $DATA_DIR/ogbg \ --ogbg ``` ### WMT From `algorithmic-efficiency` run: -``` +```bash python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR \ + --data_dir $DATA_DIR \ --wmt ``` @@ -75,13 +74,13 @@ Fill out form on https://fastmri.med.nyu.edu/. After filling out the form you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". -``` +```bash python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR \ + --data_dir $DATA_DIR \ --fastmri \ - --fastmri_knee_singlecoil_train_url "" \ - --fastmri_knee_singlecoil_val_url "" \ - --fastmri_knee_singlecoil_test_url "" + --fastmri_knee_singlecoil_train_url '' \ + --fastmri_knee_singlecoil_val_url '' \ + --fastmri_knee_singlecoil_test_url '' ``` ## ImageNet @@ -90,25 +89,49 @@ URLS for the ILSVRC2012 train and validation images. Imagenet dataset processsing is resource intensive. To avoid potential ResourcExhausted errors increase the maximum number of open file descriptors: -``` +```bash ulimit -n 8192 ``` -Alo note that some functions use subprocess.Popen(..., shell=True), which can be +The imagenet data pipeline differs between the pytorch and jax workloads. +Therefore, you will have to specify the framework (pytorch or jax) through the +framework flag. + +```bash +python3 datasets/dataset_setup.py \ + --data_dir $DATA_DIR \ + --temp_dir $DATA_DIR/tmp \ + --imagenet \ + --imagenet_train_url \ + --imagenet_val_url \ + --framework +``` + +Note that some functions use subprocess.Popen(..., shell=True), which can be dangerous if the user injects code into the --data_dir or --temp_dir flags. We do some basic sanitization in main(), but submitters should not let untrusted users run this script on their systems. -``` +### Cleanup +In order to avoid potential accidental deletion, this script does NOT +delete any intermediate temporary files (such as zip archives) without a user +confirmation. Deleting temp files is particularly important for Criteo 1TB, as +there can be multiple copies of the dataset on disk during preprocessing if +files are not cleaned up. If you do not want any temp files to be deleted, you +can pass --interactive_deletion=false and then all files will be downloaded to +the provided --temp_dir, and the user can manually delete these after +downloading has finished. + +## Criteo1tb +```bash python3 datasets/dataset_setup.py \ - --data_dir=$DATA_DIR \ - --imagenet - --imagenet_train_url - --imagenet_val_url + --data_dir $DATA_DIR \ + --temp_dir $DATA_DIR/tmp \ + --criteo1tb \ ``` -### Cleanup -Note: that in order to avoid potential accidental deletion, this script does NOT +### Clean up +In order to avoid potential accidental deletion, this script does NOT delete any intermediate temporary files (such as zip archives) without a user confirmation. Deleting temp files is particularly important for Criteo 1TB, as there can be multiple copies of the dataset on disk during preprocessing if @@ -118,7 +141,6 @@ the provided --temp_dir, and the user can manually delete these after downloading has finished. - ## Librispeech ### Training SPM Tokenizer @@ -141,13 +163,5 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` -## Criteo1tb -Note: that in order to avoid potential accidental deletion, this script does NOT -delete any intermediate temporary files (such as zip archives) without a user -confirmation. Deleting temp files is particularly important for Criteo 1TB, as -there can be multiple copies of the dataset on disk during preprocessing if -files are not cleaned up. If you do not want any temp files to be deleted, you -can pass --interactive_deletion=false and then all files will be downloaded to -the provided --temp_dir, and the user can manually delete these after -downloading has finished. + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 898dbdce8..8dd83733e 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -548,9 +548,12 @@ def download_librispeech(dataset_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. - tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech') - extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') + # tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech') + # extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') + # final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') + tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw') + extracted_data_dir = os.path.join(tmp_dir, 'librispeech_extracted') + final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) From d45d7bf4afa776b69743aedfe022c80de2884528 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 21:02:49 +0000 Subject: [PATCH 17/66] imagenet debugging --- datasets/dataset_setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8dd83733e..c45a5a8a1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -666,6 +666,9 @@ def main(_): logging.info('fastMRI download completed. Extracting...') setup_fastmri(data_dir, updated_data_dir) + if not FLAGS.imagenet: + print('not imagenet') + if FLAGS.all or FLAGS.imagenet: flags.mark_flag_as_required('imagenet_train_url') flags.mark_flag_as_required('imagenet_val_url') From 66395f6fdb11c5d27978bbfe2716abf1aaf928ca Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:06:33 +0000 Subject: [PATCH 18/66] imagenet fix --- datasets/README.md | 12 ++++++------ datasets/dataset_setup.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 65249be35..b8e09343c 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -98,13 +98,13 @@ Therefore, you will have to specify the framework (pytorch or jax) through the framework flag. ```bash -python3 datasets/dataset_setup.py \ - --data_dir $DATA_DIR \ - --temp_dir $DATA_DIR/tmp \ +python3 datasets/dataset_setup.py \ + --data_dir=/data \ --imagenet \ - --imagenet_train_url \ - --imagenet_val_url \ - --framework + --temp_dir=$DATA_DIR/tmp \ --imagenet_train_url=https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar \ + --imagenet_val_url=https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar \ + --framework=jax + ``` Note that some functions use subprocess.Popen(..., shell=True), which can be diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index c45a5a8a1..965596e64 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -433,13 +433,13 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): if not os.path.exists(imagenet_train_filepath): logging.info( 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) - _download_url(url=imagenet_train_url, data_dir=data_dir).download() + _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset if not os.path.exists(imagenet_val_filepath): logging.info('Downloading imagenet validation dataset from {}'.format( imagenet_val_url)) - _download_url(url=imagenet_val_url, data_dir=data_dir).download() + _download_url(url=imagenet_val_url, data_dir=data_dir) # Download imagenet test set download_imagenet_v2(data_dir) From 239cb5bba5bc8d63d80ee68ba1cee0a6e2fb9b2b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:07:56 +0000 Subject: [PATCH 19/66] download fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 965596e64..6b4d26322 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -213,7 +213,7 @@ def _download_url(url, data_dir, name=None): file_path = os.path.join(data_dir, url.split('/')[-1]) else: file_path = os.path.join(data_dir, name) - logging.info(f'About to download to {file_path}') + logging.info(f'Downloading URL {url} to {file_path}') response = requests.get(url, stream=True, timeout=600) total_size_in_bytes = int(response.headers.get('Content-length', 0)) From 6bd1087db3dceee618dd2b92819cbacd754e8ff2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:10:05 +0000 Subject: [PATCH 20/66] imagenet download fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 6b4d26322..e2440e719 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -230,7 +230,7 @@ def _download_url(url, data_dir, name=None): break logging.info('Invalid response. Try again.') if overwrite == 'n': - logging.info('Skipping download to {}'.format(file_path)) + logging.info(f'Skipping download URL {url} to {}'.format(file_path)) return with open(file_path, 'wb') as f: From 9a2c6d271624c3c27ab819d7605fc111c20b1085 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:14:45 +0000 Subject: [PATCH 21/66] remove expand user from download_url --- datasets/dataset_setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index e2440e719..d52d4808f 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -207,8 +207,6 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion): def _download_url(url, data_dir, name=None): - - data_dir = os.path.expanduser(data_dir) if not name: file_path = os.path.join(data_dir, url.split('/')[-1]) else: From 8c078122ae193beaa55d6ffce48429cdcbb105fe Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:16:20 +0000 Subject: [PATCH 22/66] fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index d52d4808f..134ba35d1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -228,7 +228,7 @@ def _download_url(url, data_dir, name=None): break logging.info('Invalid response. Try again.') if overwrite == 'n': - logging.info(f'Skipping download URL {url} to {}'.format(file_path)) + logging.info(f'Skipping download URL {url} to {}'.format(file_path)') return with open(file_path, 'wb') as f: From 86ed7de3daabefb4815020f60612588e39017762 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:16:40 +0000 Subject: [PATCH 23/66] string fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 134ba35d1..282ede823 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -228,7 +228,7 @@ def _download_url(url, data_dir, name=None): break logging.info('Invalid response. Try again.') if overwrite == 'n': - logging.info(f'Skipping download URL {url} to {}'.format(file_path)') + logging.info(f'Skipping download URL {url} to {file_path}') return with open(file_path, 'wb') as f: From 5f9005ea2339abe4f4f03c5bc8fec1a981b839ee Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:23:14 +0000 Subject: [PATCH 24/66] debugging --- datasets/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 282ede823..e968f8d47 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -628,6 +628,8 @@ def main(_): raise ValueError(f'Invalid data_dir: {data_dir}.') if any(s in tmp_dir for s in bad_chars): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') + print('data dir before expand user') + print(data_dir) data_dir = os.path.abspath(os.path.expanduser(data_dir)) logging.info('Downloading data to %s...', data_dir) From 519cd5dc8e4947ee4d8477fdba9571d2b43d9ab9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:43:09 +0000 Subject: [PATCH 25/66] fix expanduser condition --- datasets/README.md | 4 ++-- datasets/dataset_setup.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index b8e09343c..abbc762ad 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -32,14 +32,14 @@ the `-v $HOME/data:/data` flag in the `docker run` command. This will mount the `$HOME/data` directory to the `/data` directory in the container. In this case set --data_dir to `\data`. ```bash -DATA_DIR=\data +DATA_DIR='/data' ``` ### Set data directory (on host) Alternatively, if you are running the data download script directly on your host, feel free to choose whatever directory you find suitable, further submission instructions assume the data is stored in `~/data`. ```bash -DATA_DIR=~/data +DATA_DIR='~/data' ``` #### Start tmux session (Recommended) If running the dataset_setup.py on directly on host it is recommended to run diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index e968f8d47..6553d9963 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -630,7 +630,8 @@ def main(_): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') print('data dir before expand user') print(data_dir) - data_dir = os.path.abspath(os.path.expanduser(data_dir)) + if '~' in data_dir: + data_dir = os.path.abspath(os.path.expanduser(data_dir)) logging.info('Downloading data to %s...', data_dir) if FLAGS.all or FLAGS.criteo1tb: From 541ce57aeb67b7bcd318e90c3e957cf1cb76a7cd Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 31 Aug 2023 23:52:42 +0000 Subject: [PATCH 26/66] remove set resource limit --- datasets/README.md | 29 +++++++++++++++-------------- datasets/dataset_setup.py | 7 ------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index abbc762ad..8dfdc7328 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -56,16 +56,16 @@ tmux new -s data_setup From `algorithmic-efficiency` run: ```bash python3 datasets/dataset_setup.py \ - --data_dir $DATA_DIR/ogbg \ - --ogbg +--data_dir $DATA_DIR/ogbg \ +--ogbg ``` ### WMT From `algorithmic-efficiency` run: ```bash python3 datasets/dataset_setup.py \ - --data_dir $DATA_DIR \ - --wmt +--data_dir $DATA_DIR \ +--wmt ``` @@ -76,11 +76,11 @@ you should get an email containing the URLS for "knee_singlecoil_train", ```bash python3 datasets/dataset_setup.py \ - --data_dir $DATA_DIR \ - --fastmri \ - --fastmri_knee_singlecoil_train_url '' \ - --fastmri_knee_singlecoil_val_url '' \ - --fastmri_knee_singlecoil_test_url '' +--data_dir $DATA_DIR \ +--fastmri \ +--fastmri_knee_singlecoil_train_url '' \ +--fastmri_knee_singlecoil_val_url '' \ +--fastmri_knee_singlecoil_test_url '' ``` ## ImageNet @@ -99,11 +99,12 @@ framework flag. ```bash python3 datasets/dataset_setup.py \ - --data_dir=/data \ - --imagenet \ - --temp_dir=$DATA_DIR/tmp \ --imagenet_train_url=https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar \ - --imagenet_val_url=https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar \ - --framework=jax +--data_dir=/data \ +--imagenet \ +--temp_dir=$DATA_DIR/tmp \ +--imagenet_train_url= \ +--imagenet_val_url= Date: Fri, 1 Sep 2023 00:04:53 +0000 Subject: [PATCH 27/66] formatting --- datasets/dataset_setup.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2d01799ba..4017e56ca 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -427,14 +427,25 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + imagenet_jax_data_dir = os.path.join(data_dir, 'jax') + manual_download_dir = os.path.join(imagenet_jax_data_dir, + 'downloads', + 'manual') + imagenet_train_download_filepath = os.path.join(manual_download_dir, + IMAGENET_TRAIN_TAR_FILENAME) + imagenet_val_download_filepath = os.path.join(manual_download_dir, + IMAGENET_VAL_TAR_FILENAME) + # Download imagnet train dataset - if not os.path.exists(imagenet_train_filepath): + if not os.path.exists(imagenet_train_filepath) and not os.path.exists( + imagenet_train_download_filepath): logging.info( 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset - if not os.path.exists(imagenet_val_filepath): + if not os.path.exists(imagenet_val_filepath) and not os.path.exists( + imagenet_val_download_filepath): logging.info('Downloading imagenet validation dataset from {}'.format( imagenet_val_url)) _download_url(url=imagenet_val_url, data_dir=data_dir) @@ -626,8 +637,7 @@ def main(_): raise ValueError(f'Invalid data_dir: {data_dir}.') if any(s in tmp_dir for s in bad_chars): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') - if '~' in data_dir: - data_dir = os.path.abspath(os.path.expanduser(data_dir)) + data_dir = os.path.abspath(os.path.expanduser(data_dir)) logging.info('Downloading data to %s...', data_dir) if FLAGS.all or FLAGS.criteo1tb: From 0c33882f2fd4f5ab1ef44d6cfc7754f8cd811317 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 00:16:21 +0000 Subject: [PATCH 28/66] formatting --- datasets/dataset_setup.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 4017e56ca..83c4baf92 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -431,21 +431,19 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): manual_download_dir = os.path.join(imagenet_jax_data_dir, 'downloads', 'manual') - imagenet_train_download_filepath = os.path.join(manual_download_dir, + imagenet_train_download_filepath = os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) - imagenet_val_download_filepath = os.path.join(manual_download_dir, + imagenet_val_download_filepath = os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) - + # Download imagnet train dataset - if not os.path.exists(imagenet_train_filepath) and not os.path.exists( - imagenet_train_download_filepath): + if not os.path.exists(imagenet_train_filepath) and not os.path.exists(imagenet_train_download_filepath): logging.info( 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset - if not os.path.exists(imagenet_val_filepath) and not os.path.exists( - imagenet_val_download_filepath): + if not os.path.exists(imagenet_val_filepath) and not os.path.exists(imagenet_val_download_filepath): logging.info('Downloading imagenet validation dataset from {}'.format( imagenet_val_url)) _download_url(url=imagenet_val_url, data_dir=data_dir) From 3c418e77570d1ef1b0ea5aaceb3259850887175f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 18:30:51 +0000 Subject: [PATCH 29/66] move imagenet_v2 folder --- datasets/dataset_setup.py | 42 ++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 83c4baf92..e36826800 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -431,19 +431,21 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): manual_download_dir = os.path.join(imagenet_jax_data_dir, 'downloads', 'manual') - imagenet_train_download_filepath = os.path.join(manual_download_dir, + imagenet_train_download_filepath = os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) - imagenet_val_download_filepath = os.path.join(manual_download_dir, + imagenet_val_download_filepath = os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) - + # Download imagnet train dataset - if not os.path.exists(imagenet_train_filepath) and not os.path.exists(imagenet_train_download_filepath): + if not os.path.exists(imagenet_train_filepath) and not os.path.exists( + imagenet_train_download_filepath): logging.info( 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset - if not os.path.exists(imagenet_val_filepath) and not os.path.exists(imagenet_val_download_filepath): + if not os.path.exists(imagenet_val_filepath) and not os.path.exists( + imagenet_val_download_filepath): logging.info('Downloading imagenet validation dataset from {}'.format( imagenet_val_url)) _download_url(url=imagenet_val_url, data_dir=data_dir) @@ -466,6 +468,7 @@ def setup_imagenet(data_dir, framework=None): def setup_imagenet_jax(data_dir): train_tar_file_path = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + test_dir_path = os.path.join(data_dir, 'imagenet_v2') # Setup jax dataset dir imagenet_jax_data_dir = os.path.join(data_dir, 'jax') @@ -478,14 +481,19 @@ def setup_imagenet_jax(data_dir): logging.info('Checking if tar files already exists in jax/downloads/manual.') if not os.path.exists( os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Copying {} to {}'.format(train_tar_file_path, - manual_download_dir)) + logging.info('Moving {} to {}'.format(train_tar_file_path, + manual_download_dir)) shutil.move(train_tar_file_path, manual_download_dir) if not os.path.exists( os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Copying {} to {}'.format(val_tar_file_path, - manual_download_dir)) + logging.info('Moving {} to {}'.format(val_tar_file_path, + manual_download_dir)) shutil.move(val_tar_file_path, manual_download_dir) + if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + logging.info('Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + shutil.move(test_dir_path, + os.path.join(imagenet_jax_data_dir, 'imagenet_v2')) logging.info('Preparing imagenet data.') ds_builder = tfds.builder( 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)) @@ -496,6 +504,7 @@ def setup_imagenet_jax(data_dir): def setup_imagenet_pytorch(data_dir): train_tar_file_path = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + test_dir_path = os.path.join(data_dir, 'imagenet_v2') # Setup jax dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') @@ -503,13 +512,18 @@ def setup_imagenet_pytorch(data_dir): os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'train')) os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'val')) - # Copy tar file into pytorch directory - logging.info('Copying {} to {}'.format(train_tar_file_path, - imagenet_pytorch_data_dir)) + # Move tar files and imagenet_v2 into pytorch directory + logging.info('Moving {} to {}'.format(train_tar_file_path, + imagenet_pytorch_data_dir)) shutil.move(train_tar_file_path, imagenet_pytorch_data_dir) - logging.info('Copying {} to {}'.format(val_tar_file_path, - imagenet_pytorch_data_dir)) + logging.info('Moving {} to {}'.format(val_tar_file_path, + imagenet_pytorch_data_dir)) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) + if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + logging.info('Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + shutil.move(test_dir_path, + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')) # Extract train data\ logging.info('Extracting imagenet train data') From 024b7e47fc7c373b97518e3b85416d22745581d3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 20:21:41 +0000 Subject: [PATCH 30/66] update librispeech instructions --- datasets/README.md | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 8dfdc7328..025178659 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -99,12 +99,12 @@ framework flag. ```bash python3 datasets/dataset_setup.py \ ---data_dir=/data \ +--data_dir /data \ --imagenet \ ---temp_dir=$DATA_DIR/tmp \ ---imagenet_train_url= \ ---imagenet_val_url= \ +--imagenet_val_url Date: Fri, 1 Sep 2023 20:23:31 +0000 Subject: [PATCH 31/66] documentation fix --- datasets/README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 025178659..e3de6940a 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -17,9 +17,7 @@ The complete benchmark uses 6 datasets: - Librispeech -Some dataset setups will require you to sign a third party agreement with the -dataset in order to get the donwload URLs. - +Some dataset setups will require you to sign a third party agreement with the dataset owners in order to get the donwload URLs. # Per dataset instructions ## Environment @@ -94,8 +92,7 @@ ulimit -n 8192 ``` The imagenet data pipeline differs between the pytorch and jax workloads. -Therefore, you will have to specify the framework (pytorch or jax) through the -framework flag. +Therefore, you will have to specify the framework (pytorch or jax) through theframework flag. ```bash python3 datasets/dataset_setup.py \ From 254fdb97893aaef76b11fbbadb0d9381daf96dd7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 21:27:56 +0000 Subject: [PATCH 32/66] fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index e36826800..87d628888 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -55,7 +55,7 @@ Example command: python3 datasets/dataset_setup.py \ - --data_dir=~/dataa \ + --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ --imagenet_train_url= \ From ae93bce2f5837dd5d412336b9e7d0ea4941cdc25 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 21:41:37 +0000 Subject: [PATCH 33/66] undo unintional criteo download changes --- datasets/README.md | 10 ---------- datasets/dataset_setup.py | 6 +++--- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index e3de6940a..65a964e0e 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -110,16 +110,6 @@ dangerous if the user injects code into the --data_dir or --temp_dir flags. We do some basic sanitization in main(), but submitters should not let untrusted users run this script on their systems. -### Cleanup -In order to avoid potential accidental deletion, this script does NOT -delete any intermediate temporary files (such as zip archives) without a user -confirmation. Deleting temp files is particularly important for Criteo 1TB, as -there can be multiple copies of the dataset on disk during preprocessing if -files are not cleaned up. If you do not want any temp files to be deleted, you -can pass --interactive_deletion=false and then all files will be downloaded to -the provided --temp_dir, and the user can manually delete these after -downloading has finished. - ## Criteo1tb ```bash python3 datasets/dataset_setup.py \ diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 87d628888..8bc28d73b 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -302,12 +302,12 @@ def download_criteo1tb(data_dir, # Unzip the individual days. processes = [] - xz_paths = [] + gz_paths = [] for day in range(24): input_path = os.path.join(tmp_criteo_dir, f'day_{day}.xz') - xz_paths.append(input_path) + gz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') - unzip_cmd = (f'pixz -d -c -p{num_decompression_threads} "{input_path}" > ' + unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' f'"{unzipped_path}"') logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}') processes.append(subprocess.Popen(unzip_cmd, shell=True)) From e65822c54443de8b6cf0c1e11aa139d36c50aad2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 2 Sep 2023 00:49:03 +0000 Subject: [PATCH 34/66] fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8bc28d73b..8fc442b76 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -304,7 +304,7 @@ def download_criteo1tb(data_dir, processes = [] gz_paths = [] for day in range(24): - input_path = os.path.join(tmp_criteo_dir, f'day_{day}.xz') + input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz') gz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' From 260bfba542ed7e3b216b29c81b0fb3aca576b30a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 1 Sep 2023 20:52:27 -0700 Subject: [PATCH 35/66] Update README.md Fix documentation --- datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/README.md b/datasets/README.md index 65a964e0e..93d7d4b9e 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -135,7 +135,7 @@ To download, train a tokenizer and preprocess the librispeech dataset: python3 datasets/dataset_setup.py \ --data_dir librispeech \ --temp_dir $DATA_DIR/tmp \ ---criteo1tb +--librispeech ``` ### Notes on librispeech preprocessing From 9ecc7cfc22149b0fdba673feb49b4ab9a9a3b1aa Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Thu, 7 Sep 2023 17:51:07 +0000 Subject: [PATCH 36/66] critero traindiff fix --- tests/modeldiffs/criteo1tb/compare.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 761da427b..a3c49e559 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -35,7 +35,16 @@ def key_transform(k): return tuple(new_key) -sd_transform = None +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k]) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks,dim=0) + return out if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable From de747434b9ae9131b07a00d45bd412df4c6916e6 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Thu, 7 Sep 2023 17:58:15 +0000 Subject: [PATCH 37/66] style fix --- tests/modeldiffs/criteo1tb/compare.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index a3c49e559..56b15fa38 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -43,9 +43,10 @@ def sd_transform(sd): chunks.append(sd[k]) else: out[k] = sd[k] - out[('embedding_table',)] = torch.cat(chunks,dim=0) + out[('embedding_table',)] = torch.cat(chunks, dim=0) return out + if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable From 11997e9037de8bba878bb86173fde384fecc064a Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Thu, 7 Sep 2023 18:54:58 +0000 Subject: [PATCH 38/66] fixes --- tests/modeldiffs/criteo1tb/compare.py | 2 +- tests/test_traindiffs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 56b15fa38..9a95f3656 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -40,7 +40,7 @@ def sd_transform(sd): chunks = [] for k in sd: if 'embedding_chunk' in ''.join(k): - chunks.append(sd[k]) + chunks.append(sd[k].cpu()) else: out[k] = sd[k] out[('embedding_table',)] = torch.cat(chunks, dim=0) diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index fec1f9085..a1b64a573 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -42,14 +42,14 @@ def test_workload(self): jax_logs = '/tmp/jax_log.pkl' pyt_logs = '/tmp/pyt_log.pkl' run( - f'python3 tests/reference_algorithm_tests.py --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}' + f'python3 -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}' f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL, stderr=STDOUT, check=True) run( - f'torchrun --standalone --nnodes 1 --nproc_per_node 8 tests/reference_algorithm_tests.py --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}' + f'torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}' f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL, From c15adb0708788a76b86d9366ac4fd134495ed428 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sat, 9 Sep 2023 22:33:25 +0000 Subject: [PATCH 39/66] fix wmt comparator --- tests/modeldiffs/wmt/compare.py | 36 ++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 52c96481c..382f2bf26 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -47,20 +47,38 @@ def sd_transform(sd): out = {} for k in sd: k_str = ''.join(k) - if 'Dense' in k_str: - new_key = (*k[:2], 'MlpBlock_0', *k[2:]) - out[new_key] = sd[k] - elif 'SelfAttention' in k_str: + if 'SelfAttention' in k_str: new_key = list(k) - if '_' in new_key[-1]: - qkv = {'q': 'query', 'k': 'key', 'v': 'value'}[new_key[-1][0]] - new_key[-1] = qkv - new_key.append('kernel') new_key = [ i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' for i in new_key ] - new_key = tuple(new_key) + if 'SelfAttention_0' in k_str: + if new_key[-2] == 'Dense_0': + # qkv + for name, value in zip(('query','key','value'),sd[k].chunk(3)): + out[(*new_key[:-2],name,new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_1': + # out + out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + pass + else: + if new_key[-2] == 'Dense_0': + #q + out[(*new_key[:-2],'query',new_key[-1])] = sd[k] + pass + elif new_key[-2] == 'Dense_1': + # kv + for name, value in zip(('key','value'),sd[k].chunk(2)): + out[(*new_key[:-2],name,new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_2': + # out + out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + pass + elif 'Dense' in k_str: + new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] elif 'LayerNorm' in k_str: new_key = list(k) From ec96fbed552f46ee4802ca9b8595a78351684ffa Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sat, 9 Sep 2023 22:36:20 +0000 Subject: [PATCH 40/66] comparator fix --- tests/modeldiffs/wmt/compare.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 382f2bf26..806022687 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -56,26 +56,26 @@ def sd_transform(sd): if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv - for name, value in zip(('query','key','value'),sd[k].chunk(3)): - out[(*new_key[:-2],name,new_key[-1])] = value + for name, value in zip(('query', 'key', 'value'), sd[k].chunk(3)): + out[(*new_key[:-2], name, new_key[-1])] = value pass elif new_key[-2] == 'Dense_1': # out - out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass else: if new_key[-2] == 'Dense_0': #q - out[(*new_key[:-2],'query',new_key[-1])] = sd[k] - pass + out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] + pass elif new_key[-2] == 'Dense_1': - # kv - for name, value in zip(('key','value'),sd[k].chunk(2)): - out[(*new_key[:-2],name,new_key[-1])] = value - pass + # kv + for name, value in zip(('key', 'value'), sd[k].chunk(2)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass elif new_key[-2] == 'Dense_2': # out - out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From 04a83802d620c6237a257f5dea5f6781c8aa6d2a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:12:48 +0000 Subject: [PATCH 41/66] fix arg for deletion prompt --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8fc442b76..77885d660 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -298,7 +298,7 @@ def download_criteo1tb(data_dir, logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}') p = subprocess.Popen(unzip_cmd, shell=True) p.communicate() - _maybe_prompt_for_deletion(all_days_zip_filepath, interactive_deletion) + _maybe_prompt_for_deletion([all_days_zip_filepath], interactive_deletion) # Unzip the individual days. processes = [] From 2f76cb9e324258fe2e974379faa2f5ed4702255a Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 13 Sep 2023 15:56:25 +0200 Subject: [PATCH 42/66] Simplify pad function --- algorithmic_efficiency/data_utils.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 38744716b..96fc699c0 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -51,7 +51,7 @@ def _prepare(x): # Pad if remainder_size != 0 (should only be possible during evaluation). if remainder_size != 0: - x = pad(x, pad_size, 'jax', padding_value=padding_value) + x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -61,21 +61,13 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: spec.Tensor, +def pad(tensor: np.ndarray, pad_size: int, - framework: str, - padding_value: int = 0) -> spec.Tensor: + padding_value: int = 0) -> np.ndarray: if len(tensor) > 1: pad_size = (pad_size, *tensor.shape[1:]) - if framework == 'pytorch': - padding = torch.full( - pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) - padded_tensor = torch.cat((tensor, padding), dim=0) - elif framework == 'jax': - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) - else: - raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) return padded_tensor From 1a3679d3f25e1289dcc3615a43fbf68a58959c5c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 17:29:49 +0000 Subject: [PATCH 43/66] move delete prompt to end of criteo download --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 77885d660..dfcecbd85 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -316,9 +316,9 @@ def download_criteo1tb(data_dir, _maybe_prompt_for_deletion(gz_paths, interactive_deletion) # Split into files with 5M lines each: day_1.csv -> day_1_[0-39].csv. + unzipped_paths = [] for batch in range(6): batch_processes = [] - unzipped_paths = [] for day_offset in range(4): day = batch * 4 + day_offset unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') @@ -330,7 +330,7 @@ def download_criteo1tb(data_dir, batch_processes.append(subprocess.Popen(split_cmd, shell=True)) for p in batch_processes: p.communicate() - _maybe_prompt_for_deletion(unzipped_paths, interactive_deletion) + _maybe_prompt_for_deletion(unzipped_paths, interactive_deletion) def download_cifar(data_dir, framework): From 35c873615d65cc1c435bf2b2bd3dce51dcd17de4 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 14 Sep 2023 18:27:20 +0200 Subject: [PATCH 44/66] Always pad to global_batch_size when it is provided --- algorithmic_efficiency/data_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 96fc699c0..245d3768e 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] + if global_batch_size is not None: + assert global_batch_size >= current_batch_size, \ + 'global_batch_size must be larger than or equal to current_batch_size.' + # Always pad to global_batch_size if it is provided. + pad_to_global_batch_size = global_batch_size > current_batch_size + else: + pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -50,7 +57,7 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to From ad64fd18c9e3907f32347e49953a7067293fe36b Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 14 Sep 2023 18:28:07 +0200 Subject: [PATCH 45/66] Fix pad_size in pad function --- algorithmic_efficiency/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 245d3768e..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -71,7 +71,7 @@ def _prepare(x): def pad(tensor: np.ndarray, pad_size: int, padding_value: int = 0) -> np.ndarray: - if len(tensor) > 1: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) padding = np.full(pad_size, padding_value, dtype=tensor.dtype) padded_tensor = np.concatenate((tensor, padding), axis=0) From efdd670c336b74c4a06c59f47d4fa27013a045e1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 19:53:56 +0000 Subject: [PATCH 46/66] librispeech processing --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index dfcecbd85..8099fe7cc 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -584,7 +584,7 @@ def download_librispeech(dataset_dir, tmp_dir): subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz') subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', + f'tar xzvf {tar_path} --directory {extracted_data_dir}', shell=True).communicate() tars = [ @@ -599,7 +599,7 @@ def download_librispeech(dataset_dir, tmp_dir): subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, tar_filename) subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', + f'tar xzvf {tar_path} --directory {extracted_data_dir}', shell=True).communicate() tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab') From 26713bcbed056c9c6232c5c998a1aa5b186856da Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 20:27:12 +0000 Subject: [PATCH 47/66] fix --- datasets/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8099fe7cc..71efa3434 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -575,6 +575,8 @@ def download_librispeech(dataset_dir, tmp_dir): final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) + _maybe_mkdir(extracted_data_dir) + _maybe_mkdir(final_data_dir) for split in ['dev', 'test']: for version in ['clean', 'other']: From f3881daf7775f268f86892163e49f0d42986dc11 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 20:49:49 +0000 Subject: [PATCH 48/66] librispeech fix --- datasets/dataset_setup.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 71efa3434..b6ae48378 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -571,11 +571,10 @@ def download_librispeech(dataset_dir, tmp_dir): # extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') # final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw') - extracted_data_dir = os.path.join(tmp_dir, 'librispeech_extracted') + extracted_data_dir = os.path.join(tmp_dir, 'LibriSpeech) final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) - _maybe_mkdir(extracted_data_dir) _maybe_mkdir(final_data_dir) for split in ['dev', 'test']: @@ -586,7 +585,7 @@ def download_librispeech(dataset_dir, tmp_dir): subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz') subprocess.Popen( - f'tar xzvf {tar_path} --directory {extracted_data_dir}', + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() tars = [ @@ -601,7 +600,7 @@ def download_librispeech(dataset_dir, tmp_dir): subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, tar_filename) subprocess.Popen( - f'tar xzvf {tar_path} --directory {extracted_data_dir}', + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab') From fd710ab40df007640072cc953d5274f27b4057b1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 20:52:00 +0000 Subject: [PATCH 49/66] syntax fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index b6ae48378..df3ba22fe 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -571,7 +571,7 @@ def download_librispeech(dataset_dir, tmp_dir): # extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') # final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw') - extracted_data_dir = os.path.join(tmp_dir, 'LibriSpeech) + extracted_data_dir = os.path.join(tmp_dir, 'LibriSpeech') final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) From fa0862601c72fa657662562366e59243807895a6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 23:25:33 +0000 Subject: [PATCH 50/66] fix --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index df3ba22fe..fe5e2a9a0 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -570,8 +570,8 @@ def download_librispeech(dataset_dir, tmp_dir): # tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech') # extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') # final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') - tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw') - extracted_data_dir = os.path.join(tmp_dir, 'LibriSpeech') + tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') + extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) From e9119b9f5084827820a631c3dfd57aa73f427c7a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 14 Sep 2023 23:26:20 +0000 Subject: [PATCH 51/66] documentation --- datasets/dataset_setup.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index fe5e2a9a0..e7f8c1d13 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -567,9 +567,6 @@ def download_librispeech(dataset_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. - # tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech') - # extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - # final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') final_data_dir = os.path.join(dataset_dir, 'librispeech') From ae9d46f0e7fea1fe77b41d8a0ce94099f121bf5b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 15 Sep 2023 00:29:56 +0000 Subject: [PATCH 52/66] typo fix --- datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/README.md b/datasets/README.md index 93d7d4b9e..5ff0e18a7 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -100,7 +100,7 @@ python3 datasets/dataset_setup.py \ --imagenet \ --temp_dir $DATA_DIR/tmp \ --imagenet_train_url \ ---imagenet_val_url \ --framework jax ``` From 241e546dc1b737e066a054cac68b41a43c8da921 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 15 Sep 2023 17:53:10 +0000 Subject: [PATCH 53/66] add test-other counts to librispeech preprocessing --- datasets/librispeech_preprocess.py | 1 + 1 file changed, 1 insertion(+) diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index 0968f2a00..acdaa8e98 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -32,6 +32,7 @@ 'train-clean-360': 104014, 'train-other-500': 148688, 'test-clean': 2620, + 'test-other': 2939, 'dev-clean': 2703, 'dev-other': 2864, } From df542c2ea88bd7d89ef9547596005fa7ff9ae140 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 18 Sep 2023 23:48:28 +0000 Subject: [PATCH 54/66] add rng_seed flag and save seed to metadata --- algorithmic_efficiency/logger_utils.py | 8 ++++++++ submission_runner.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index af2e61581..39c039f18 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,6 +275,14 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data +def save_meta_data(workload: spec.Workload, + rng_seed: int, + preemption_count: int): + meta_data = get_meta_data(workload) + meta_data.update({'rng_seed': rng_seed}) + meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') + write_json(meta_file_name, meta_data) + class MetricLogger(object): """Used to log all measurements during training. diff --git a/submission_runner.py b/submission_runner.py index f4ee32ede..bed3d1e22 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,6 +133,10 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer('rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -267,10 +271,8 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) - meta_data = logger_utils.get_meta_data(workload) - meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.write_json(meta_file_name, meta_data) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -449,7 +451,8 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True): + save_checkpoints: Optional[bool] = True, + rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -496,7 +499,8 @@ def score_submission_on_workload(workload: spec.Workload, all_metrics = [] for hi, hyperparameters in enumerate(tuning_search_space): # Generate a new seed from hardware sources of randomness for each trial. - rng_seed = struct.unpack('I', os.urandom(4))[0] + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) rng = prng.PRNGKey(rng_seed) # Because we initialize the PRNGKey with only a single 32 bit int, in the @@ -610,7 +614,8 @@ def main(_): tuning_search_space=FLAGS.tuning_search_space, num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints) + save_checkpoints=FLAGS.save_checkpoints, + rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From 5ff2ec2f7affd2091fcddc1b227122f34fae2a78 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 18 Sep 2023 23:56:13 +0000 Subject: [PATCH 55/66] fix --- algorithmic_efficiency/logger_utils.py | 3 +-- submission_runner.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 39c039f18..559859515 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -277,10 +277,9 @@ def get_meta_data(workload: spec.Workload) -> dict: def save_meta_data(workload: spec.Workload, rng_seed: int, - preemption_count: int): + meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) - meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') write_json(meta_file_name, meta_data) class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index bed3d1e22..1f4bcf603 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -271,6 +271,7 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) + meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') From df016238f4d9b7e1a6ff21d90e155104ea1fffd1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:02:52 +0000 Subject: [PATCH 56/66] fix --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 1f4bcf603..8f55cd882 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -273,7 +273,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng_seed, preemption_count) + logger_utils.save_meta_data(workload, rng, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) From 87ecd5b477b4e99484fe53216fff2282651155b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:05:58 +0000 Subject: [PATCH 57/66] debug --- algorithmic_efficiency/logger_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 559859515..dcc8754a9 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -279,7 +279,7 @@ def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) + # meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) class MetricLogger(object): From 828765cbf400e05641a2641beb548baea6a60939 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:12:08 +0000 Subject: [PATCH 58/66] fix --- algorithmic_efficiency/logger_utils.py | 2 +- submission_runner.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index dcc8754a9..559859515 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -279,7 +279,7 @@ def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - # meta_data.update({'rng_seed': rng_seed}) + meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index 8f55cd882..fb8df198c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -177,6 +177,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -273,7 +274,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng, preemption_count) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -533,7 +534,9 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, rng, + hyperparameters, + rng_seed, + rng, profiler, max_global_steps, tuning_dir_name, @@ -559,7 +562,7 @@ def score_submission_on_workload(workload: spec.Workload, workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - None, rng, profiler, max_global_steps, log_dir, + None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) return score From f861353ccc7bf7a50b064ebb207163758ba13074 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:43:04 +0000 Subject: [PATCH 59/66] lint fix --- algorithmic_efficiency/logger_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 559859515..18652dcaa 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,8 +275,8 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data -def save_meta_data(workload: spec.Workload, - rng_seed: int, +def save_meta_data(workload: spec.Workload, + rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) From 18a8c20362b71c3a90361240832380348fcb7cfc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:53:16 +0000 Subject: [PATCH 60/66] pylint --- submission_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index fb8df198c..af3741812 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -134,7 +134,7 @@ True, 'Whether or not to checkpoint the model at every eval.') flags.DEFINE_integer('rng_seed', - None, + None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') FLAGS = flags.FLAGS @@ -177,7 +177,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], - rng_seed: int, + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -534,7 +534,7 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, + hyperparameters, rng_seed, rng, profiler, From 25b05b848a52ce4423772ba5843c69d1fca3414d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:55:57 +0000 Subject: [PATCH 61/66] formatting --- algorithmic_efficiency/logger_utils.py | 5 ++--- submission_runner.py | 9 +++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 18652dcaa..2b3cf86f6 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,13 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data -def save_meta_data(workload: spec.Workload, - rng_seed: int, - meta_file_name: str): +def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) + class MetricLogger(object): """Used to log all measurements during training. diff --git a/submission_runner.py b/submission_runner.py index af3741812..8096eeda3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,10 +133,11 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') -flags.DEFINE_integer('rng_seed', - None, - 'Value of rng seed. If None, a random seed will' - 'be generated from hardware.') +flags.DEFINE_integer( + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() From d54b8660c8a05f7ada6a0385cfc212bc2fa0115b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 20 Sep 2023 22:18:17 +0000 Subject: [PATCH 62/66] pass rng_seed arg for self-tuning submission as well --- submission_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 8096eeda3..2289d39d3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -554,7 +554,8 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: - rng_seed = struct.unpack('q', os.urandom(8))[0] + if not rng_seed: + rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) # If the submission is responsible for tuning itself, we only need to run it # once and return the total time. From a7b60fa9fead9b453245dce35e5487deacef09be Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Sep 2023 00:34:16 +0000 Subject: [PATCH 63/66] pin ml_dytpes version --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 6f53cd51b..a7ce5ebb2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,6 +115,7 @@ jax_core_deps = # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. chex==0.1.7 + ml_dtypes==0.2.0 # JAX CPU jax_cpu = From 33a8a9fccb114a5f7cfe0d4aea98a0b38d0bb9e7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Sep 2023 21:00:32 +0000 Subject: [PATCH 64/66] add guards for cuda context initializion --- algorithmic_efficiency/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index d52a532b2..fa2a1bee2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -15,7 +15,7 @@ def _get_monotonic_time() -> float: - if torch.cuda.is_available(): + if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() return time.monotonic() From cc8b820af5d698befe4fd19af4152964f1614135 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 25 Sep 2023 15:44:39 -0400 Subject: [PATCH 65/66] minor --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 5 +++-- algorithmic_efficiency/workloads/criteo1tb/workload.py | 7 +++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 993d82c9d..c514d0a9c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -136,6 +136,8 @@ def _build_input_queue( cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + del num_batches + not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) @@ -147,7 +149,6 @@ def _build_input_queue( split=split, data_dir=data_dir, global_batch_size=global_batch_size, - num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None while True: @@ -233,7 +234,7 @@ def _eval_batch(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - return summed_loss + return summed_loss.to(dtype=torch.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 801716de7..b341d1022 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 89_000_000 + return 83_274_637 @property def num_test_examples(self) -> int: - return 89_274_637 + return 95_000_000 @property def train_mean(self): @@ -95,13 +95,13 @@ def _build_input_queue( repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: del cache + del num_batches ds = input_pipeline.get_criteo1tb_dataset( split=split, shuffle_rng=data_rng, data_dir=data_dir, num_dense_features=self.num_dense_features, global_batch_size=global_batch_size, - num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) for batch in iter(ds): @@ -132,7 +132,6 @@ def _eval_model_on_split(self, split=split, data_dir=data_dir, global_batch_size=global_batch_size, - num_batches=num_batches, repeat_final_dataset=True) loss = 0.0 for _ in range(num_batches): From 86ad0af2196741fd813237281c9563612d9f3294 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 25 Sep 2023 16:34:58 -0400 Subject: [PATCH 66/66] Add num_batch configs --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 3 +-- algorithmic_efficiency/workloads/criteo1tb/workload.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index c514d0a9c..55b68fb2f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -136,8 +136,6 @@ def _build_input_queue( cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: - del num_batches - not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) @@ -149,6 +147,7 @@ def _build_input_queue( split=split, data_dir=data_dir, global_batch_size=global_batch_size, + num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None while True: diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index b341d1022..ef971bb75 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -95,13 +95,13 @@ def _build_input_queue( repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: del cache - del num_batches ds = input_pipeline.get_criteo1tb_dataset( split=split, shuffle_rng=data_rng, data_dir=data_dir, num_dense_features=self.num_dense_features, global_batch_size=global_batch_size, + num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) for batch in iter(ds): @@ -132,6 +132,7 @@ def _eval_model_on_split(self, split=split, data_dir=data_dir, global_batch_size=global_batch_size, + num_batches=num_batches, repeat_final_dataset=True) loss = 0.0 for _ in range(num_batches):