-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathatayal_phoneme_recog.py
65 lines (46 loc) · 2.02 KB
/
atayal_phoneme_recog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
from typing import List
import argparse
import torch
import torchaudio
import pandas as pd
import pathlib
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
def get_parser():
parser = argparse.ArgumentParser(description="Convert .csv to Praat TextGrid")
parser.add_argument(
"-i", "--input_dir",
type=str,
required=True,
default=None)
parser.add_argument(
"-o", "--output_csv",
type=str,
required=False,
default="atayal_phoneme_result.csv")
return parser
def get_phonemes(audio : np.ndarray, samplerate : int, model : Wav2Vec2ForCTC = model,
processor : Wav2Vec2Processor = processor, tokenizer : Wav2Vec2CTCTokenizer = tokenizer):
assert audio.ndim == 1
# Run prediction, get logits and probabilities
inputs = processor(audio, return_tensors="pt", padding="longest", sampling_rate = samplerate)
with torch.no_grad():
logits = model(inputs.input_values).logits.cpu()[0]
#logits = model(waveform).logits[0]
probs = torch.nn.functional.softmax(logits,dim=-1)
predicted_ids = torch.argmax(logits, dim=-1)
return processor.decode(predicted_ids)
def recognition_and_save(input_dir, output_csv):
data_dir = pathlib.Path(input_dir).glob('*.wav')
df = pd.DataFrame(columns=['file', 'pred_phonemes'])
for file in data_dir:
waveform, SAMPLERATE = torchaudio.load(file)
df.loc[len(df)] = [file.name, get_phonemes(waveform[0].numpy(), SAMPLERATE)]
df.to_csv(output_csv, index=False)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
recognition_and_save(args.input_dir, args.output_csv)