Skip to content

Commit

Permalink
add sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 23, 2024
1 parent 84f8adf commit a6a8089
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import argparse
import logging
import math
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -98,44 +99,90 @@ def get_args():
help="Stop processing pieces until this number (exclusive).",
)

parser.add_argument(
"--window-duration",
type=float,
default=300.0,
)

parser.add_argument(
"--shift-duration",
type=float,
default=250.0,
)

return parser.parse_args()


@torch.no_grad()
def extract_and_save_one_cuts(
raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device
raw_cuts_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)

logging.info("Extracting kmeans")
cuts = []

assert window_duration >= shift_duration
window_size = int(window_duration * 16000)
shift_size = int(shift_duration * 16000)
overlap_size = window_size - shift_size
out_overlap_size = get_out_length(overlap_size)

for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"

audio = cut.load_audio()

offsets = 0
if True:
x = torch.from_numpy(audio).float().to(device)
T = audio.shape[1]
start = 0
kmeans = []
while start < T:
real_window_size = min(window_size, T - start)
audio_window = audio[:, start : start + real_window_size]

x = (
torch.from_numpy(audio_window)
.float()
.to(next(model.parameters()).device)
)
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)

feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)

with torch.no_grad():
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
current_kmeans = apply_kmeans(feature).tolist()

feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)
if start == 0:
kmeans.extend(current_kmeans)
else:
kmeans.extend(current_kmeans[out_overlap_size:])

kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
if T - start <= window_size:
break

cut_with_kmeans = fastcopy(
cut,
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
start += shift_size

kmeans = " ".join(map(str, kmeans))

cut_with_kmeans = fastcopy(
cut,
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)

cuts = CutSet(cuts)

Expand Down Expand Up @@ -166,6 +213,9 @@ def extract_kmeans(args):
model = model[0].eval().to(device)
do_normalize = task.cfg.normalize

window_duration = args.window_duration
shift_duration = args.shift_duration

if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if cuts_path.is_file():
Expand All @@ -183,7 +233,8 @@ def extract_kmeans(args):
model,
apply_kmeans,
do_normalize,
device,
window_duration,
shift_duration,
)
else:
num_digits = 8 # num_digits is fixed by lhotse split-lazy
Expand Down Expand Up @@ -213,10 +264,19 @@ def extract_kmeans(args):
model,
apply_kmeans,
do_normalize,
device,
window_duration,
shift_duration,
)


def get_out_length(T):
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
T = math.floor((T - kernel_size) / stride) + 1

return max(0, T)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

Expand Down
6 changes: 3 additions & 3 deletions egs/librilight/SSL/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
fi
if [ ! -e data/kmeans/.extract_small.done ]; then
./local/extract_kmeans_from_hubert_base.py --subset small
./local/extract_kmeans.py --subset small
touch data/kmeans/.extract_small.done
fi
if [ ! -e data/kmeans/.extract_medium.done ]; then
./local/extract_kmeans_from_hubert_base.py --subset medium
./local/extract_kmeans.py --subset medium
touch data/kmeans/.extract_medium.done
fi
if [ ! -e data/kmeans/.extract_large.done ]; then
./local/extract_kmeans_from_hubert_base.py --subset large
./local/extract_kmeans.py --subset large
touch data/kmeans/.extract_large.done
fi
fi

0 comments on commit a6a8089

Please sign in to comment.