Skip to content

Commit

Permalink
fix for black
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanyeung committed Feb 18, 2024
1 parent b070d04 commit 809bdb0
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 56 deletions.
161 changes: 120 additions & 41 deletions egs/librispeech/SSL/hubert/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def __init__(
self.encoder_decoder_attention = encoder_decoder_attention

assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
"Self-attention requires query, key and "
"value to be of the same size"
)

self.k_proj = quant_noise(
Expand Down Expand Up @@ -217,22 +218,36 @@ def _get_reserve_head_index(self, num_heads_to_keep: int):
start_idx = i * self.head_dim
end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append(
torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx,])).tolist()
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
torch.sum(
torch.abs(self.k_proj.weight[start_idx:end_idx,])
).tolist()
+ torch.sum(
torch.abs(self.k_proj.bias[start_idx:end_idx])
).tolist()
)
q_proj_heads_norm.append(
torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx,])).tolist()
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
torch.sum(
torch.abs(self.q_proj.weight[start_idx:end_idx,])
).tolist()
+ torch.sum(
torch.abs(self.q_proj.bias[start_idx:end_idx])
).tolist()
)
v_proj_heads_norm.append(
torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx,])).tolist()
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
torch.sum(
torch.abs(self.v_proj.weight[start_idx:end_idx,])
).tolist()
+ torch.sum(
torch.abs(self.v_proj.bias[start_idx:end_idx])
).tolist()
)

heads_norm = []
for i in range(self.num_heads):
heads_norm.append(
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
k_proj_heads_norm[i]
+ q_proj_heads_norm[i]
+ v_proj_heads_norm[i]
)

sorted_head_index = sorted(
Expand Down Expand Up @@ -266,7 +281,9 @@ def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
new_v_weight.append(self.v_proj.weight[start_idx:end_idx,])
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])

new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
new_out_proj_weight.append(
self.out_proj.weight[:, start_idx:end_idx]
)

new_q_weight = torch.cat(new_q_weight).detach()
new_k_weight = torch.cat(new_k_weight).detach()
Expand Down Expand Up @@ -313,7 +330,9 @@ def _pad_masks(
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if attn_mask is not None:
shape = attn_mask.size()[:-1] + torch.Size([1])
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(shape)], dim=-1
)
if key_padding_mask is not None:
shape = key_padding_mask.size()[:-1] + torch.Size([1])
key_padding_mask = torch.cat(
Expand Down Expand Up @@ -351,10 +370,12 @@ def _append_zero_attn(
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
dim=-2,
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
dim=-2,
)
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
Expand All @@ -367,7 +388,9 @@ def forward(
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
incremental_state: Optional[
Dict[str, Dict[str, Optional[Tensor]]]
] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -432,15 +455,19 @@ def forward(
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
torch.cat(
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask.bool() if key_padding_mask is not None else None,
key_padding_mask.bool()
if key_padding_mask is not None
else None,
need_weights,
attn_mask,
use_separate_proj_weight=True,
Expand All @@ -455,7 +482,10 @@ def forward(
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
assert (
self.encoder_decoder_attention
and not self.self_attention
)
key = value = None
else:
saved_state = None
Expand All @@ -473,9 +503,9 @@ def forward(
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
:, :, 0, :
]
key = key.view(
key.size(0), -1, self.beam_size, key.size(2)
)[:, :, 0, :]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1)
Expand Down Expand Up @@ -522,7 +552,9 @@ def forward(
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
kv_bsz = _prev_key.size(0)
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
prev_key = _prev_key.view(
kv_bsz * self.num_heads, -1, self.head_dim
)
if static_kv:
k = prev_key
else:
Expand Down Expand Up @@ -553,14 +585,18 @@ def forward(
static_kv=static_kv,
)

saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key"] = k.view(
kv_bsz, self.num_heads, -1, self.head_dim
)
saved_state["prev_value"] = v.view(
kv_bsz, self.num_heads, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
incremental_state = self._set_input_buffer(
incremental_state, saved_state
)
assert k is not None
assert k.size(1) == src_len

Expand All @@ -586,12 +622,20 @@ def forward(
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
)
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
attn_weights = attn_weights.reshape(
(-1,) + attn_weights.size()[-2:]
)
else:
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
attn_weights = self.apply_sparse_mask(
attn_weights, tgt_len, src_len, bsz
)

assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
assert list(attn_weights.size()) == [
bsz * self.num_heads,
tgt_len,
src_len,
]

if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
Expand All @@ -601,7 +645,9 @@ def forward(

if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
if not is_tpu:
attn_weights = attn_weights.view(
kv_bsz, -1, self.num_heads, tgt_len, src_len
Expand All @@ -615,9 +661,13 @@ def forward(
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.masked_fill(
key_padding_mask, float("-inf")
)
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(
bsz * self.num_heads, tgt_len, src_len
)

if before_softmax:
return attn_weights, v
Expand Down Expand Up @@ -652,13 +702,21 @@ def forward(
attn = attn.reshape((-1,) + attn.size()[-2:])
else:
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
assert list(attn.size()) == [
bsz * self.num_heads,
tgt_len,
self.head_dim,
]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = (
attn.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, self.embed_dim)
)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
Expand Down Expand Up @@ -728,7 +786,9 @@ def reorder_incremental_state(
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention:
if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
if input_buffer_k.size(
0
) * self.beam_size == new_order.size(0):
return incremental_state
elif self.beam_size > 1:
input_buffer[k] = input_buffer_k.index_select(
Expand All @@ -737,18 +797,25 @@ def reorder_incremental_state(
// self.beam_size,
)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
input_buffer[k] = input_buffer_k.index_select(
0, new_order
)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
input_buffer[k] = input_buffer_k.index_select(
0, new_order
)
incremental_state = self._set_input_buffer(
incremental_state, input_buffer
)
return incremental_state

def set_beam_size(self, beam_size):
"""Used for effiecient beamable enc-dec attention"""
self.beam_size = beam_size

def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
Expand All @@ -762,9 +829,13 @@ def _set_input_buffer(
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
return self.set_incremental_state(
incremental_state, "attn_state", buffer
)

def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
def apply_sparse_mask(
self, attn_weights, tgt_len: int, src_len: int, bsz: int
):
return attn_weights

def upgrade_state_dict_named(self, state_dict, name):
Expand All @@ -776,19 +847,27 @@ def upgrade_state_dict_named(self, state_dict, name):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][
2 * dim :
]

keys_to_remove.append(k)

k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][
:dim
]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][
2 * dim :
]

keys_to_remove.append(prefix + "in_proj_bias")

Expand Down
4 changes: 3 additions & 1 deletion egs/librispeech/SSL/hubert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,9 @@ def save_bad_model(suffix: str = ""):
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)

if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
Expand Down
4 changes: 3 additions & 1 deletion egs/librispeech/SSL/hubert/finetune_ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,9 @@ def save_bad_model(suffix: str = ""):
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)

if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
Expand Down
Loading

0 comments on commit 809bdb0

Please sign in to comment.