From d047d05903332eb418e798f7ed3bcce01e61250e Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sun, 15 Oct 2023 00:24:50 +0000 Subject: [PATCH 1/2] conformer oom fixes --- .../librispeech_pytorch/models.py | 298 ++++-------------- .../librispeech_pytorch/workload.py | 12 +- submission_runner.py | 4 +- .../librispeech_conformer/compare.py | 10 +- 4 files changed, 78 insertions(+), 246 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 665c3c894..2da7dcfb3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -71,7 +71,10 @@ def forward(self, x): class Subsample(nn.Module): - def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): + def __init__(self, + encoder_dim: int = 0, + input_dropout_rate: float = 0.0, + num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim self.input_dropout_rate = input_dropout_rate @@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): self.conv2 = Conv2dSubsampling( input_channels=encoder_dim, output_channels=encoder_dim) - self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True) + self.linear = nn.Linear( + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) self.dropout = nn.Dropout(p=self.input_dropout_rate) @@ -123,6 +129,7 @@ def __init__(self, self.kernel = nn.Parameter( torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) self.bias = nn.Parameter(torch.zeros(output_channels)) + self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) def get_same_padding(self, input_shape): in_height, in_width = input_shape[2:] @@ -162,15 +169,11 @@ def forward(self, inputs, paddings): input_length = paddings.shape[1] stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length - padded_paddings = torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) - ], - dim=2) + padded_paddings = F.pad( + paddings[:, None, :], (0, pad_len), mode='constant', value=0) out_padding = F.conv1d( input=padded_paddings, - weight=torch.ones([1, 1, 1], device=paddings.device), + weight=self.paddings_kernel, stride=self.filter_stride[:1]) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) @@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig): self.config = config self.ln = LayerNorm(dim=config.encoder_dim) - self.linear1 = nn.LazyLinear( + self.linear1 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim * config.feed_forward_expansion_factor, bias=True) self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) - self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True) + self.linear2 = nn.Linear( + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True) if config.feed_forward_residual_dropout_rate is None: feed_forward_residual_dropout_rate = 0.1 @@ -253,217 +260,32 @@ def forward(self, inputs): return inputs * scale -class MHSAwithQS(nn.MultiheadAttention): - # pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name +class MHSAwithQS(nn.Module): + def __init__(self, config: ConformerConfig): - super().__init__( - embed_dim=config.encoder_dim, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout_rate, - bias=True, - batch_first=True) + super().__init__() + self.embed_dim = config.encoder_dim + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout_rate + self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) + self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) - def _scaled_in_proj_weight(self): - # Scale the query projection weight. - qs_input = self.in_proj_weight[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2) - in_proj_queryW_scaled = self.qs(qs_input).transpose( - 1, 2).view(*self.in_proj_weight[:self.embed_dim].shape) - in_proj_weight = torch.cat( - [in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]]) - return in_proj_weight - - def _scaled_in_proj_bias(self): - # Scale the query bias. - in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads)).view(-1) - in_proj_bias = torch.cat( - [in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]]) - return in_proj_bias - - def forward(self, - query, - key, - value, - key_padding_mask=None, - need_weights: bool = True, - attn_mask=None, - average_attn_weights: bool = True): - r""" - Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and byte masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - .. note:: - `batch_first` argument is ignored for unbatched inputs. - """ - is_batched = query.dim() == 3 - if key_padding_mask is not None: - _kpm_dtype = key_padding_mask.dtype - if _kpm_dtype != torch.bool and not torch.is_floating_point( - key_padding_mask): - raise AssertionError( - "only bool and floating types of key_padding_mask are supported") - why_not_fast_path = '' - if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif attn_mask is not None: - why_not_fast_path = "attn_mask was not None" - elif query.is_nested and key_padding_mask is not None: - why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" - elif self.num_heads % 2 == 1: - why_not_fast_path = "num_heads is odd" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - - if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) - for x in tensor_args]): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any( - [x is not None and x.requires_grad for x in tensor_args]): - why_not_fast_path = ( - "grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") - if not why_not_fast_path: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - return torch._native_multi_head_attention( - query, - key, - value, - self.embed_dim, - self.num_heads, - in_proj_weight, - in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - key_padding_mask if key_padding_mask is not None else attn_mask, - need_weights, - average_attn_weights, - 1 if key_padding_mask is not None else - 0 if attn_mask is not None else None) - any_nested = query.is_nested or key.is_nested or value.is_nested - assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + - f"The fast path was not hit because {why_not_fast_path}") - - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - - if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) - else: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - in_proj_weight, in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, average_attn_weights=average_attn_weights) - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights - else: - return attn_output, attn_output_weights + def forward(self, inputs, key_padding_mask=None): + batch_size, seq_len, embed_dim = inputs.shape + q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) + q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=~key_padding_mask[:, None, None], + dropout_p=self.dropout, + ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + out = self.out_proj(out) + return out class MultiHeadedSelfAttention(nn.Module): @@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig): def forward(self, outputs, paddings): outputs = self.ln(outputs) - outputs, _ = self.self_attention( - query=outputs, - key=outputs, - value=outputs, - key_padding_mask=paddings==1, - need_weights=False, + outputs = self.self_attention( + outputs, + key_padding_mask=paddings == 1, ) outputs = self.dropout(outputs) return outputs @@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig): self.register_buffer('running_var', running_var) self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) - self.register_buffer('momentum', - torch.FloatTensor([config.batch_norm_momentum])) - self.register_buffer('epsilon', - torch.FloatTensor([config.batch_norm_epsilon])) + self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) - # self.momentum = config.batch_norm_momentum - # self.epsilon = config.batch_norm_epsilon - # self.dim = config.encoder_dim + self.momentum = config.batch_norm_momentum + self.epsilon = config.batch_norm_epsilon def forward(self, inputs, input_paddings): #inputs: NHD #padding: NH + """ + Alternatively: + inputs[input_paddings==0] = F.batch_norm( + input = inputs[input_paddings==0], + running_mean = self.running_mean, + running_var = self.running_var, + weight = 1+self.scale, + bias = self.bias, + training = self.training, + momentum=1-self.momentum, + eps=self.epsilon + ) + inputs.masked_fill(input_paddings[...,None] != 0, 0) + return inputs + """ mask = 1 - input_paddings[:, :, None] if self.training: count = mask.sum() @@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig): else: input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( - encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate) + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate, + num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 24f4eb1fc..c4f4a1247 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -47,8 +47,11 @@ def init_model_fn( input_dropout_rate. """ torch.random.manual_seed(rng[0]) - # Disable cudnn benchmark to avoid OOM errors. + # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) model = conformer_model.ConformerEncoderDecoder( conformer_model.ConformerConfig( attention_residual_dropout_rate=dropout_rate, @@ -57,13 +60,6 @@ def init_model_fn( input_dropout_rate=aux_dropout_rate, use_specaug=self.use_specaug)) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - # Run model once to initialize lazy layers. - # Run the initialization in eval mode to disable BN tracking. - model = model.eval() - t = MAX_INPUT_LENGTH - wave = torch.randn((2, t)) - pad = torch.zeros_like(wave) - _ = model(wave, pad) conformer_model.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..d7f28bc22 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -206,9 +206,9 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb'] + compile_error_workloads = ['librispeech_conformer', 'ogbg', 'criteo1tb'] eager_backend_workloads = [ - 'librispeech_conformer', 'librispeech_deepspeech' + 'librispeech_deepspeech' ] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 1d243d83e..d414001dd 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -38,11 +38,15 @@ def sd_transform(sd): out = {} for k in sd: if 'Attention' in ''.join(k): - if 'in_proj' in k[-1]: - new_key = k[:-1] + if 'Dense_0' in k[-2]: + # In-proj + new_key = k[:-2] chunks = sd[k].chunk(3) for t, c in zip(['query', 'key', 'value'], chunks): - out[new_key + (t, k[-1].split('_')[-1])] = c + out[new_key + (t, k[-1])] = c + elif 'Dense_1' in k[-2]: + # Out-proj + out[(*k[:-2], 'out', k[-1])] = sd[k] else: out[k] = sd[k] else: From 28a1ff039a8f096fbb9360825e6ede88a79030ce Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sun, 15 Oct 2023 00:43:45 +0000 Subject: [PATCH 2/2] style fix --- submission_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index d7f28bc22..6577204f2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -207,9 +207,7 @@ def train_once( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = ['librispeech_conformer', 'ogbg', 'criteo1tb'] - eager_backend_workloads = [ - 'librispeech_deepspeech' - ] + eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning(