Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kwargs for Whisper methods #108

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions audiotools/core/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup_whisper(
).to(self.whisper_device)
self.is_initialized = True

def get_whisper_features(self) -> torch.Tensor:
def get_whisper_features(self, **kwargs) -> torch.Tensor:
"""Preprocess audio signal as per the whisper model's training config.

Returns
Expand Down Expand Up @@ -49,11 +49,12 @@ def get_whisper_features(self) -> torch.Tensor:
raw_speech,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_tensors="pt",
**kwargs
).input_features

return input_features

def get_whisper_transcript(self) -> str:
def get_whisper_transcript(self, **kwargs) -> str:
"""Get the transcript of the audio signal using the whisper model.

Returns
Expand All @@ -69,12 +70,12 @@ def get_whisper_transcript(self) -> str:

with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
generated_ids = self.whisper_model.generate(inputs=input_features)
generated_ids = self.whisper_model.generate(input_features=input_features, **kwargs)

transcription = self.whisper_processor.batch_decode(generated_ids)
transcription = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)
return transcription[0]

def get_whisper_embeddings(self) -> torch.Tensor:
def get_whisper_embeddings(self, **kwargs) -> torch.Tensor:
"""Get the last hidden state embeddings of the audio signal using the whisper model.

Returns
Expand All @@ -92,6 +93,6 @@ def get_whisper_embeddings(self) -> torch.Tensor:

with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
embeddings = encoder(input_features)
embeddings = encoder(input_features, **kwargs)

return embeddings.last_hidden_state
Loading