-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
124dcf8
commit abc239c
Showing
9 changed files
with
39 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
|
||
class WhisperDecoderInitOpenai(torch.nn.Module): | ||
"""WhisperDecoderInit for Openai.""" | ||
|
||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
|
@@ -43,13 +44,12 @@ def forward( | |
audio_features, | ||
past=None, | ||
): | ||
|
||
# Create a kv_cache for past_values | ||
past_kv_cache = dict() | ||
if past is not None: | ||
# Convert past values from 4D to 3D | ||
past = [torch.transpose(val, 1, 2) for val in past] | ||
past = [val.reshape(val.shape[:2] + (-1, )) for val in past] | ||
past = [val.reshape(val.shape[:2] + (-1,)) for val in past] | ||
half_idx = len(past) // 2 | ||
for idx, block in enumerate(self.whisper_decoder.blocks): | ||
past_kv_cache[block.attn.key] = past[2 * idx] | ||
|
@@ -65,8 +65,12 @@ def forward( | |
# Add concat node for past values | ||
if past is not None: | ||
for idx, block in enumerate(self.whisper_decoder.blocks): | ||
Check warning Code scanning / lintrunner RUFF/B007 Warning
Loop control variable idx not used within loop body.
See https://docs.astral.sh/ruff/rules/unused-loop-control-variable |
||
self.kv_cache[block.attn.key] = torch.cat([past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1).detach() | ||
self.kv_cache[block.attn.value] = torch.cat([past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1).detach() | ||
self.kv_cache[block.attn.key] = torch.cat( | ||
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 | ||
).detach() | ||
self.kv_cache[block.attn.value] = torch.cat( | ||
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 | ||
).detach() | ||
|
||
present_self, present_cross = [], [] | ||
# Group self and cross values | ||
|
@@ -79,7 +83,7 @@ def forward( | |
|
||
present_self = present_self + present_cross | ||
# Add reshape and transpose ops to convert from 3D to 4D | ||
present_self = [present_val.reshape( | ||
present_val.shape[:2] + (-1, 64) | ||
).transpose(1, 2) for present_val in present_self] | ||
present_self = [ | ||
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self | ||
] | ||
return logits, present_self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters