From 523936e3ef161c6313c6875c7b51e3f3e90ed988 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 16 Dec 2024 15:39:33 -0800 Subject: [PATCH 01/20] Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortformer Diarizer. (#11336) * Adding the first pr files models and dataset Signed-off-by: taejinp * Tested all unit-test files Signed-off-by: taejinp * Name changes on yaml files and train example Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Reflecting comments and removing unnecessary parts for this PR Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Adding docstrings to reflect the PR comments Signed-off-by: taejinp * removed the unused find_first_nonzero Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Fixed all pylint issues Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Resolving pylint issues Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Removing unused varialbe in audio_to_diar_label.py Signed-off-by: taejinp * Fixed docstrings in training script Signed-off-by: taejinp * Line-too-long issue from Pylint fixed Signed-off-by: taejinp * Adding get_subsegments_scriptable to prevent jit.script error Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Addressed Code-QL issues Signed-off-by: taejinp * Resolved conflicts on bce_loss.py Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Adding all the diarization reltated unit-tests Signed-off-by: taejinp * Moving speaker task related unit test files to speaker_tasks folder Signed-off-by: taejinp * Fixed uninit variable issue in bce_loss.py spotted by codeQL Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Fixing code-QL issues Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Reflecting PR comments from weiqingw Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Line too long pylint issue resolved in e2e_diarize_speech.py Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Resovled unused variable issue in model test Signed-off-by: taejinp * Reflecting the comment on Nov 21st 2024. Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Unused variable import time Signed-off-by: taejinp * Adding docstrings to score_labels() function in der.py Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Reflecting comments on YAML files and model file variable changes. Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Added get_subsegments_scriptable for legacy get_subsegment functions Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Resolved line too long pylint issues Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Added training and inference CI-tests Signed-off-by: taejinp * Added the missing parse_func in preprocessing/collections.py Signed-off-by: taejinp * Adding the missing parse_func in preprocessing/collections.py Signed-off-by: taejinp * Fixed an indentation error Signed-off-by: taejinp * Resolved multi_bin_acc and bce_loss issues Signed-off-by: taejinp * Resolved line-too-long for msdd_models.py Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Code QL issues and fixed test errors Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * line too long in audio_to_diar_label.py Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * resolving CICD test issues Signed-off-by: taejinp * Fixing codeQL issues Signed-off-by: taejinp * Fixed pin memory False for inference Signed-off-by: taejinp --------- Signed-off-by: taejinp Signed-off-by: tango4j Co-authored-by: tango4j --- .github/workflows/cicd-main.yml | 29 + .../neural_diarizer/e2e_diarize_speech.py | 1 + .../asr/data/audio_to_diar_label.py | 12 +- .../asr/models/sortformer_diar_models.py | 3 +- .../speaker_tasks/test_diar_datasets.py | 110 ++++ .../test_diar_label_models.py | 42 +- .../test_diar_lhotse_datasets.py | 173 ++++++ .../test_diar_metrics.py | 0 .../test_diar_neural_inference.py | 10 +- .../test_diar_sortformer_models.py | 168 ++++++ .../test_speaker_label_models.py | 12 +- .../utils/test_data_simul_utils.py | 549 ++++++++++++++++++ .../utils}/test_diar_utils.py | 109 +++- .../utils/test_multispeaker_utils.py | 352 +++++++++++ .../speaker_tasks/utils/test_vad_utils.py | 126 ++++ 15 files changed, 1686 insertions(+), 10 deletions(-) create mode 100644 tests/collections/speaker_tasks/test_diar_datasets.py rename tests/collections/{asr => speaker_tasks}/test_diar_label_models.py (79%) create mode 100644 tests/collections/speaker_tasks/test_diar_lhotse_datasets.py rename tests/collections/{asr => speaker_tasks}/test_diar_metrics.py (100%) rename tests/collections/{asr => speaker_tasks}/test_diar_neural_inference.py (87%) create mode 100644 tests/collections/speaker_tasks/test_diar_sortformer_models.py rename tests/collections/{asr => speaker_tasks}/test_speaker_label_models.py (95%) create mode 100644 tests/collections/speaker_tasks/utils/test_data_simul_utils.py rename tests/collections/{asr => speaker_tasks/utils}/test_diar_utils.py (92%) create mode 100644 tests/collections/speaker_tasks/utils/test_multispeaker_utils.py create mode 100644 tests/collections/speaker_tasks/utils/test_vad_utils.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 686b066652c0..310d580e43f6 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -816,6 +816,33 @@ jobs: +trainer.fast_dev_run=True \ exp_manager.exp_dir=/tmp/speaker_diarization_results + L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \ + trainer.devices="[0]" \ + batch_size=3 \ + model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \ + model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + exp_manager.exp_dir=/tmp/speaker_diarization_results \ + +trainer.fast_dev_run=True + + L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \ + model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \ + dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \ + batch_size=1 + L2_Speaker_dev_run_Speech_to_Label: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4753,6 +4780,8 @@ jobs: - L2_Speech_to_Text_EMA - L2_Speaker_dev_run_Speaker_Recognition - L2_Speaker_dev_run_Speaker_Diarization + - L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer + - L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference - L2_Speaker_dev_run_Speech_to_Label - L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference - L2_Speaker_dev_run_Clustering_Diarizer_Inference diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 1767a16cbe02..147d7a3aa002 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -386,6 +386,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) diar_model._cfg.test_ds.batch_size = cfg.batch_size + diar_model._cfg.test_ds.pin_memory = False # Model setup for inference diar_model._cfg.test_ds.num_workers = cfg.num_workers diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 0824c9c6ab51..3f4ae61e0d08 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1065,6 +1065,7 @@ def __init__( round_digits: int = 2, soft_targets: bool = False, subsampling_factor: int = 8, + device: str = 'cpu', ): super().__init__() self.collection = EndtoEndDiarizationSpeechLabel( @@ -1084,6 +1085,7 @@ def __init__( self.soft_targets = soft_targets self.round_digits = 2 self.floor_decimal = 10**self.round_digits + self.device = device def __len__(self): return len(self.collection) @@ -1232,11 +1234,13 @@ def __getitem__(self, index): audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] audio_signal_length = torch.tensor(audio_signal.shape[0]).long() - audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') - target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device) + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate).to( + self.device + ) targets = self.parse_rttm_for_targets_and_lens( rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len - ) + ).to(self.device) return audio_signal, audio_signal_length, targets, target_len @@ -1355,6 +1359,7 @@ def __init__( window_stride, global_rank: int, soft_targets: bool, + device: str, ): super().__init__( manifest_filepath=manifest_filepath, @@ -1365,6 +1370,7 @@ def __init__( window_stride=window_stride, global_rank=global_rank, soft_targets=soft_targets, + device=device, ) def eesd_train_collate_fn(self, batch): diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index f6b0eab4c895..71de10cc2f79 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -175,6 +175,7 @@ def __setup_dataloader_from_config(self, config): window_stride=self._cfg.preprocessor.window_stride, global_rank=global_rank, soft_targets=config.soft_targets if 'soft_targets' in config else False, + device=self.device, ) self.data_collection = dataset.collection @@ -557,13 +558,13 @@ def test_batch( audio_signal=audio_signal, audio_signal_length=audio_signal_length, ) + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) preds = preds.detach().to('cpu') if preds.shape[0] == 1: # batch size = 1 self.preds_total_list.append(preds) else: self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() - self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") diff --git a/tests/collections/speaker_tasks/test_diar_datasets.py b/tests/collections/speaker_tasks/test_diar_datasets.py new file mode 100644 index 000000000000..28272d63bd43 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_datasets.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile + +import pytest +import torch.cuda + +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines + + +def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): + """ + Check if the maximum RTTM duration exceeds the length of the provided audio file. + + Args: + rttm_file_path (str): Path to the RTTM file. + wav_len_in_sec (float): Length of the audio file in seconds. + + Returns: + bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise. + """ + rttm_lines = read_rttm_lines(rttm_file_path) + max_rttm_sec = 0 + for line in rttm_lines: + start, dur = get_vad_out_from_rttm_line(line) + max_rttm_sec = max(max_rttm_sec, start + dur) + return max_rttm_sec <= wav_len_in_sec + + +class TestAudioToSpeechE2ESpkDiarDataset: + + @pytest.mark.unit + def test_e2e_speaker_diar_dataset(self, test_data_dir): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) + + batch_size = 4 + num_samples = 8 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + data_dict_list = [] + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r', encoding='utf-8') as mfile: + for ix, line in enumerate(mfile): + if ix >= num_samples: + break + + line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "") + f.write(f"{line}\n") + data_dict = json.loads(line) + data_dict_list.append(data_dict) + + f.seek(0) + featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None) + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=f.name, + soft_label_thres=0.5, + session_len_sec=90, + num_spks=4, + featurizer=featurizer, + window_stride=0.01, + global_rank=0, + soft_targets=False, + device=device, + ) + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + collate_fn=dataset.eesd_train_collate_fn, + drop_last=False, + shuffle=False, + num_workers=1, + pin_memory=False, + ) + assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct + batch_counts = len(dataloader_instance) + + deviation_thres_rate = 0.01 # 1% deviation allowed + for batch_index, batch in enumerate(dataloader_instance): + if batch_index != batch_counts - 1: + assert len(batch) == batch_size, "Batch size does not match the expected value" + audio_signals, audio_signal_len, targets, target_lens = batch + for sample_index in range(audio_signals.shape[0]): + dataloader_audio_in_sec = audio_signal_len[sample_index].item() + data_dur_in_sec = abs( + data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate + - dataloader_audio_in_sec + ) + assert ( + data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec + ), "Duration deviation exceeds 1%" + assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" + assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" + assert not torch.isnan(targets).any(), "targets tensor contains NaN values" + assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" diff --git a/tests/collections/asr/test_diar_label_models.py b/tests/collections/speaker_tasks/test_diar_label_models.py similarity index 79% rename from tests/collections/asr/test_diar_label_models.py rename to tests/collections/speaker_tasks/test_diar_label_models.py index 2ed6177d3cb2..f01a8add7aab 100644 --- a/tests/collections/asr/test_diar_label_models.py +++ b/tests/collections/speaker_tasks/test_diar_label_models.py @@ -16,6 +16,7 @@ import torch from omegaconf import DictConfig +from nemo.collections.asr.losses import BCELoss from nemo.collections.asr.models import EncDecDiarLabelModel @@ -24,7 +25,12 @@ def msdd_model(): preprocessor = { 'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', - 'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,}, + 'params': { + "features": 80, + "window_size": 0.025, + "window_stride": 0.01, + "sample_rate": 16000, + }, } speaker_model_encoder = { @@ -165,3 +171,37 @@ def test_forward_infer(self, msdd_model): assert diff <= 1e-6 diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch)) assert diff <= 1e-6 + + +class TestBCELoss: + @pytest.mark.unit + @pytest.mark.parametrize( + "probs, labels, target_lens, reduction, expected_output", + [ + ( + torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([2]), + "mean", + torch.tensor(0.693147, dtype=torch.float32), + ), + ( + torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([1]), + "mean", + torch.tensor(0.693147, dtype=torch.float32), + ), + ( + torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32), + torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32), + torch.tensor([2]), + "mean", + torch.tensor(100, dtype=torch.float32), + ), + ], + ) + def test_loss(self, probs, labels, target_lens, reduction, expected_output): + loss = BCELoss(reduction=reduction) + result = loss(probs=probs, labels=labels, target_lens=target_lens) + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" diff --git a/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py new file mode 100644 index 000000000000..281742be9174 --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_lhotse_datasets.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from unittest import mock + +import pytest +import torch +import torch.cuda +from omegaconf import DictConfig + +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + + +def get_train_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': True, + 'num_workers': num_workers, + 'validation_mode': False, + 'use_lhotse': True, + 'use_bucketing': True, + 'num_buckets': 10, + 'bucket_duration_bins': [10, 20, 30, 40, 50, 60, 70, 80, 90], + 'pin_memory': True, + 'min_duration': 80, + 'max_duration': 90, + 'batch_duration': 400, + 'quadratic_duration': 1200, + 'bucket_buffer_size': 20000, + 'shuffle_buffer_size': 10000, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + + +def get_validation_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + + +def get_test_ds_config(manifest_filepath, batch_size, num_workers) -> DictConfig: + return DictConfig( + { + 'manifest_filepath': manifest_filepath, + 'is_tarred': False, + 'tarred_audio_filepaths': None, + 'sample_rate': 16000, + 'num_spks': 4, + 'session_len_sec': 90, + 'soft_label_thres': 0.5, + 'soft_targets': False, + 'labels': None, + 'batch_size': batch_size, + 'shuffle': False, + 'seq_eval_mode': True, + 'num_workers': num_workers, + 'validation_mode': True, + 'use_lhotse': False, + 'use_bucketing': False, + 'drop_last': False, + 'pin_memory': True, + 'window_stride': 0.01, + 'subsampling_factor': 8, + } + ) + + +class TestLhotseAudioToSpeechE2ESpkDiarDataset: + + @pytest.mark.unit + @pytest.mark.parametrize( + "batch_size, num_workers, split", + [ + (4, 8, 'train'), # Example 1 + (4, 0, 'train'), # Example 2 + (2, 4, 'validation'), # Example 3 + (8, 2, 'test'), # Example 4 + ], + ) + def test_e2e_speaker_diar_lhotse_dataset(self, test_data_dir, batch_size, num_workers, split): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) + num_samples = 8 + device = 'gpu' if torch.cuda.is_available() else 'cpu' + data_dict_list = [] + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r', encoding='utf-8') as mfile: + for ix, line in enumerate(mfile): + if ix >= num_samples: + break + + line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "") + f.write(f"{line}\n") + data_dict = json.loads(line) + data_dict_list.append(data_dict) + + f.seek(0) + config = None + if split == 'train': + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + elif split == 'validation': + config = get_train_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + elif split == 'test': + config = get_test_ds_config(manifest_filepath=f.name, batch_size=batch_size, num_workers=num_workers) + + dataloader_instance = get_lhotse_dataloader_from_config( + config, + global_rank=0, + world_size=1, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + deviation_thres_rate = 0.01 # 1% deviation allowed + for batch_index, batch in enumerate(dataloader_instance): + audio_signals, audio_signal_len, targets, target_lens = batch + for sample_index in range(audio_signals.shape[0]): + dataloader_audio_in_sec = audio_signal_len[sample_index].item() + data_dur_in_sec = abs( + data_dict_list[batch_size * batch_index + sample_index]['duration'] * config.sample_rate + - dataloader_audio_in_sec + ) + assert ( + data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec + ), "Duration deviation exceeds 1%" + assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" + assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" + assert not torch.isnan(targets).any(), "targets tensor contains NaN values" + assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" diff --git a/tests/collections/asr/test_diar_metrics.py b/tests/collections/speaker_tasks/test_diar_metrics.py similarity index 100% rename from tests/collections/asr/test_diar_metrics.py rename to tests/collections/speaker_tasks/test_diar_metrics.py diff --git a/tests/collections/asr/test_diar_neural_inference.py b/tests/collections/speaker_tasks/test_diar_neural_inference.py similarity index 87% rename from tests/collections/asr/test_diar_neural_inference.py rename to tests/collections/speaker_tasks/test_diar_neural_inference.py index 076eac129293..64c1196cd9a6 100644 --- a/tests/collections/asr/test_diar_neural_inference.py +++ b/tests/collections/speaker_tasks/test_diar_neural_inference.py @@ -28,13 +28,16 @@ class TestNeuralDiarizerInference: torch.device("cpu"), pytest.param( torch.device("cuda"), - marks=pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required for test.',), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA required for test.', + ), ), ], ) @pytest.mark.parametrize("num_speakers", [None, 1]) @pytest.mark.parametrize("max_num_speakers", [4]) - def test_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers): + def test_msdd_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_num_speakers): """ Test to ensure diarization inference works correctly. - Ensures multiple audio files can be diarized sequentially @@ -69,3 +72,6 @@ def test_diar_inference(self, tmpdir, test_data_dir, device, num_speakers, max_n # assert only 1 speaker & segment assert len(annotation.labels()) == 1 assert len(list(annotation.itersegments())) == 1 + + # class TestSortformerDiarizerInference: + # TODO: This test can only be implemented once SortformerDiarizer model is uploaded. diff --git a/tests/collections/speaker_tasks/test_diar_sortformer_models.py b/tests/collections/speaker_tasks/test_diar_sortformer_models.py new file mode 100644 index 000000000000..41bd1537f16a --- /dev/null +++ b/tests/collections/speaker_tasks/test_diar_sortformer_models.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.models import SortformerEncLabelModel + + +@pytest.fixture() +def sortformer_model(): + + model = { + 'sample_rate': 16000, + 'pil_weight': 0.5, + 'ats_weight': 0.5, + 'max_num_of_spks': 4, + } + model_defaults = { + 'fc_d_model': 512, + 'tf_d_model': 192, + } + preprocessor = { + '_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', + 'normalize': 'per_feature', + 'window_size': 0.025, + 'sample_rate': 16000, + 'window_stride': 0.01, + 'window': 'hann', + 'features': 80, + 'n_fft': 512, + 'frame_splicing': 1, + 'dither': 0.00001, + } + + sortformer_modules = { + '_target_': 'nemo.collections.asr.modules.sortformer_modules.SortformerModules', + 'num_spks': model['max_num_of_spks'], + 'dropout_rate': 0.5, + 'fc_d_model': model_defaults['fc_d_model'], + 'tf_d_model': model_defaults['tf_d_model'], + } + + encoder = { + '_target_': 'nemo.collections.asr.modules.ConformerEncoder', + 'feat_in': preprocessor['features'], + 'feat_out': -1, + 'n_layers': 18, + 'd_model': model_defaults['fc_d_model'], + 'subsampling': 'dw_striding', + 'subsampling_factor': 8, + 'subsampling_conv_channels': 256, + 'causal_downsampling': False, + 'ff_expansion_factor': 4, + 'self_attention_model': 'rel_pos', + 'n_heads': 8, + 'att_context_size': [-1, -1], + 'att_context_style': 'regular', + 'xscaling': True, + 'untie_biases': True, + 'pos_emb_max_len': 5000, + 'conv_kernel_size': 9, + 'conv_norm_type': 'batch_norm', + 'conv_context_size': None, + 'dropout': 0.1, + 'dropout_pre_encoder': 0.1, + 'dropout_emb': 0.0, + 'dropout_att': 0.1, + 'stochastic_depth_drop_prob': 0.0, + 'stochastic_depth_mode': 'linear', + 'stochastic_depth_start_layer': 1, + } + + transformer_encoder = { + '_target_': 'nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder', + 'num_layers': 18, + 'hidden_size': model_defaults['tf_d_model'], + 'inner_size': 768, + 'num_attention_heads': 8, + 'attn_score_dropout': 0.5, + 'attn_layer_dropout': 0.5, + 'ffn_dropout': 0.5, + 'hidden_act': 'relu', + 'pre_ln': False, + 'pre_ln_final_layer_norm': True, + } + + loss = { + '_target_': 'nemo.collections.asr.losses.bce_loss.BCELoss', + 'weight': None, + 'reduction': 'mean', + } + + modelConfig = DictConfig( + { + 'sample_rate': 16000, + 'pil_weight': 0.5, + 'ats_weight': 0.5, + 'max_num_of_spks': 4, + 'model_defaults': DictConfig(model_defaults), + 'encoder': DictConfig(encoder), + 'transformer_encoder': DictConfig(transformer_encoder), + 'sortformer_modules': DictConfig(sortformer_modules), + 'preprocessor': DictConfig(preprocessor), + 'loss': DictConfig(loss), + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.001, + 'betas': (0.9, 0.98), + }, + } + ) + model = SortformerEncLabelModel(cfg=modelConfig) + return model + + +class TestSortformerEncLabelModel: + @pytest.mark.unit + def test_constructor(self, sortformer_model): + sortformer_diar_model = sortformer_model.train() + confdict = sortformer_diar_model.to_config_dict() + instance2 = SortformerEncLabelModel.from_config_dict(confdict) + assert isinstance(instance2, SortformerEncLabelModel) + + @pytest.mark.unit + @pytest.mark.parametrize( + "batch_size, frame_length, sample_len", + [ + (4, 0.08, 16), # Example 1 + (2, 0.02, 32), # Example 2 + (1, 0.1, 20), # Example 3 + ], + ) + def test_forward_infer(self, sortformer_model, batch_size, frame_length, sample_len, num_spks=4): + sortformer_diar_model = sortformer_model.eval() + confdict = sortformer_diar_model.to_config_dict() + sampling_rate = confdict['preprocessor']['sample_rate'] + input_signal = torch.randn(size=(batch_size, sample_len * sampling_rate)) + input_signal_length = (sample_len * sampling_rate) * torch.ones(batch_size, dtype=torch.int) + + with torch.no_grad(): + # batch size 1 + preds_list = [] + for i in range(input_signal.size(0)): + preds = sortformer_diar_model.forward(input_signal[i : i + 1], input_signal_length[i : i + 1]) + preds_list.append(preds) + preds_instance = torch.cat(preds_list, 0) + + # batch size 4 + preds_batch = sortformer_diar_model.forward(input_signal, input_signal_length) + assert preds_instance.shape == preds_batch.shape + + diff = torch.mean(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 + diff = torch.max(torch.abs(preds_instance - preds_batch)) + assert diff <= 1e-6 diff --git a/tests/collections/asr/test_speaker_label_models.py b/tests/collections/speaker_tasks/test_speaker_label_models.py similarity index 95% rename from tests/collections/asr/test_speaker_label_models.py rename to tests/collections/speaker_tasks/test_speaker_label_models.py index 29b5c9eea643..81a051e32e66 100644 --- a/tests/collections/asr/test_speaker_label_models.py +++ b/tests/collections/speaker_tasks/test_speaker_label_models.py @@ -96,7 +96,11 @@ def test_ecapa_enc_dec(self): } modelConfig = DictConfig( - {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder),} + { + 'preprocessor': DictConfig(preprocessor), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + } ) speaker_model = EncDecSpeakerLabelModel(cfg=modelConfig) speaker_model.train() @@ -142,7 +146,11 @@ def test_titanet_enc_dec(self): } modelConfig = DictConfig( - {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder),} + { + 'preprocessor': DictConfig(preprocessor), + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + } ) speaker_model = EncDecSpeakerLabelModel(cfg=modelConfig) speaker_model.train() diff --git a/tests/collections/speaker_tasks/utils/test_data_simul_utils.py b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py new file mode 100644 index 000000000000..9a27820cdfa1 --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_data_simul_utils.py @@ -0,0 +1,549 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.asr.parts.utils.data_simulation_utils import ( + DataAnnotator, + SpeechSampler, + add_silence_to_alignments, + binary_search_alignments, + get_cleaned_base_path, + get_split_points_in_alignments, + normalize_audio, + read_noise_manifest, +) +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line + + +@pytest.fixture() +def annotator(): + cfg = get_data_simulation_configs() + return DataAnnotator(cfg) + + +@pytest.fixture() +def sampler(): + cfg = get_data_simulation_configs() + sampler = SpeechSampler(cfg) + # Must get session-wise randomized silence/overlap mean + sampler.get_session_overlap_mean() + sampler.get_session_silence_mean() + return sampler + + +def get_data_simulation_configs(): + config_dict = { + 'data_simulator': { + 'manifest_filepath': '???', + 'sr': 16000, + 'random_seed': 42, + 'multiprocessing_chunksize': 10000, + 'session_config': {'num_speakers': 4, 'num_sessions': 60, 'session_length': 600}, + 'session_params': { + 'max_audio_read_sec': 20, + 'sentence_length_params': [0.4, 0.05], + 'dominance_var': 0.11, + 'min_dominance': 0.05, + 'turn_prob': 0.875, + 'min_turn_prob': 0.5, + 'mean_silence': 0.15, + 'mean_silence_var': 0.01, + 'per_silence_var': 900, + 'per_silence_min': 0.0, + 'per_silence_max': -1, + 'mean_overlap': 0.1, + 'mean_overlap_var': 0.01, + 'per_overlap_var': 900, + 'per_overlap_min': 0.0, + 'per_overlap_max': -1, + 'start_window': True, + 'window_type': 'hamming', + 'window_size': 0.05, + 'start_buffer': 0.1, + 'split_buffer': 0.1, + 'release_buffer': 0.1, + 'normalize': True, + 'normalization_type': 'equal', + 'normalization_var': 0.1, + 'min_volume': 0.75, + 'max_volume': 1.25, + 'end_buffer': 0.5, + }, + 'outputs': { + 'output_dir': '???', + 'output_filename': 'multispeaker_session', + 'overwrite_output': True, + 'output_precision': 3, + }, + 'background_noise': { + 'add_bg': False, + 'background_manifest': None, + 'num_noise_files': 10, + 'snr': 60, + 'snr_min': None, + }, + 'segment_augmentor': { + 'add_seg_aug': False, + 'augmentor': { + 'gain': {'prob': 0.5, 'min_gain_dbfs': -10.0, 'max_gain_dbfs': 10.0}, + }, + }, + 'session_augmentor': { + 'add_sess_aug': False, + 'augmentor': { + 'white_noise': {'prob': 1.0, 'min_level': -90, 'max_level': -46}, + }, + }, + 'speaker_enforcement': {'enforce_num_speakers': True, 'enforce_time': [0.25, 0.75]}, + 'segment_manifest': {'window': 0.5, 'shift': 0.25, 'step_count': 50, 'deci': 3}, + } + } + return DictConfig(config_dict) + + +def generate_words_and_alignments(sample_index): + if sample_index == 0: + words = ['', 'hello', 'world'] + alignments = [0.5, 1.0, 1.5] + elif sample_index == 1: + words = ["", "stephanos", "dedalos", ""] + alignments = [0.51, 1.31, 2.04, 2.215] + elif sample_index == 2: + words = ['', 'hello', 'world', '', 'welcome', 'to', 'nemo', ''] + alignments = [0.5, 1.0, 1.5, 1.7, 1.8, 2.2, 2.7, 2.8] + else: + raise ValueError(f"sample_index {sample_index} not supported") + speaker_id = 'speaker_0' + return words, alignments, speaker_id + + +class TestGetCtmLine: + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0, 1]) + def test_wrong_type_conf_values(self, conf): + # Test with wrong integer confidence values + with pytest.raises(ValueError): + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.01, 0.99]) + def test_valid_conf_values(self, conf): + # Test with valid confidence values + output_precision = 2 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected = "test_source 1 0.12 0.46 word" + f" {conf:.{output_precision}f} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, 2, -1, 100, -100]) + def test_invalid_conf_ranges(self, conf): + # Test with invalid confidence values + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration, output_precision", + [(0.123, 0.456, 2), (1.0, 2.0, 1), (0.0, 0.0, 2), (0.01, 0.99, 3), (1.23, 4.56, 2)], + ) + def test_valid_start_time_duration_with_precision(self, start_time, duration, output_precision): + # Test with valid beginning time, duration values and output precision + confidence = 0.5 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=confidence, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected_start_time = ( + f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected_duration = f"{duration:.{output_precision}f}" # Adjusted to match the output format with precision + expected_confidence = ( + f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected = f"test_source 1 {expected_start_time} {expected_duration} word {expected_confidence} lex speaker1\n" + assert ( + result == expected + ), f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" + + @pytest.mark.unit + def test_valid_input(self): + # Test with completely valid inputs + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=0.789, + type_of_token="lex", + speaker="speaker1", + ) + expected = "test_source 1 0.12 0.46 word 0.79 lex speaker1\n" + assert result == expected, "Failed on valid input" + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration", + [ + ("not a float", 1.0), + (1.0, "not a float"), + (1, 2.0), # Integers should be converted to float + (2.0, 3), # Same as above + ], + ) + def test_invalid_types_for_time_duration(self, start_time, duration): + # Test with invalid types for start_time and duration + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=0.5, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, "not a float"]) + def test_invalid_conf_values(self, conf): + # Test with invalid values for conf + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + def test_default_values(self): + # Test with missing optional parameters + result = get_ctm_line( + source="test_source", + channel=None, + start_time=0.123, + duration=0.456, + token="word", + conf=None, + type_of_token=None, + speaker=None, + ) + expected = "test_source 1 0.12 0.46 word NA unknown NA\n" + assert result == expected, "Failed on default values" + + +class TestDataSimulatorUtils: + # TODO: add tests for all util functions + @pytest.mark.parametrize("max_audio_read_sec", [2.5, 3.5, 4.5]) + @pytest.mark.parametrize("min_alignment_count", [2, 3, 4]) + def test_binary_search_alignments(self, max_audio_read_sec, min_alignment_count): + inds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + alignments = [0.5, 11.0, 11.5, 12.0, 13.0, 14.0, 14.5, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 30, 40.0] + offset_max = binary_search_alignments(inds, max_audio_read_sec, min_alignment_count, alignments) + assert max_audio_read_sec <= alignments[-1 * min_alignment_count] - alignments[inds[offset_max]] + + @pytest.mark.parametrize("sample_len", [100, 16000]) + @pytest.mark.parametrize("gain", [0.1, 0.5, 1.0, 2.0, 5.0]) + def test_normalize_audio(self, sample_len, gain): + array_raw = np.random.randn(sample_len) + array_input = torch.from_numpy(gain * array_raw / np.max(np.abs(array_raw))) + norm_array = normalize_audio(array_input) + assert torch.max(torch.abs(norm_array)) == 1.0 + assert torch.min(torch.abs(norm_array)) < 1.0 + + @pytest.mark.parametrize("output_dir", [os.path.join(os.getcwd(), "test_dir")]) + def test_get_cleaned_base_path(self, output_dir): + result_path = get_cleaned_base_path(output_dir, overwrite_output=True) + assert os.path.exists(result_path) and not os.path.isfile(result_path) + result_path = get_cleaned_base_path(output_dir, overwrite_output=False) + assert os.path.exists(result_path) and not os.path.isfile(result_path) + os.rmdir(result_path) + assert not os.path.exists(result_path) + + @pytest.mark.parametrize( + "words, alignments, answers", + [ + (['', 'hello', 'world'], [0.5, 1.0, 1.5], [[0, 16000.0]]), + ( + ['', 'hello', 'world', '', 'welcome', 'to', 'nemo', ''], + [0.27, 1.0, 1.7, 2.7, 2.8, 3.2, 3.7, 3.9], + [[0, (1.7 + 0.5) * 16000], [(2.7 - 0.5) * 16000, (3.9 - 0.27) * 16000]], + ), + ], + ) + @pytest.mark.parametrize("sr", [16000]) + @pytest.mark.parametrize("split_buffer", [0.5]) + @pytest.mark.parametrize("new_start", [0.0]) + def test_get_split_points_in_alignments(self, words, alignments, sr, new_start, split_buffer, answers): + sentence_audio_len = sr * (alignments[-1] - alignments[0]) + splits = get_split_points_in_alignments(words, alignments, split_buffer, sr, sentence_audio_len, new_start) + assert len(splits) == len(answers) + for k, interval in enumerate(splits): + assert abs(answers[k][0] - interval[0]) < 1e-4 + assert abs(answers[k][1] - interval[1]) < 1e-4 + + @pytest.mark.parametrize( + "alignments, words", [(['hello', 'world'], [1.0, 1.5]), (['', 'hello', 'world'], [0.0, 1.0, 1.5])] + ) + def test_add_silence_to_alignments(self, alignments, words): + """ + Test add_silence_to_alignments function. + """ + audio_manifest = { + 'audio_filepath': 'test.wav', + 'alignments': alignments, + 'words': words, + } + audio_manifest = add_silence_to_alignments(audio_manifest) + if words[0] == '': + assert audio_manifest['alignments'] == [0.0] + alignments + assert audio_manifest['words'] == [''] + words + else: + assert audio_manifest['alignments'] == alignments + assert audio_manifest['words'] == words + + +class TestDataAnnotator: + def test_init(self, annotator): + assert isinstance(annotator, DataAnnotator) + + def test_create_new_rttm_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + start, end = alignments[0], alignments[-1] + rttm_list = annotator.create_new_rttm_entry( + words=words, alignments=alignments, start=start, end=end, speaker_id=speaker_id + ) + assert rttm_list[0] == f"{start} {end} {speaker_id}" + + def test_create_new_json_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + start, end = alignments[0], alignments[-1] + test_wav_filename = '/path/to/test_wav_filename.wav' + test_rttm_filename = '/path/to/test_rttm_filename.rttm' + test_ctm_filename = '/path/to/test_ctm_filename.ctm' + text = " ".join(words) + + one_line_json_dict = annotator.create_new_json_entry( + text=text, + wav_filename=test_wav_filename, + start=start, + length=end - start, + speaker_id=speaker_id, + rttm_filepath=test_rttm_filename, + ctm_filepath=test_ctm_filename, + ) + start = round(float(start), annotator._params.data_simulator.outputs.output_precision) + length = round(float(end - start), annotator._params.data_simulator.outputs.output_precision) + meta = { + "audio_filepath": test_wav_filename, + "offset": start, + "duration": length, + "label": speaker_id, + "text": text, + "num_speakers": annotator._params.data_simulator.session_config.num_speakers, + "rttm_filepath": test_rttm_filename, + "ctm_filepath": test_ctm_filename, + "uem_filepath": None, + } + assert one_line_json_dict == meta + + def test_create_new_ctm_entry(self, annotator): + words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) + session_name = 'test_session' + ctm_list = annotator.create_new_ctm_entry( + words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] + ) + assert ctm_list[0] == ( + alignments[1], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[1], + duration=float(alignments[1] - alignments[0]), + token=words[1], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) + assert ctm_list[1] == ( + alignments[2], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[2], + duration=float(alignments[2] - alignments[1]), + token=words[2], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) + + +class TestSpeechSampler: + def test_init(self, sampler): + assert isinstance(sampler, SpeechSampler) + + def test_init_overlap_params(self, sampler): + sampler._init_overlap_params() + assert sampler.per_silence_min_len is not None + assert sampler.per_silence_max_len is not None + assert type(sampler.per_silence_min_len) == int + assert type(sampler.per_silence_max_len) == int + + def test_init_silence_params(self, sampler): + sampler._init_overlap_params() + assert sampler.per_overlap_min_len is not None + assert sampler.per_overlap_max_len is not None + assert type(sampler.per_overlap_min_len) == int + assert type(sampler.per_overlap_max_len) == int + + @pytest.mark.parametrize("mean", [0.1, 0.2, 0.3]) + @pytest.mark.parametrize("var", [0.05, 0.07]) + def test_get_session_silence_mean_pass(self, sampler, mean, var): + sampler.mean_silence = mean + sampler.mean_silence_var = var + sampled_silence_mean = sampler.get_session_silence_mean() + assert 0 <= sampled_silence_mean <= 1 + + @pytest.mark.parametrize("mean", [0.5]) + @pytest.mark.parametrize("var", [0.5, 0.6]) + def test_get_session_silence_mean_fail(self, sampler, mean, var): + """ + This test should raise `ValueError` because `mean_silence_var` + should be less than `mean_silence * (1 - mean_silence)`. + """ + sampler.mean_silence = mean + sampler.mean_silence_var = var + with pytest.raises(ValueError) as execinfo: + sampler.get_session_silence_mean() + assert "ValueError" in str(execinfo) and "mean_silence_var" in str(execinfo) + + @pytest.mark.parametrize("mean", [0.1, 0.2, 0.3]) + @pytest.mark.parametrize("var", [0.05, 0.07]) + def test_get_session_overlap_mean_pass(self, sampler, mean, var): + sampler.mean_overlap = mean + sampler.mean_overlap_var = var + sampled_overlap_mean = sampler.get_session_overlap_mean() + assert 0 <= sampled_overlap_mean <= 1 + + @pytest.mark.parametrize("mean", [0.4, 0.5]) + @pytest.mark.parametrize("var", [0.3, 0.8]) + def test_get_session_overlap_mean_fail(self, sampler, mean, var): + """ + This test should raise `ValueError` because `mean_overlap_var` + should be less than `mean_overlap * (1 - mean_overlap)`. + """ + sampler.mean_overlap = mean + sampler.mean_overlap_var = var + sampler._params = DictConfig(sampler._params) + with pytest.raises(ValueError) as execinfo: + sampler.get_session_overlap_mean() + assert "ValueError" in str(execinfo) and "mean_overlap_var" in str(execinfo) + + @pytest.mark.parametrize("non_silence_len_samples", [16000, 32000]) + @pytest.mark.parametrize("running_overlap_len_samples", [8000, 12000]) + def test_sample_from_overlap_model(self, sampler, non_silence_len_samples, running_overlap_len_samples): + sampler.get_session_overlap_mean() + sampler.running_overlap_len_samples = running_overlap_len_samples + overlap_amount = sampler.sample_from_overlap_model(non_silence_len_samples=non_silence_len_samples) + assert type(overlap_amount) == int + assert 0 <= overlap_amount + + @pytest.mark.parametrize("running_len_samples", [8000, 16000]) + @pytest.mark.parametrize("running_overlap_len_samples", [8000, 12000]) + def test_sample_from_silence_model(self, sampler, running_len_samples, running_overlap_len_samples): + sampler.get_session_silence_mean() + self.running_overlap_len_samples = running_overlap_len_samples + silence_amount = sampler.sample_from_silence_model(running_len_samples=running_len_samples) + assert type(silence_amount) == int + assert 0 <= silence_amount + + @pytest.mark.with_downloads() + @pytest.mark.parametrize("num_noise_files", [1, 2, 4]) + def test_sample_noise_manifest(self, sampler, num_noise_files, test_data_dir): + sampler.num_noise_files = num_noise_files + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/an4_val.json')) + noise_manifest = read_noise_manifest(add_bg=True, background_manifest=manifest_path) + sampled_noise_manifests = sampler.sample_noise_manifest(noise_manifest=noise_manifest) + assert len(sampled_noise_manifests) == num_noise_files + + @pytest.mark.parametrize("running_speech_len_samples", [32000, 64000]) + @pytest.mark.parametrize("running_overlap_len_samples", [16000, 32000]) + @pytest.mark.parametrize("running_len_samples", [64000, 96000]) + @pytest.mark.parametrize("non_silence_len_samples", [16000, 32000]) + def test_silence_vs_overlap_selector( + self, + sampler, + running_overlap_len_samples, + running_speech_len_samples, + running_len_samples, + non_silence_len_samples, + ): + sampler.running_overlap_len_samples = running_overlap_len_samples + sampler.running_speech_len_samples = running_speech_len_samples + add_overlap = sampler.silence_vs_overlap_selector( + running_len_samples=running_len_samples, non_silence_len_samples=non_silence_len_samples + ) + assert type(add_overlap) == bool diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/speaker_tasks/utils/test_diar_utils.py similarity index 92% rename from tests/collections/asr/test_diar_utils.py rename to tests/collections/speaker_tasks/utils/test_diar_utils.py index cb364675fcf4..71ae2dc16d8e 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/speaker_tasks/utils/test_diar_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import os - import numpy as np import pytest import torch @@ -48,6 +47,7 @@ get_online_subsegments_from_buffer, get_speech_labels_for_update, get_sub_range_list, + get_subsegments, get_subsegments_scriptable, get_target_sig, int2fl, @@ -115,6 +115,10 @@ def generate_toy_data( emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) seg_list.extend(segments) emb_list.append(emb) + if emb.shape[0] == 0: + import ipdb + + ipdb.set_trace() multiscale_segment_counts[scale_idx] += emb.shape[0] if scale_idx == len(multiscale_segment_counts) - 1: @@ -377,6 +381,109 @@ def test_online_speaker_clustering_instance_export(self): isinstance(offline_speaker_clustering, torch.jit._script.RecursiveScriptClass) +class TestGetSubsegments: + @pytest.mark.unit + @pytest.mark.parametrize( + "offset, window, shift, duration, min_subsegment_duration, decimals, use_asr_style_frame_count, sample_rate, feat_per_sec, expected", + [ + (12.05, 1.5, 0.75, 2.4, 0.01, 2, False, 16000, 100, [[12.05, 1.5], [12.8, 1.5], [13.55, 0.9]]), + (0, 1.0, 0.5, 0.4, 0.01, 2, False, 16000, 100, [[0, 0.4]]), + (0, 2.0, 1.0, 1.5, 0.5, 2, False, 16000, 100, [[0, 1.5]]), + ( + 10, + 1.5, + 0.75, + 4.5, + 0.5, + 2, + False, + 16000, + 100, + [[10, 1.5], [10.75, 1.5], [11.5, 1.5], [12.25, 1.5], [13.0, 1.5]], + ), + (0, 1.5, 0.5, 0.3, 0.01, 2, True, 16000, 100, [[0, 0.3]]), + ], + ) + def test_get_subsegments( + self, + offset, + window, + shift, + duration, + min_subsegment_duration, + decimals, + use_asr_style_frame_count, + sample_rate, + feat_per_sec, + expected, + ): + + for is_scriptable in [True, False]: + if is_scriptable: + result = get_subsegments_scriptable( + offset=offset, + window=window, + shift=shift, + duration=duration, + ) + else: + result = get_subsegments( + offset=offset, + window=window, + shift=shift, + duration=duration, + min_subsegment_duration=min_subsegment_duration, + decimals=decimals, + use_asr_style_frame_count=use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=feat_per_sec, + ) + result_round = [] + for subsegment in result: + result_round.append([round(x, decimals) for x in subsegment]) + assert result_round == expected + + @pytest.mark.unit + def test_min_subsegment_duration_filtering(self): + result = get_subsegments( + offset=0, + window=1.5, + shift=0.5, + duration=3, + min_subsegment_duration=2.0, + decimals=2, + use_asr_style_frame_count=False, + ) + expected = [] # Only subsegments meeting the duration filter should remain + assert result == expected + + @pytest.mark.unit + def test_zero_duration(self): + result = get_subsegments( + offset=0, + window=1.0, + shift=0.5, + duration=0, + min_subsegment_duration=0.01, + decimals=2, + use_asr_style_frame_count=False, + ) + assert result == [] + + @pytest.mark.unit + def test_edge_case_short_slice(self): + result = get_subsegments( + offset=0, + window=0.5, + shift=0.25, # Shift larger than duration + duration=0.25, + min_subsegment_duration=0.01, + decimals=2, + use_asr_style_frame_count=False, + ) + assert result == [[0.0, 0.25]] + + class TestDiarizationSegmentationUtils: """ Test segmentation util functions diff --git a/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py new file mode 100644 index 000000000000..2e01cf4b94da --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_multispeaker_utils.py @@ -0,0 +1,352 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import pytest +import torch + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + find_best_permutation, + find_first_nonzero, + get_ats_targets, + get_hidden_length_from_sample_length, + get_pil_targets, + reconstruct_labels, +) + + +def reconstruct_labels_forloop(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + This is a for-loop implementation of reconstruct_labels built for testing purposes. + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + + +class TestSortingUtils: + @pytest.mark.unit + @pytest.mark.parametrize( + "mat, max_cap_val, thres, expected", + [ + # Test 1: Basic case with clear first nonzero values + (torch.tensor([[0.1, 0.6, 0.0], [0.0, 0.0, 0.9]]), -1, 0.5, torch.tensor([1, 2])), + # Test 2: All elements are below threshold + (torch.tensor([[0.1, 0.2], [0.3, 0.4]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 3: No nonzero elements, should return max_cap_val (-1) + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 4: Large matrix with mixed values, some rows with all values below threshold + (torch.tensor([[0.1, 0.7, 0.3], [0.0, 0.0, 0.9], [0.5, 0.6, 0.7]]), -1, 0.5, torch.tensor([1, 2, 0])), + # Test 5: Single row matrix + (torch.tensor([[0.0, 0.0, 0.6]]), -1, 0.5, torch.tensor([2])), + # Test 6: Single column matrix + (torch.tensor([[0.1], [0.6], [0.0]]), -1, 0.5, torch.tensor([-1, 0, -1])), + # Test 7: One element matrix + (torch.tensor([[0.501]]), -1, 0.5, torch.tensor([0], dtype=torch.long)), + # Test 8: All values are zero, should return max_cap_val + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), -1, 0.5, torch.tensor([-1, -1])), + # Test 9: All values are above threshold + (torch.tensor([[0.6, 0.7], [0.8, 0.9]]), -1, 0.5, torch.tensor([0, 0])), + # Test 10: Custom max_cap_val different from default + (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), 99, 0.5, torch.tensor([99, 99])), + # Test 11: Matrix with 101 columns, first nonzero value is towards the end + (torch.cat([torch.zeros(1, 100), torch.ones(1, 1)], dim=1), -1, 0.5, torch.tensor([100])), + # Test 12: Matrix with 1000 columns, all below threshold except one near the middle + ( + torch.cat([torch.zeros(1, 499), torch.tensor([[0.6]]), torch.zeros(1, 500)], dim=1), + -1, + 0.5, + torch.tensor([499]), + ), + ], + ) + def test_find_first_nonzero(self, mat, max_cap_val, thres, expected): + result = find_first_nonzero(mat, max_cap_val, thres) + assert torch.equal(result, expected), f"Expected {expected} but got {result}" + + @pytest.mark.unit + @pytest.mark.parametrize( + "match_score, speaker_permutations, expected", + [ + # Test 1: Simple case with batch size 1, clear best match + ( + torch.tensor([[0.1, 0.9, 0.2]]), # match_score (batch_size=1, num_permutations=3) + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations (num_permutations=3, num_speakers=2) + torch.tensor([[1, 0]]), # expected best permutation for the batch + ), + # Test 2: Batch size 2, different best matches for each batch + ( + torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.6, 0.4]]), # match_score (batch_size=2, num_permutations=3) + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations + torch.tensor([[0, 1], [1, 0]]), # expected best permutations + ), + # Test 3: Larger number of speakers and permutations + ( + torch.tensor( + [[0.1, 0.4, 0.9, 0.5], [0.6, 0.3, 0.7, 0.2]] + ), # match_score (batch_size=2, num_permutations=4) + torch.tensor( + [[0, 1, 2], [1, 0, 2], [2, 1, 0], [1, 2, 0]] + ), # speaker_permutations (num_permutations=4, num_speakers=3) + torch.tensor([[2, 1, 0], [2, 1, 0]]), # expected best permutations + ), + # Test 4: All match scores are the same, should pick the first permutation (argmax behavior) + ( + torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]), # equal match_score across permutations + torch.tensor([[0, 1], [1, 0], [0, 1]]), # speaker_permutations + torch.tensor([[0, 1], [0, 1]]), # first permutation is chosen as tie-breaker + ), + # Test 5: Single speaker case (num_speakers = 1) + ( + torch.tensor([[0.8, 0.2]]), # match_score (batch_size=1, num_permutations=2) + torch.tensor([[0], [0]]), # speaker_permutations (num_permutations=2, num_speakers=1) + torch.tensor([[0]]), # expected best permutation + ), + # Test 6: Batch size 3, varying permutations + ( + torch.tensor([[0.3, 0.6], [0.4, 0.1], [0.2, 0.7]]), # match_score (batch_size=3, num_permutations=2) + torch.tensor([[0, 1], [1, 0]]), # speaker_permutations + torch.tensor([[1, 0], [0, 1], [1, 0]]), # expected best permutations for each batch + ), + ], + ) + def test_find_best_permutation(self, match_score, speaker_permutations, expected): + result = find_best_permutation(match_score, speaker_permutations) + assert torch.equal(result, expected), f"Expected {expected} but got {result}" + + @pytest.mark.parametrize( + "batch_size, num_frames, num_speakers", + [ + (2, 4, 3), # Original test case + (3, 5, 2), # More frames and speakers + (1, 6, 4), # Single batch with more frames and speakers + (5, 3, 5), # More batch size with equal frames and speakers + ], + ) + def test_reconstruct_labels_with_forloop_ver(self, batch_size, num_frames, num_speakers): + # Generate random labels and batch_perm_inds tensor for testing + labels = torch.rand(batch_size, num_frames, num_speakers) + batch_perm_inds = torch.stack([torch.randperm(num_speakers) for _ in range(batch_size)]) + + # Call both functions + result_matrix = reconstruct_labels(labels, batch_perm_inds) + result_forloop = reconstruct_labels_forloop(labels, batch_perm_inds) + + # Assert that both methods return the same result + assert torch.allclose(result_matrix, result_forloop), "The results are not equal!" + + @pytest.mark.parametrize( + "labels, batch_perm_inds, expected_output", + [ + # Example 1: Small batch size with a few frames and speakers + ( + torch.tensor( + [ + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], # First batch + [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]], # Second batch + ] + ), + torch.tensor([[2, 0, 1], [1, 2, 0]]), + torch.tensor( + [ + [[0.3, 0.1, 0.2], [0.6, 0.4, 0.5], [0.9, 0.7, 0.8]], # First batch reconstructed + [[0.8, 0.7, 0.9], [0.5, 0.4, 0.6], [0.2, 0.1, 0.3]], # Second batch reconstructed + ] + ), + ), + # Example 2: batch_size = 1 with more frames and speakers + ( + torch.tensor( + [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]]] + ), + torch.tensor([[3, 0, 1, 2]]), + torch.tensor( + [[[0.4, 0.1, 0.2, 0.3], [0.8, 0.5, 0.6, 0.7], [1.2, 0.9, 1.0, 1.1], [1.6, 1.3, 1.4, 1.5]]] + ), + ), + # Example 3: Larger batch size with fewer frames and speakers + ( + torch.tensor( + [ + [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], # First batch + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch + [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]], # Third batch + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], # Fourth batch + ] + ), + torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]]), + torch.tensor( + [ + [[0.2, 0.1], [0.4, 0.3], [0.6, 0.5]], # First batch reconstructed + [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], # Second batch unchanged + [[1.4, 1.3], [1.6, 1.5], [1.8, 1.7]], # Third batch reconstructed + [[1.9, 2.0], [2.1, 2.2], [2.3, 2.4]], # Fourth batch unchanged + ] + ), + ), + ], + ) + def test_reconstruct_labels(self, labels, batch_perm_inds, expected_output): + # Call the reconstruct_labels function + result = reconstruct_labels(labels, batch_perm_inds) + # Assert that the result matches the expected output + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" + + +class TestTargetGenerators: + + @pytest.mark.parametrize( + "labels, preds, num_speakers, expected_output", + [ + # Test 1: Basic case with simple permutations + ( + torch.tensor( + [ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.9], [0.0, 0.9, 0.1], [0.9, 0.1, 0.0]], # Batch 2 + ] + ), + torch.tensor( + [ + [[0.8, 0.2, 0.0], [0.2, 0.7, 0.0], [0.0, 0.1, 0.9]], # Batch 1 + [[0.0, 0.0, 0.8], [0.0, 0.8, 0.2], [0.9, 0.1, 0.0]], # Batch 2 + ] + ), + 3, # Number of speakers + torch.tensor( + [ + [[0.9, 0.1, 0.0], [0.1, 0.8, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 1 + [[0.9, 0.0, 0.0], [0.1, 0.9, 0.0], [0.0, 0.1, 0.9]], # Expected labels for Batch 2 + ] + ), + ), + # Test 2: Ambiguous case + ( + torch.tensor([[[0.9, 0.8, 0.7], [0.2, 0.8, 0.7], [0.2, 0.3, 0.9]]]), # Labels + torch.tensor([[[0.6, 0.7, 0.2], [0.9, 0.4, 0.0], [0.1, 0.7, 0.1]]]), # Preds + 3, # Number of speakers + torch.tensor([[[0.8, 0.7, 0.9], [0.8, 0.7, 0.2], [0.3, 0.9, 0.2]]]), # Expected output + ), + # Test 3: Ambiguous case + ( + torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels + torch.tensor( + [[[0.6, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] + ), # Preds + 4, # Number of speakers + torch.tensor([[[1, 1, 0, 0], [1, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]]]), # Expected output + ), + ], + ) + def test_get_ats_targets(self, labels, preds, num_speakers, expected_output): + # Generate all permutations for the given number of speakers + speaker_inds = list(range(num_speakers)) + speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) + + # Call the function under test + result = get_ats_targets(labels, preds, speaker_permutations) + # Assert that the result matches the expected output + assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}" + + @pytest.mark.unit + @pytest.mark.parametrize( + "labels, preds, num_speakers, expected_output", + [ + # Test 1: Basic case with simple permutations + ( + torch.tensor( + [[[1, 0], [0, 1]], [[1, 0], [0, 1]]] + ), # Labels (batch_size=2, num_speakers=2, num_classes=2) + torch.tensor( + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]] + ), # Preds (batch_size=2, num_speakers=2, num_classes=2) + 2, # Number of speakers + torch.tensor([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), # expected max_score_permed_labels + ), + # Test 2: Batch size 1 with more complex permutations + ( + torch.tensor([[[0.8, 0.2], [0.3, 0.7]]]), # Labels + torch.tensor([[[0.9, 0.1], [0.2, 0.8]]]), # Preds + 2, # Number of speakers + torch.tensor( + [[[0.8, 0.2], [0.3, 0.7]]] + ), # expected output (labels remain the same as preds are close) + ), + # Test 3: Ambiguous case + ( + torch.tensor([[[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Labels + torch.tensor( + [[[0.61, 0.6, 0.1, 0.9], [0.7, 0.7, 0.2, 0.8], [0.4, 0.6, 0.2, 0.7], [0.1, 0.1, 0.1, 0.7]]] + ), # Preds + 4, # Number of speakers + torch.tensor([[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]]]), # Expected output + ), + ], + ) + def test_get_pil_targets(self, labels, preds, num_speakers, expected_output): + # Generate all permutations for the given number of speakers + speaker_inds = list(range(num_speakers)) + speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) + + result = get_pil_targets(labels, preds, speaker_permutations) + assert torch.equal(result, expected_output), f"Expected {expected_output} but got {result}" + + +class TestGetHiddenLengthFromSampleLength: + @pytest.mark.parametrize( + "num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length", + [ + (160, 160, 8, 1), + (1280, 160, 8, 2), + (0, 160, 8, 1), + (159, 160, 8, 1), + (129, 100, 5, 1), + (300, 150, 3, 1), + ], + ) + def test_various_cases( + self, num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame, expected_hidden_length + ): + result = get_hidden_length_from_sample_length( + num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + assert result == expected_hidden_length + + def test_default_parameters(self): + assert get_hidden_length_from_sample_length(160) == 1 + assert get_hidden_length_from_sample_length(1280) == 2 + assert get_hidden_length_from_sample_length(0) == 1 + assert get_hidden_length_from_sample_length(159) == 1 + + def test_edge_cases(self): + assert get_hidden_length_from_sample_length(159, 160, 8) == 1 + assert get_hidden_length_from_sample_length(160, 160, 8) == 1 + assert get_hidden_length_from_sample_length(161, 160, 8) == 1 + assert get_hidden_length_from_sample_length(1279, 160, 8) == 1 + + def test_real_life_examples(self): + # The samples tried when this function was designed. + assert get_hidden_length_from_sample_length(160000) == 126 + assert get_hidden_length_from_sample_length(159999) == 125 + assert get_hidden_length_from_sample_length(158720) == 125 + assert get_hidden_length_from_sample_length(158719) == 124 + + assert get_hidden_length_from_sample_length(158880) == 125 + assert get_hidden_length_from_sample_length(158879) == 125 + assert get_hidden_length_from_sample_length(1600) == 2 + assert get_hidden_length_from_sample_length(1599) == 2 diff --git a/tests/collections/speaker_tasks/utils/test_vad_utils.py b/tests/collections/speaker_tasks/utils/test_vad_utils.py new file mode 100644 index 000000000000..a7672e1aa43d --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_vad_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from pyannote.core import Annotation, Segment + +from nemo.collections.asr.parts.utils.vad_utils import ( + align_labels_to_frames, + convert_labels_to_speech_segments, + frame_vad_construct_pyannote_object_per_file, + get_frame_labels, + get_nonspeech_segments, + load_speech_overlap_segments_from_rttm, + load_speech_segments_from_rttm, + read_rttm_as_pyannote_object, +) + + +def get_simple_rttm_without_overlap(rttm_file="test1.rttm"): + line = "SPEAKER 1 0 2 speech \n" + speech_segments = [[0.0, 2.0]] + with open(rttm_file, "w") as f: + f.write(line) + return rttm_file, speech_segments + + +def get_simple_rttm_with_overlap(rttm_file="test2.rttm"): + speech_segments = [[0.0, 3.0]] + overlap_segments = [[1.0, 2.0]] + with open(rttm_file, "w") as f: + f.write("SPEAKER 1 0 2 speech \n") + f.write("SPEAKER 1 1 2 speech \n") + return rttm_file, speech_segments, overlap_segments + + +def get_simple_rttm_with_silence(rttm_file="test3.rttm"): + line = "SPEAKER 1 1 2 speech \n" + speech_segments = [[1.0, 2.0]] + silence_segments = [[0.0, 1.0]] + with open(rttm_file, "w") as f: + f.write(line) + return rttm_file, speech_segments, silence_segments + + +class TestVADUtils: + @pytest.mark.parametrize(["logits_len", "labels_len"], [(20, 10), (20, 11), (20, 9), (10, 21), (10, 19)]) + @pytest.mark.unit + def test_align_label_logits(self, logits_len, labels_len): + logits = np.arange(logits_len).tolist() + labels = np.arange(labels_len).tolist() + labels_new = align_labels_to_frames(probs=logits, labels=labels) + + assert len(labels_new) == len(logits) + + @pytest.mark.unit + def test_load_speech_segments_from_rttm(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test1.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + assert speech_segments_new == speech_segments + + @pytest.mark.unit + def test_load_speech_overlap_segments_from_rttm(self, test_data_dir): + rttm_file, speech_segments, overlap_segments = get_simple_rttm_with_overlap(test_data_dir + "/test2.rttm") + speech_segments_new, overlap_segments_new = load_speech_overlap_segments_from_rttm(rttm_file) + assert speech_segments_new == speech_segments + assert overlap_segments_new == overlap_segments + + @pytest.mark.unit + def test_get_nonspeech_segments(self, test_data_dir): + rttm_file, speech_segments, silence_segments = get_simple_rttm_with_silence(test_data_dir + "/test3.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + silence_segments_new = get_nonspeech_segments(speech_segments_new) + assert silence_segments_new == silence_segments + + @pytest.mark.unit + def test_get_frame_labels(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test4.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments_new, 0.02, 0.0, 3.0, as_str=False) + assert frame_labels[0] == 1 + assert len(frame_labels) == 150 + + @pytest.mark.unit + def test_convert_labels_to_speech_segments(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test5.rttm") + speech_segments_new = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments_new, 0.02, 0.0, 3.0, as_str=False) + speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) + assert speech_segments_new == speech_segments + + @pytest.mark.unit + def test_read_rttm_as_pyannote_object(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test6.rttm") + pyannote_object = read_rttm_as_pyannote_object(rttm_file) + pyannote_object_gt = Annotation() + pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' + assert pyannote_object == pyannote_object_gt + + @pytest.mark.unit + def test_frame_vad_construct_pyannote_object_per_file(self, test_data_dir): + rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test7.rttm") + # test for rttm input + ref, hyp = frame_vad_construct_pyannote_object_per_file(rttm_file, rttm_file) + pyannote_object_gt = Annotation() + pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' + assert ref == hyp == pyannote_object_gt + + # test for list input + speech_segments = load_speech_segments_from_rttm(rttm_file) + frame_labels = get_frame_labels(speech_segments, 0.02, 0.0, 3.0, as_str=False) + speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) + assert speech_segments_new == speech_segments + ref, hyp = frame_vad_construct_pyannote_object_per_file(frame_labels, frame_labels, 0.02) + assert ref == hyp == pyannote_object_gt From 06a14915134af4858a5989b4870dbee232f8a7eb Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 17 Dec 2024 12:57:38 +0400 Subject: [PATCH 02/20] 2x more memory efficient Graph-based RNN-T (#11169) * Optimized Graph-Transducer implementation Signed-off-by: Vladimir Bataev --------- Signed-off-by: Vladimir Bataev Signed-off-by: artbataev Co-authored-by: artbataev --- .../asr/parts/k2/graph_transducer.py | 155 ++++++++--- .../collections/asr/parts/k2/rnnt_logprobs.py | 44 +++ .../asr/parts/k2/rnnt_logprobs_triton.py | 250 ++++++++++++++++++ nemo/core/utils/optional_libs.py | 34 +++ requirements/requirements.txt | 1 + .../asr/k2/test_graph_transducer.py | 91 ++++++- 6 files changed, 528 insertions(+), 47 deletions(-) create mode 100644 nemo/collections/asr/parts/k2/rnnt_logprobs.py create mode 100644 nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py create mode 100644 nemo/core/utils/optional_libs.py diff --git a/nemo/collections/asr/parts/k2/graph_transducer.py b/nemo/collections/asr/parts/k2/graph_transducer.py index bcd49bcbd7a9..874e6e6fd2b4 100644 --- a/nemo/collections/asr/parts/k2/graph_transducer.py +++ b/nemo/collections/asr/parts/k2/graph_transducer.py @@ -15,11 +15,17 @@ import abc from contextlib import nullcontext from typing import ContextManager + import torch import torch.nn.functional as F from nemo.core.classes.loss import Loss from nemo.core.utils.k2_guard import k2 +from nemo.core.utils.optional_libs import TRITON_AVAILABLE +from nemo.utils import logging + +if TRITON_AVAILABLE: + from nemo.collections.asr.parts.k2.rnnt_logprobs_triton import rnnt_logprobs_triton def force_float32_context() -> ContextManager: @@ -129,13 +135,13 @@ def get_composed_lattice(self, units_tensor: torch.Tensor, num_frames: int, voca return composed def get_graphs_batched( - self, logits_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int + self, source_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int ) -> "k2.Fsa": """ Get batched lattice (grid or composed) for the batch of sequences. Args: - logits_lengths: tensor with lengths of logits + source_lengths: tensor with lengths of logits targets: tensor with target units target_lengths: tensor with lengths of targets vocab_size: vocab size (including blank) @@ -143,14 +149,16 @@ def get_graphs_batched( Returns: batched lattice - FsaVec (k2.Fsa) """ - batch_size = logits_lengths.shape[0] + batch_size = source_lengths.shape[0] with torch.no_grad(): if self.use_grid_implementation: + source_lengths_list = source_lengths.tolist() + target_lengths_list = target_lengths.tolist() return k2.create_fsa_vec( [ self.get_grid( - units_tensor=targets[i, : target_lengths[i].item()], - num_frames=logits_lengths[i].item(), + units_tensor=targets[i, : target_lengths_list[i]], + num_frames=source_lengths_list[i], vocab_size=vocab_size, ) for i in range(batch_size) @@ -167,7 +175,7 @@ def get_graphs_batched( ] temporal_fsas = [ self.get_temporal_schema( - num_frames=logits_lengths[i].item(), vocab_size=vocab_size, device=targets.device + num_frames=source_lengths[i].item(), vocab_size=vocab_size, device=targets.device ) for i in range(batch_size) ] @@ -175,22 +183,20 @@ def get_graphs_batched( k2.create_fsa_vec(text_fsas), k2.create_fsa_vec(temporal_fsas), treat_epsilons_specially=False ) if self.connect_composed: - k2.connect(target_fsas_vec) + target_fsas_vec = k2.connect(target_fsas_vec) return target_fsas_vec - def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor: + def get_batch_indices(self, target_fsas_vec: k2.Fsa) -> torch.Tensor: """ - Get indices of flatten logits for each arc in the lattices. + Get batch indices (for logits) for each arc in the lattices. Args: target_fsas_vec: batch of target FSAs with lattices - logits_shape: shape of the logits tensor Returns: 1d tensor with indices """ - # logits_shape: B x Time x Text+1 x Labels - batch_size = logits_shape[0] + batch_size = target_fsas_vec.shape[0] device = target_fsas_vec.device scores_to_batch_i = torch.repeat_interleave( torch.arange(batch_size, device=device, dtype=torch.int64), @@ -199,6 +205,21 @@ def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) device=device, ), ) + return scores_to_batch_i + + def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor: + """ + Get indices of flatten logits for each arc in the lattices. + + Args: + target_fsas_vec: batch of target FSAs with lattices + logits_shape: shape of the logits tensor + + Returns: + 1d tensor with indices + """ + # logits_shape: B x Time x Text+1 x Labels + scores_to_batch_i = self.get_batch_indices(target_fsas_vec=target_fsas_vec) indices = ( scores_to_batch_i * logits_shape[1] * logits_shape[2] * logits_shape[3] # Batch + target_fsas_vec.aux_labels.to(torch.int64) * logits_shape[2] * logits_shape[3] # Time indices @@ -222,6 +243,8 @@ def __init__( connect_composed=False, double_scores=False, cast_to_float32=False, + return_graph=False, + use_triton=True, ): """ Init method @@ -232,8 +255,11 @@ def __init__( connect_composed: Connect graph after composing unit and temporal schemas (only for Compose-Transducer). `connect` operation is slow, it is useful for visualization, but not necessary for loss computation. double_scores: Use calculation of loss in double precision (float64) in the lattice. - Does not significantly affect memory usage since the lattice is ~V/2 times smaller than the joint tensor. + Does not significantly affect memory usage since the lattice is ~V/2 times smaller + than the joint tensor. cast_to_float32: Force cast joint tensor to float32 before log-softmax calculation. + return_graph: Return graph (along with loss) from `forward` function + use_triton: use optimized log probs calculations with Triton (faster and more memory efficient) """ super().__init__( use_grid_implementation=use_grid_implementation, @@ -242,6 +268,10 @@ def __init__( cast_to_float32=cast_to_float32, ) self.blank = blank + self.return_graph = return_graph + self.use_triton = use_triton and TRITON_AVAILABLE + if not self.use_triton: + logging.warning("Triton is disabled, memory usage can be larger") def get_unit_schema(self, units_tensor: torch.Tensor, vocab_size: int) -> "k2.Fsa": """ @@ -370,13 +400,14 @@ def relabel_states(states: torch.Tensor, n: int, m: int) -> torch.Tensor: anti_diag = m + n - 1 - diag max_idx = n * m - 1 cur_diag_idx = i if m > n else m - j - 1 - states = ( + new_states = ( diag.lt(min_mn) * ((diag * (diag + 1) >> 1) + i) + torch.logical_and(diag.ge(min_mn), diag.lt(max_mn)) * ((min_mn * (min_mn + 1) >> 1) + (diag - min_mn) * min_mn + cur_diag_idx) + diag.ge(max_mn) * (max_idx - (anti_diag * (anti_diag + 1) >> 1) + m - j) ) - return states + torch.where(states >= n * m, states, new_states, out=new_states) + return new_states def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) -> "k2.Fsa": """ @@ -445,13 +476,76 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) rnnt_graph.unit_positions = unit_positions return rnnt_graph + def get_weighted_graphs( + self, + logits: torch.Tensor, + targets: torch.Tensor, + source_lengths: torch.Tensor, + target_lengths: torch.Tensor, + use_graph_weight=False, + ) -> "k2.Fsa": + """ + Get batch of graphs (FsaVec) for RNN-T loss calculation. + + Args: + logits: activations (joint tensor). NB: raw logits, not after log-softmax + targets: target labels + source_lengths: lengths of source sequences + target_lengths: length of target sequences + use_graph_weight: uses weight from graphs (if `get_graphs_batched` returns graphs with weights) + + Returns: + FsaVec containing RNN-T graphs for all utterances. + """ + vocab_size = logits.shape[-1] + target_fsas_vec = self.get_graphs_batched(source_lengths, targets, target_lengths, vocab_size) + + with torch.no_grad(): + # last transitions in the graph are labeled with -1 label + last_transition_mask = target_fsas_vec.labels == -1 + batch_indices = self.get_batch_indices(target_fsas_vec=target_fsas_vec) + time_indices = target_fsas_vec.aux_labels.clone().to(torch.int64) + unit_indices = target_fsas_vec.unit_positions.clone().to(torch.int64) + text_units = target_fsas_vec.labels.clone().to(torch.int64) + # fill in the indices outside the logits with 0, replace later + text_units.masked_fill_(last_transition_mask, 0) + + cast_context = force_float32_context() if self.cast_to_float32 else nullcontext() + with cast_context: + # NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign) + if self.use_triton and logits.device.type == "cuda": + unit_scores, blank_scores = rnnt_logprobs_triton( + logits=logits, + targets=targets, + blank_id=self.blank, + source_lengths=source_lengths, + target_lengths=target_lengths, + ) + text_units_blank_mask = text_units == self.blank + scores = torch.where( + text_units_blank_mask, + blank_scores[batch_indices, time_indices, unit_indices], + unit_scores[batch_indices, time_indices, unit_indices], + ).to(torch.float32) + scores[last_transition_mask] = 0.0 # fix weights for the arcs to the last state + else: + log_probs = F.log_softmax(logits, dim=-1) + scores = log_probs[batch_indices, time_indices, unit_indices, text_units].to(torch.float32) + scores[last_transition_mask] = 0.0 + + if use_graph_weight: + target_fsas_vec.scores = target_fsas_vec.scores + scores + else: + target_fsas_vec.scores = scores + return target_fsas_vec + def forward( self, acts: torch.Tensor, labels: torch.Tensor, act_lens: torch.Tensor, label_lens: torch.Tensor, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, "k2.Fsa"]: """ Compute forward method for RNN-T. @@ -466,26 +560,11 @@ def forward( """ # argument names are consistent with NeMo, see RNNTLoss.forward: # self._loss(acts=log_probs, labels=targets, act_lens=input_lengths, label_lens=target_lengths) - logits, targets, logits_lengths, target_lengths = acts, labels, act_lens, label_lens - - # logits: B x Time x Text+1 x C - vocab_size = logits.shape[-1] - target_fsas_vec = self.get_graphs_batched(logits_lengths, targets, target_lengths, vocab_size) - - cast_context = force_float32_context() if self.cast_to_float32 else nullcontext() - with cast_context: - log_probs = F.log_softmax(logits, dim=-1) - with torch.no_grad(): - indices = self.get_logits_indices(target_fsas_vec, logits.shape) - # transition to the last state - # use 0 index (for valid index_select) and manually assign score after index_select for this case - indices[target_fsas_vec.labels == -1] = 0 - - # NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign) - scores = log_probs.flatten().index_select(-1, indices) - # fix weights for the arcs to the last state - scores[target_fsas_vec.labels == -1] = 0 + target_fsas_vec = self.get_weighted_graphs( + logits=acts, targets=labels, source_lengths=act_lens, target_lengths=label_lens, use_graph_weight=False + ) - target_fsas_vec.scores = scores - scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True) - return scores + scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True) + if self.return_graph: + return scores, target_fsas_vec + return scores diff --git a/nemo/collections/asr/parts/k2/rnnt_logprobs.py b/nemo/collections/asr/parts/k2/rnnt_logprobs.py new file mode 100644 index 000000000000..c41615f83bf9 --- /dev/null +++ b/nemo/collections/asr/parts/k2/rnnt_logprobs.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def rnnt_logprobs_torch( + logits: torch.Tensor, targets: torch.Tensor, blank_id: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Given logits, calculate log probabilities for blank and target labels needed for transducer loss calculation. + Naive implementation in PyTorch, for testing and prototyping purposes. + + Args: + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + + Returns: + Tuple of tensors with log probabilities for targets and blank labels, both of size [B, T, U+1]. + For the last non-existent target (U+1) output is zero. + """ + device = logits.device + batch_size = logits.shape[0] + log_probs = F.log_softmax(logits, dim=-1) + blank_scores = log_probs[..., blank_id] + targets = torch.cat((targets, torch.zeros([batch_size], dtype=targets.dtype, device=device).unsqueeze(1)), dim=-1) + target_scores = torch.gather( + log_probs, dim=-1, index=targets.unsqueeze(1).expand(log_probs.shape[:-1]).unsqueeze(-1) + ).squeeze(-1) + target_scores[:, :, -1] = 0.0 + return target_scores, blank_scores diff --git a/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py b/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py new file mode 100644 index 000000000000..64bc8abbdbeb --- /dev/null +++ b/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rnnt_logprobs_fwd_kernel( + logits_ptr, + targets_ptr, + source_lengths_ptr, + target_lengths_ptr, + max_source_len: int, + max_target_len_plus_1: int, + num_labels: int, # vocab size (with blank) + blank_id: int, + target_scores_ptr, + blank_scores_ptr, + BLOCK_SIZE: tl.constexpr, +): + """ + Forward kernel for RNN-T log probs. Stores result in `target_scores_ptr` and `blank_scores_ptr`. + Calculations are performed in float32 (but original tensors can use any precision). + """ + batch_i = tl.program_id(axis=0).to(tl.int64) + source_i = tl.program_id(axis=1).to(tl.int64) + target_i = tl.program_id(axis=2).to(tl.int64) + + # load lengths for source/target + source_len = tl.load(source_lengths_ptr + batch_i) + target_len = tl.load(target_lengths_ptr + batch_i) + + if source_i >= source_len or target_i > target_len: + # no calculations required + return + + # calculate offset in [B, T, U+1, V] tensor for the current vector with target logits + flat_index = ((batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i) * num_labels + logits_ptr += flat_index + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_labels + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + # stable log softmax calculation + logits_max = tl.max(logits, axis=0) + logits_minus_max = logits - logits_max + denominator = tl.log(tl.sum(tl.exp(logits_minus_max), axis=0)) + blank_logit = tl.load(logits_ptr + blank_id).to(tl.float32) + flat_index_output = (batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i + tl.store(blank_scores_ptr + flat_index_output, blank_logit - logits_max - denominator) + + # calculate log prob for target if needed + if target_i < target_len: + target_id = tl.load(targets_ptr + batch_i * (max_target_len_plus_1 - 1) + target_i) + target_logit = tl.load(logits_ptr + target_id).to(tl.float32) + tl.store(target_scores_ptr + flat_index_output, target_logit - logits_max - denominator) + + +@triton.jit +def _rnnt_logprobs_bwd_kernel( + logits_ptr, + grad_logits_ptr, + targets_ptr, + source_lengths_ptr, + target_lengths_ptr, + max_source_len: int, + max_target_len_plus_1: int, + num_labels: int, + blank_id: int, + grad_target_scores_ptr, + grad_blank_scores_ptr, + BLOCK_SIZE: tl.constexpr, +): + """ + Backward kernel for RNN-T log probs. Stores result in `grad_target_scores_ptr` and `grad_blank_scores_ptr`. + We recalculate part of the forward here to avoid using extra memory in forward. + Calculations are performed in float32 (but original tensors can use any precision). + """ + batch_i = tl.program_id(axis=0).to(tl.int64) + source_i = tl.program_id(axis=1).to(tl.int64) + target_i = tl.program_id(axis=2).to(tl.int64) + + # load lengths for source/target + source_len = tl.load(source_lengths_ptr + batch_i) + target_len = tl.load(target_lengths_ptr + batch_i) + if source_i >= source_len or target_i > target_len: + # no calculations required + return + + # calculate offset in [B, T, U+1, V] tensor for the current vector with target logits/grad_logits + flat_index = ((batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i) * num_labels + logits_ptr += flat_index + grad_logits_ptr += flat_index + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_labels + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + # stable log softmax calculation + logits_max = tl.max(logits, axis=0) + logits_minus_max = logits - logits_max + denominator = tl.log(tl.sum(tl.exp(logits_minus_max), axis=0)) + log_softmax = logits_minus_max - denominator + # softmax for gradient + softmax = tl.exp(log_softmax) + + flat_index_grad = (batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i + blank_grad = tl.load(grad_blank_scores_ptr + flat_index_grad).to(tl.float32) + target_i_valid = target_i < target_len + target_grad = tl.load(grad_target_scores_ptr + flat_index_grad, mask=target_i_valid, other=0.0).to(tl.float32) + target_id = tl.load(targets_ptr + batch_i * (max_target_len_plus_1 - 1) + target_i, mask=target_i_valid, other=-1) + + grad_not_in_targets = (-softmax) * (blank_grad + target_grad) + grad = tl.where(col_offsets == blank_id, blank_grad + grad_not_in_targets, grad_not_in_targets) + grad = tl.where(col_offsets == target_id, target_grad + grad_not_in_targets, grad) + tl.store(grad_logits_ptr + col_offsets, grad, mask=mask) + + +class RnntLogProbs(torch.autograd.Function): + """ + Function to calculate log probabilities for target and blank labels for RNN-T, supporting torch.autograd. + """ + + @staticmethod + def forward( + ctx, + logits: torch.Tensor, + targets: torch.Tensor, + blank_id: int, + source_lengths: torch.Tensor | None, + target_lengths: torch.Tensor | None, + ): + """ + + Args: + ctx: ctx object for storing the context + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + source_lengths: optional tensor with lengths for source utterances + target_lengths: optional tensor with lengths for targets + + Returns: + + """ + assert logits.is_contiguous() # logits are huge, so here we just check if logits are contiguous + targets = targets.contiguous() + device = logits.device + float_dtype = torch.float32 + + target_scores = torch.zeros(logits.shape[:-1], dtype=float_dtype, device=device) + blank_scores = torch.zeros_like(target_scores) + if source_lengths is None: + source_lengths = torch.full([logits.shape[0]], fill_value=logits.shape[1], dtype=torch.int, device=device) + else: + source_lengths = source_lengths.contiguous() + if target_lengths is None: + target_lengths = torch.full( + [logits.shape[0]], fill_value=logits.shape[2] - 1, dtype=torch.int, device=device + ) + else: + target_lengths = target_lengths.contiguous() + + # run Triton kernel + _rnnt_logprobs_fwd_kernel[(logits.shape[0], logits.shape[1], logits.shape[2])]( + logits_ptr=logits, + targets_ptr=targets, + source_lengths_ptr=source_lengths, + target_lengths_ptr=target_lengths, + max_source_len=logits.shape[1], + max_target_len_plus_1=logits.shape[2], + num_labels=logits.shape[3], + blank_id=blank_id, + target_scores_ptr=target_scores, + blank_scores_ptr=blank_scores, + BLOCK_SIZE=triton.next_power_of_2(logits.shape[-1]), + ) + + # saving for backward + ctx.save_for_backward(logits, targets, source_lengths, target_lengths) + ctx.blank_id = blank_id + return target_scores, blank_scores + + @staticmethod + def backward(ctx, grad_target_scores, grad_blank_scores): + """ + Backward calculation for RNN-T log-probs. + + Args: + ctx: ctx object for storing the context + grad_target_scores: upstream gradient for targets + grad_blank_scores: upstream gradient for blank scores + + Returns: + gradient for logits, None for all other arguments for `forward` + """ + (logits, targets, source_lengths, target_lengths) = ctx.saved_tensors + blank_id = ctx.blank_id + grad_logits = torch.zeros_like(logits) + _rnnt_logprobs_bwd_kernel[(logits.shape[0], logits.shape[1], logits.shape[2])]( + logits_ptr=logits, + grad_logits_ptr=grad_logits, + source_lengths_ptr=source_lengths, + target_lengths_ptr=target_lengths, + targets_ptr=targets, + max_source_len=logits.shape[1], + max_target_len_plus_1=logits.shape[2], + num_labels=logits.shape[3], + blank_id=blank_id, + grad_target_scores_ptr=grad_target_scores, + grad_blank_scores_ptr=grad_blank_scores, + BLOCK_SIZE=triton.next_power_of_2(logits.shape[-1]), + ) + return grad_logits, None, None, None, None + + +def rnnt_logprobs_triton( + logits: torch.Tensor, + targets: torch.Tensor, + blank_id: int, + source_lengths: torch.Tensor | None = None, + target_lengths: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Given logits, calculate log probabilities for blank and target labels needed for transducer loss calculation. + Optimized implementation in Triton. + + Args: + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + source_lengths: optional tensor with lengths for source utterances + target_lengths: optional tensor with lengths for targets + + Returns: + Tuple of tensors with log probabilities for targets and blank labels, both of size [B, T, U+1]. + For the non-existent targets (U+1 or beyond target_lengths) output is zero. + """ + return RnntLogProbs.apply(logits, targets, blank_id, source_lengths, target_lengths) diff --git a/nemo/core/utils/optional_libs.py b/nemo/core/utils/optional_libs.py new file mode 100644 index 000000000000..9aa39260963c --- /dev/null +++ b/nemo/core/utils/optional_libs.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util + + +def is_lib_available(name: str) -> bool: + """ + Checks if the library/package with `name` is available in the system + NB: try/catch with importlib.import_module(name) requires importing the library, which can be slow. + So, `find_spec` should be preferred + """ + return importlib.util.find_spec(name) is not None + + +TRITON_AVAILABLE = is_lib_available("triton") + +try: + from nemo.core.utils.k2_guard import k2 as _ + + K2_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + K2_AVAILABLE = False diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7fd5e88eebe3..1b9fc88000b9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,5 +11,6 @@ tensorboard text-unidecode torch tqdm>=4.41.0 +triton>=3.1.0; sys_platform == 'linux' wget wrapt diff --git a/tests/collections/asr/k2/test_graph_transducer.py b/tests/collections/asr/k2/test_graph_transducer.py index 5879226e782d..592772767484 100644 --- a/tests/collections/asr/k2/test_graph_transducer.py +++ b/tests/collections/asr/k2/test_graph_transducer.py @@ -12,29 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import List import numpy as np import pytest import torch +from nemo.collections.asr.parts.k2.rnnt_logprobs import rnnt_logprobs_torch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy +from nemo.core.utils.optional_libs import K2_AVAILABLE, TRITON_AVAILABLE + +if K2_AVAILABLE: + import k2 -try: from nemo.collections.asr.parts.k2.graph_transducer import GraphRnntLoss - from nemo.core.utils.k2_guard import k2 -except (ImportError, ModuleNotFoundError): - pytest.skip("k2 is not installed, skipping Graph-RNNT tests.", allow_module_level=True) + +if TRITON_AVAILABLE: + from nemo.collections.asr.parts.k2.rnnt_logprobs_triton import rnnt_logprobs_triton + EPS_SM_INPUT = 1e-6 EPS_L_INPUT = 1e-4 DEVICES = ['cpu'] -if torch.cuda.is_available() and k2.with_cuda: +if K2_AVAILABLE and torch.cuda.is_available() and k2.with_cuda: DEVICES.append('cuda') +@pytest.mark.skipif(not K2_AVAILABLE, reason="k2 is not installed, skipping Graph-RNNT tests.") class TestGraphRnnt: @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) @@ -214,9 +221,12 @@ def test_small_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_d @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) - def test_medium_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_data): + @pytest.mark.parametrize("use_triton", [True, False]) + def test_medium_grid_transducer(self, device, use_triton: bool, rnnt_test_helper, rnn_loss_sample_data): + if use_triton and device == "cpu": + pytest.skip("Triton does not support CPU yet") sample_data = rnn_loss_sample_data.get_sample_medium() - graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True) + graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True, use_triton=use_triton) graph_cost, graph_grads = rnnt_test_helper.wrap_and_call( graph_rnnt, sample_data.logits, sample_data.targets, device ) @@ -225,9 +235,12 @@ def test_medium_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_ @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) - def test_medium_random_var_size(self, device, rnnt_test_helper, rnn_loss_sample_data): + @pytest.mark.parametrize("use_triton", [True, False]) + def test_medium_random_var_size(self, device, use_triton: bool, rnnt_test_helper, rnn_loss_sample_data): + if use_triton and device == "cpu": + pytest.skip("Triton does not support CPU yet") sample_data = rnn_loss_sample_data.get_sample_medium_random_var_size(blank_first=True) - graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True) + graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True, use_triton=use_triton) graph_cost, graph_grads = rnnt_test_helper.wrap_and_call( graph_rnnt, sample_data.logits.detach(), @@ -261,3 +274,63 @@ def test_small_random_grid_compose_equivalent(self, device: torch.device, blank_ assert k2.is_rand_equivalent( graph_grid, graph_composed, log_semiring=True, treat_epsilons_specially=False ), "Grid and composed graphs are not equivalent." + + +@pytest.mark.skipif(not TRITON_AVAILABLE, reason="Triton is not installed, skipping RNNT Log Probs tests") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is unavailable") +class TestRnntLogProbs: + @pytest.mark.parametrize( + "batch_size,num_frames,num_text_units,vocab_size", + [ + (1, 4, 2, 4), + (2, 3, 2, 5), + (2, 16, 31, 17), + (16, 129, 65, 2048), + ], + ) + @pytest.mark.parametrize( + "float_dtype", + [torch.float32] + ([torch.bfloat16] if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else []), + ) + def test_rnnt_logprobs_random( + self, batch_size: int, num_frames: int, num_text_units: int, vocab_size: int, float_dtype: torch.dtype + ): + """ + Test Triton-based implementation using etalon Torch-based implementation for RNN-T log-probs. + """ + device = torch.device("cuda") + torch.manual_seed(777) + + targets = torch.tensor( + [[random.randrange(0, vocab_size - 1) for i in range(num_text_units)] for j in range(batch_size)], + device=device, + dtype=torch.long, + ) + + logits = torch.rand( + [batch_size, num_frames, num_text_units + 1, vocab_size + 1], + dtype=float_dtype, + device=device, + requires_grad=True, + ) + + # Triton-based implementation works in float32 precision for accuracy purposes, should compare with float32 + target_scores_etalon, blank_scores_etalon = rnnt_logprobs_torch( + logits=logits.to(torch.float32), targets=targets, blank_id=vocab_size + ) + logits2 = logits.clone().detach() + logits2.requires_grad_(True) + target_scores, blank_scores = rnnt_logprobs_triton(logits=logits2, targets=targets, blank_id=vocab_size) + target_scores[..., -1:] = 0.0 + target_scores_etalon[..., -1:] = 0.0 + assert torch.allclose(blank_scores, blank_scores_etalon, atol=1e-5) + assert torch.allclose(target_scores, target_scores_etalon, atol=1e-5) + + # test backward + target_scales = torch.rand_like(target_scores, requires_grad=False) + blank_scales = torch.rand_like(blank_scores, requires_grad=False) + loss_etalon = (target_scales * target_scores_etalon + blank_scales * blank_scores_etalon).sum() + loss = (target_scales * target_scores + blank_scales * blank_scores).sum() + loss_etalon.backward() + loss.backward() + assert torch.allclose(logits.grad, logits2.grad, atol=1e-5) From f2169a1528c0668bf78d8a3344c2a1a1311ec61c Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Tue, 17 Dec 2024 08:41:01 -0800 Subject: [PATCH 03/20] Use explicit subpaths in io for exporting a checkpoint (#11352) * Fix llm.export_ckpt Signed-off-by: Hemil Desai * fix Signed-off-by: Hemil Desai --------- Signed-off-by: Hemil Desai --- nemo/collections/llm/api.py | 2 +- nemo/collections/llm/gpt/model/llama.py | 9 ++++++--- nemo/lightning/io/connector.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index d030eb88863c..3e63bcea9447 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -605,7 +605,7 @@ def import_ckpt( def load_connector_from_trainer_ckpt(path: Path, target: str) -> io.ModelConnector: - return io.load_context(path).model.exporter(target, path) + return io.load_context(path, subpath="model").exporter(target, path) @run.cli.entrypoint(name="export", namespace="llm") diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index a7e995addb83..04540294d82a 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -344,7 +344,10 @@ def apply(self, output_path: Path) -> Path: target = target.cpu() target.save_pretrained(output_path) - self.tokenizer.save_pretrained(output_path) + try: + self.tokenizer.save_pretrained(output_path) + except Exception: + logging.warning("Failed to save tokenizer") return output_path @@ -366,11 +369,11 @@ def convert_state(self, source, target): @property def tokenizer(self): - return io.load_context(str(self)).model.tokenizer.tokenizer + return io.load_context(str(self), subpath="model").tokenizer.tokenizer @property def config(self) -> "HFLlamaConfig": - source: LlamaConfig = io.load_context(str(self)).model.config + source: LlamaConfig = io.load_context(str(self), subpath="model.config") from transformers import LlamaConfig as HFLlamaConfig diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index bf07956f2cd2..258d2848a63a 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -226,7 +226,7 @@ def nemo_load( from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib from nemo.lightning.io.api import load_context - model = load_context(path).model + model = load_context(path, subpath="model") _trainer = trainer or Trainer( devices=1, accelerator="cpu" if cpu else "gpu", From 431cd0880ca7daf8b73216fe34262b3d1b13b1cf Mon Sep 17 00:00:00 2001 From: Dong Hyuk Chang Date: Tue, 17 Dec 2024 12:59:29 -0500 Subject: [PATCH 04/20] Remove triton requirement (#11627) * Specify pytorch-triton instead of triton Signed-off-by: Dong Hyuk Chang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove triton Signed-off-by: Dong Hyuk Chang --------- Signed-off-by: Dong Hyuk Chang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- requirements/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1b9fc88000b9..7fd5e88eebe3 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,6 +11,5 @@ tensorboard text-unidecode torch tqdm>=4.41.0 -triton>=3.1.0; sys_platform == 'linux' wget wrapt From 993e575bf0c4d733a55f898dd15433ef7491eb8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 17 Dec 2024 20:08:43 +0100 Subject: [PATCH 05/20] ci: Remove comment if no changes required anymore (#11624) Signed-off-by: Oliver Koenig --- .github/workflows/code-formatting.yml | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/.github/workflows/code-formatting.yml b/.github/workflows/code-formatting.yml index b08e9676aabd..0eaad048b3a5 100644 --- a/.github/workflows/code-formatting.yml +++ b/.github/workflows/code-formatting.yml @@ -118,7 +118,7 @@ jobs: FILTERED=() for file in $CHANGED_FILES; do - DATE=$(git log --format=%ad --date=unix $file | tail -1) + DATE=$(git log --format=%ad --date=unix "$file" | tail -1) if [[ "$STRICT_MODE" == "true" ]]; then if [[ "$DATE" -gt "$THRESHOLD" ]]; then @@ -139,16 +139,18 @@ jobs: echo "Will run on these files: ${FILTERED[@]}" - set +e + set +xe LOG=$(pylint ${FILTERED[@]}) EXIT_CODE=$? set -e - echo "$LOG" echo "OUTPUT<> $GITHUB_ENV echo "$LOG" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV echo "log=$LOG" + set -x + + echo "exit-code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" if [[ "${{ matrix.strict-mode }}" == "true" ]]; then HEADER="🚨 The following files must be fixed before merge!" @@ -160,7 +162,7 @@ jobs: exit $([[ "$EXIT_CODE" -ne 0 && "$STRICT_MODE" == "true" ]] && echo $EXIT_CODE || echo 0) - name: Find Comment - if: ${{ always() && env.OUTPUT != '' }} + if: ${{ always() }} uses: peter-evans/find-comment@v3 id: fc with: @@ -168,7 +170,7 @@ jobs: body-includes: - name: Delete comment - if: ${{ always() && env.OUTPUT != '' && steps.fc.outputs.comment-id != '' }} + if: ${{ always() && steps.fc.outputs.comment-id != '' }} env: GH_TOKEN: ${{ secrets.github_token }} REPOSITORY: ${{ github.repository }} @@ -182,7 +184,7 @@ jobs: https://api.github.com/repos/$REPOSITORY/issues/comments/$COMMENT_ID - name: Add PR comment for PyLint - if: ${{ always() && env.OUTPUT != '' }} + if: ${{ always() && steps.pylint.outputs.exit-code != '0' }} uses: peter-evans/create-or-update-comment@v4 with: issue-number: ${{ github.event.number }} @@ -200,5 +202,13 @@ jobs: ``` --- + + Mitigation guide: + + * Add sensible and useful docstrings to functions and methods + * For trivial methods like getter/setters, consider adding `# pylint: disable=C0116` inside the function itself + * To disable multiple functions/methods at once, put a `# pylint: disable=C0116` before the first and a `# pylint: enable=C0116` after the last. + + By applying these rules, we reduce the occurance of this message in future. Thank you for improving NeMo's documentation! From de0b2e258e0ee6a629465ccfe372960011ff4be0 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:26:28 -0800 Subject: [PATCH 06/20] Jit with peft (#11586) * move jitransform at the end Signed-off-by: Alexandros Koumparoulis * add docstring & post-init Signed-off-by: Alexandros Koumparoulis * Add remove_extra_batch_keys and remove align_labels Signed-off-by: Alexandros Koumparoulis * Run JitTransform on_train_epoch_start Signed-off-by: Alexandros Koumparoulis * add --use-torch-jit option Signed-off-by: Alexandros Koumparoulis * add docstrings Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * pep8 Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- examples/llm/peft/hf.py | 17 ++++- examples/llm/sft/hf.py | 35 ++++++++-- nemo/collections/llm/api.py | 11 ++- .../gpt/model/hf_auto_model_for_causal_lm.py | 53 ++++++++++----- .../pytorch/callbacks/jit_transform.py | 68 ++++++++++++++++++- 5 files changed, 155 insertions(+), 29 deletions(-) diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 3a0930732e87..45675398a421 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -16,6 +16,7 @@ from lightning.pytorch.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm +from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform def make_squad_hf_dataset(tokenizer): @@ -53,7 +54,7 @@ def formatting_prompts_func(examples): return datamodule -if __name__ == '__main__': +def main(): import argparse parser = argparse.ArgumentParser() @@ -63,6 +64,7 @@ def formatting_prompts_func(examples): parser.add_argument('--accelerator', default='gpu', choices=['gpu']) parser.add_argument('--max-steps', type=int, default=100) parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--use-torch-jit', action='store_true') args = parser.parse_args() wandb = None @@ -74,11 +76,17 @@ def formatting_prompts_func(examples): ) grad_clip = 0.5 if args.strategy == 'fsdp': - # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + # See: + # https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 grad_clip = None use_dist_samp = False tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model) + callbacks = [] + if args.use_torch_jit: + jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False) + callbacks = [JitTransform(jit_config)] + llm.api.finetune( model=llm.HFAutoModelForCausalLM(args.model), data=make_squad_hf_dataset(tokenizer.tokenizer), @@ -94,6 +102,7 @@ def formatting_prompts_func(examples): gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, logger=wandb, + callbacks=callbacks, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), log=None, @@ -102,3 +111,7 @@ def formatting_prompts_func(examples): dim=32, ), ) + + +if __name__ == '__main__': + main() diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index ce79e136a1c2..ff85180cf86b 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -20,10 +20,12 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated -from nemo.lightning.pytorch.callbacks import ModelCallback +from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform class SquadDataModuleWithPthDataloader(llm.SquadDataModule): + """Creates a squad dataset with a PT dataloader""" + def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: return DataLoader( dataset, @@ -37,17 +39,30 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: def squad(tokenizer) -> pl.LightningDataModule: + """Instantiates a SquadDataModuleWithPthDataloader and return it + + Args: + tokenizer (AutoTokenizer): the tokenizer to use + + Returns: + pl.LightningDataModule: the dataset to train with. + """ return SquadDataModuleWithPthDataloader( tokenizer=tokenizer, - seq_length=2048, + seq_length=512, micro_batch_size=2, global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches num_workers=0, - dataset_kwargs={"sanity_check_dist_workers": False}, + dataset_kwargs={ + "sanity_check_dist_workers": False, + "pad_to_max_length": True, + "get_attention_mask_from_fusion": True, + }, ) -if __name__ == '__main__': +def main(): + """Example script to run SFT with a HF transformers-instantiated model on squad.""" import argparse parser = argparse.ArgumentParser() @@ -60,6 +75,7 @@ def squad(tokenizer) -> pl.LightningDataModule: parser.add_argument("--fp8-autocast", default=False, action='store_true') parser.add_argument('--wandb-project', type=str, default=None) parser.add_argument('--model-save-path', type=str, default=None) + parser.add_argument('--use-torch-jit', action='store_true') args = parser.parse_args() wandb = None @@ -87,6 +103,11 @@ def squad(tokenizer) -> pl.LightningDataModule: model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) tokenizer = model.tokenizer + callbacks = [] + if args.use_torch_jit: + jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False) + callbacks = [JitTransform(jit_config)] + llm.api.finetune( model=model, data=squad(tokenizer), @@ -101,8 +122,8 @@ def squad(tokenizer) -> pl.LightningDataModule: accumulate_grad_batches=10, gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, - callbacks=[], logger=wandb, + callbacks=callbacks, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), log=None, @@ -116,3 +137,7 @@ def squad(tokenizer) -> pl.LightningDataModule: if args.model_save_path is not None: model.save_pretrained(args.model_save_path) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 3e63bcea9447..7d7762edef3c 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -37,7 +37,7 @@ io, ) from nemo.lightning.base import NEMO_MODELS_CACHE -from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform +from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform from nemo.utils import logging from nemo.utils.get_rank import is_global_rank_zero @@ -875,7 +875,14 @@ def _setup( trainer.callbacks.append(model_transform) else: trainer.callbacks.append(ModelTransform()) - + # Move jit callback at the end ensure it's applied on top of any model transformations (peft) + jit_cb = None + for i, cb in enumerate(trainer.callbacks): + if isinstance(cb, JitTransform): + assert jit_cb is None + jit_cb = trainer.callbacks.pop(i) + if jit_cb is not None: + trainer.callbacks.append(jit_cb) return app_state diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index a51bbffdd6b6..2d8b32964767 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -26,24 +26,11 @@ def masked_cross_entropy(logits, targets, mask=None): if mask is not None: loss = F.cross_entropy(logits, targets, reduction='none') - return torch.mean(loss[mask == 1]) + return torch.mean(loss * mask.view(-1)) else: return F.cross_entropy(logits, targets) -def align_labels(logits, labels): - logits = logits.float() - n_cls = logits.shape[-1] - if logits.shape[-2] == labels.shape[-1]: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - elif logits.shape[-2] == labels.shape[-1] + 1: - logits = logits[..., :-1, :].contiguous() - else: - raise ValueError("Mismatched labels and logits shapes (" + str(labels.shape) + " " + str(logits.shape)) - return logits.view(-1, n_cls), labels.view(-1) - - class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): def __init__( self, @@ -111,14 +98,21 @@ def training_step(self, batch): labels = batch.pop('labels').to(self.model.device) loss_mask = batch.pop('loss_mask', None) + # GPTSFTDataset emits `tokens` instead of `input_ids` + if not 'input_ids' in batch and 'tokens' in batch: + batch['input_ids'] = batch['tokens'] + batch = self._remove_extra_batch_keys(batch) + outputs = self.forward(batch) # Prepare for loss calculation - logits, labels = align_labels(outputs.logits.float(), labels) + logits = outputs.logits.float() + n_cls = logits.shape[-1] + logits, labels = logits.view(-1, n_cls), labels.view(-1) assert logits.shape[-2] == labels.shape[-1] loss = self.loss_fn(logits, labels, loss_mask) - self.log('train_log', loss, on_step=True, on_epoch=True, prog_bar=True) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) return loss @torch.no_grad @@ -126,12 +120,20 @@ def validation_step(self, batch, batch_idx): labels = batch.pop('labels').to(self.model.device) loss_mask = batch.pop('loss_mask', None) + # GPTSFTDataset emits `tokens` instead of `input_ids` + if not 'input_ids' in batch and 'tokens' in batch: + batch['input_ids'] = batch['tokens'] + batch = self._remove_extra_batch_keys(batch) + outputs = self.forward(**batch) - logits, labels = align_labels(outputs.logits.float(), labels) + # Prepare for loss calculation + logits = outputs.logits.float() + n_cls = logits.shape[-1] + logits, labels = logits.view(-1, n_cls), labels.view(-1) assert logits.shape[-2] == labels.shape[-1] - loss = self.loss_fn(logits, labels, loss_mask) + loss = self.loss_fn(logits, labels, loss_mask) self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) def save_pretrained(self, path): @@ -141,3 +143,18 @@ def save_pretrained(self, path): self._tokenizer.save_pretrained(path) else: logging.warning("A tokenizer wasn't created before to save.") + + def _remove_extra_batch_keys(self, batch, reserved_keys=['labels', 'loss_mask']): + """Remove extra keys from batch that are not kwargs in model's forward + + Args: + batch (dict): dictionary of tensors. + + Returns: + dict: dictionary of tensors; keys that are not in model's forward are removed. + """ + import inspect + + fwd_signature = inspect.signature(self.model.forward) + allowed_keys = list(fwd_signature.parameters.keys()) + reserved_keys + return {k: batch[k] for k in allowed_keys if k in batch} diff --git a/nemo/lightning/pytorch/callbacks/jit_transform.py b/nemo/lightning/pytorch/callbacks/jit_transform.py index cbfca8a25d88..33e76555f65d 100644 --- a/nemo/lightning/pytorch/callbacks/jit_transform.py +++ b/nemo/lightning/pytorch/callbacks/jit_transform.py @@ -22,6 +22,17 @@ def extract_module_attr_name(pl_module: "pl.LightningModule") -> str: + """Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail. + + Args: + pl_module (pl.LightningModule): the LightningModule used in training. + + Raises: + ValueError: if the pl_module has neither a .mdoel or .module + + Returns: + str: the attr-name of the nn.Module + """ if hasattr(pl_module, 'module'): return 'module' elif hasattr(pl_module, 'model'): @@ -31,12 +42,34 @@ def extract_module_attr_name(pl_module: "pl.LightningModule") -> str: def listify(x): + """Wraps input in a list, if not already a list. + + Args: + x (Anything): the input, can be anything. + + Returns: + Anything | list(Anything): Anything (if it's already a list) o/w list(Anything) + """ if not isinstance(x, list): return [x] return x def get_modules_from_selector(model, module_selector): + """Iterator over model's modules whose FQN match the module_selector. + + Args: + model (nn.Module): the model to iterate over. + module_selector (str): module selector, if empty or '*' will return the whole model. If + there's an asterisk in the name will match it as a regexp. + + Raises: + AttributeError: if the user provides an invalid selector. + AttributeError: if user's selector selects a non-nn.Module attribute. + + Yields: + Iterator(nn.Module): iterator over modules whose FQN matches module_selector + """ if module_selector is None or module_selector == '' or module_selector == '*': yield model return @@ -50,7 +83,7 @@ def get_modules_from_selector(model, module_selector): # handle wildcard selector # TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv for name, module in tmp.named_children(): - if re.match(item, name): + if re.match(item.replace('*', '.*'), name): yield module return @@ -65,6 +98,15 @@ def get_modules_from_selector(model, module_selector): def compile_module(config, module): + """Jit-compiles an nn.Module + + Args: + config (JitConfig): jit config + module (nn.Module): the module to be compiled + + Returns: + nn.Module: the (potentially) compiled module + """ if config.use_torch: module.compile(**config.torch_kwargs) return True @@ -88,12 +130,26 @@ def compile_module(config, module): @dataclass class JitConfig: + """Config POD for Jit transforms (e.g. torch.compile or thunder) + Options: + - module_selector (str): reg-exp to match modules to apply JitTransform to, useful for multi-trunk + models where you want to apply it on one of them only. If empty will apply transform to root + module. + - use_torch (bool): whether to use torch.compile or not. + - torch_kwargs (dict): kwargs to pass to torch.compile. + - use_thunder (bool): whether to use thunder or not. + - profile_thunder (bool): toggle for thunder's profiler. + """ + module_selector: str = '' use_torch: bool = False torch_kwargs: dict = field(default_factory=dict) use_thunder: bool = False profile_thunder: bool = False + def __post_init__(self): + assert not (self.use_torch and self.use_thunder), "use_torch cannot be used at the same time with use_thunder" + class JitTransform(Callback, IOMixin): """ @@ -112,7 +168,15 @@ def __init__(self, config: JitConfig): self.config = config assert not (self.config.use_torch and self.config.use_thunder) - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Jit-compiles the model at the start of the epoch. + While other events such as on_train_start are more suitable, we use on_train_epoch_start + since that is what is used in peft (we want to jit after adding the adapters). + + Args: + trainer (pl.Trainer): PTL trainer + pl_module (pl.LightningModule): PTL module + """ if self.config is None: return if not self.config.use_thunder and not self.config.use_torch: From 8ebb847169ab6a4662a245694a4d9c8452576e1f Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:32:46 -0800 Subject: [PATCH 07/20] NeMo-UX: add Hf's AutoModelForImageTextToText (#11321) * init commit Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * wip Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * fix Signed-off-by: Alexandros Koumparoulis * peft examp;le Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * move peft example to multimodal_llm Signed-off-by: Alexandros Koumparoulis * surface HFAutoModelForImageTextToText Signed-off-by: Alexandros Koumparoulis * add hf vlm dataset Signed-off-by: Alexandros Koumparoulis * move processor Signed-off-by: Alexandros Koumparoulis * train_log -> train_loss Signed-off-by: Alexandros Koumparoulis * vlm.HFDatasetDataModule pass collate_fn as argument Signed-off-by: Alexandros Koumparoulis * Update peft example Signed-off-by: Alexandros Koumparoulis * typo Signed-off-by: Alexandros Koumparoulis * remove unused var Signed-off-by: Alexandros Koumparoulis * Move example Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * remove unused Signed-off-by: Alexandros Koumparoulis * Small change Signed-off-by: Alexandros Koumparoulis * Fix loss calculation Signed-off-by: Alexandros Koumparoulis * Add extract_skipped_token_ids Signed-off-by: Alexandros Koumparoulis * Use vlm.HFAutoModelForImageTextToText.extract_skipped_token_ids Signed-off-by: Alexandros Koumparoulis * add test Signed-off-by: Alexandros Koumparoulis * Update logits/labels handling Signed-off-by: Alexandros Koumparoulis * add trust_remote_code to configure_processor Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * mini refactor Signed-off-by: Alexandros Koumparoulis * add LLAMA_TOKENS Signed-off-by: Alexandros Koumparoulis * update hf_dataset Signed-off-by: Alexandros Koumparoulis * Add lora_dtype for models with non-FP weights Signed-off-by: Alexandros Koumparoulis * Add load_in_4bit option Signed-off-by: Alexandros Koumparoulis * add default_dtype Signed-off-by: Alexandros Koumparoulis * add load_in_4bit to llm collection Signed-off-by: Alexandros Koumparoulis * rm import Signed-off-by: Alexandros Koumparoulis * fix asset path Signed-off-by: Alexandros Koumparoulis * move vlm test Signed-off-by: Alexandros Koumparoulis * move data offline Signed-off-by: Alexandros Koumparoulis * use signel gpu Signed-off-by: Alexandros Koumparoulis * pylint fix Signed-off-by: Alexandros Koumparoulis * pylint Signed-off-by: Alexandros Koumparoulis * pylint Signed-off-by: Alexandros Koumparoulis * drop align_labels Signed-off-by: Alexandros Koumparoulis * remove align_labels from llm too Signed-off-by: Alexandros Koumparoulis * use loss * mask instead of loss[mask == 1] Signed-off-by: Alexandros Koumparoulis * fix path Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: akoumpa --- .github/workflows/cicd-main.yml | 11 + examples/llm/peft/hf.py | 6 +- examples/vlm/hf/peft.py | 127 ++++++++ .../gpt/model/hf_auto_model_for_causal_lm.py | 18 +- nemo/collections/llm/peft/lora.py | 21 +- nemo/collections/vlm/__init__.py | 4 + nemo/collections/vlm/hf/data/hf_dataset.py | 281 ++++++++++++++++++ .../hf_auto_model_for_image_text_to_text.py | 191 ++++++++++++ tests/collections/vlm/hf/peft.py | 128 ++++++++ 9 files changed, 777 insertions(+), 10 deletions(-) create mode 100644 examples/vlm/hf/peft.py create mode 100644 nemo/collections/vlm/hf/data/hf_dataset.py create mode 100644 nemo/collections/vlm/hf/model/hf_auto_model_for_image_text_to_text.py create mode 100644 tests/collections/vlm/hf/peft.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 310d580e43f6..fce4ef2acfbd 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3600,6 +3600,16 @@ jobs: inference.repetition_penalty=1.0 \ inference.outfile_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/out.jsonl + L2_VLM_HF_Transformer_PEFT: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt + AFTER_SCRIPT: | + rm -rf nemo_experiments L2_HF_Transformer_PEFT: needs: [ cicd-test-container-setup ] @@ -4863,6 +4873,7 @@ jobs: - L2_HF_Transformer_SFT - L2_HF_Transformer_SFT_nemorun - L2_HF_Transformer_SFT_2gpu + - L2_VLM_HF_Transformer_PEFT - L2_HF_Transformer_SFT_2gpu_nemorun - L2_HF_Transformer_SFT_TE_Acceleration - L2_NeMo_2_SSM_Pretraining diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 45675398a421..3137a542ae01 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -40,7 +40,11 @@ def formatting_prompts_func(examples): output = output[0] text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN ans = tokenizer(text) - ans['labels'] = ans['input_ids'] + # 'input_ids' is a list, we want to remove EOS_TOKEN from input_ids and the first token from + # labels to align the two: + ans['labels'] = list(ans['input_ids'][1:]) + ans['input_ids'] = ans['input_ids'][:-1] + ans['attention_mask'] = ans['attention_mask'][:-1] return ans tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) diff --git a/examples/vlm/hf/peft.py b/examples/vlm/hf/peft.py new file mode 100644 index 000000000000..d51984677a74 --- /dev/null +++ b/examples/vlm/hf/peft.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import torch +from lightning.pytorch.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm + + +def mk_hf_vlm_dataset(processor, mbs, gbs): + """Creates vlm dataset""" + skipped_tokens = vlm.HFAutoModelForImageTextToText.extract_skipped_token_ids(processor) + + def collate_fn(examples, processor): + def fmt(sample): + instruction = "Describe accurately the given image." + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": instruction}, {"type": "image", "image": sample["image"]}], + }, + {"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]}, + ] + return {"conversation": conversation, "images": [sample['image']]} + + text = [] + images = [] + for example in map(fmt, examples): + text.append( + processor.apply_chat_template( + example["conversation"], + tokenize=False, + add_generation_prompt=False, + ) + ) + images += example['images'] + + # Tokenize the text and process the images + batch = processor( + text=text, + images=images, + padding=True, + return_tensors="pt", + ) + + assert batch["input_ids"].ndim == 2, 'Expected input_ids to be 2D' + batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16) + labels = batch["input_ids"].clone() + labels[torch.isin(labels, skipped_tokens)] = -100 + batch["labels"] = labels[:, 1:] + batch["input_ids"] = batch["input_ids"][:, :-1] + return batch + + return vlm.HFDatasetDataModule( + "quintend/rdr-items", + split="train", + micro_batch_size=mbs, + global_batch_size=gbs, + collate_fn=lambda x: collate_fn(x, processor=processor), + ) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='Qwen/Qwen2-VL-2B-Instruct') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--mbs', default=1) + parser.add_argument('--gbs', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--wandb-project', type=str, default=None) + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = 0.5 + if args.strategy == 'fsdp': + # See: + # https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + grad_clip = None + use_dist_samp = False + processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) + + llm.api.finetune( + model=vlm.HFAutoModelForImageTextToText(args.model), + data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + logger=wandb, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + peft=llm.peft.LoRA( + target_modules=['*_proj'], + dim=16, + ), + ) diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index 2d8b32964767..cea7264543ff 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -42,6 +42,7 @@ def __init__( model_accelerator=None, trust_remote_code=False, default_dtype=torch.bfloat16, + load_in_4bit=False, ): super().__init__() self.save_hyperparameters() @@ -55,6 +56,7 @@ def __init__( self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code self.default_dtype = default_dtype + self.load_in_4bit = load_in_4bit @property def tokenizer(self): @@ -75,7 +77,10 @@ def configure_model(self): # create all your layers here if self.load_pretrained_weights: self.model = AutoModelForCausalLM.from_pretrained( - self.model_name, torch_dtype='auto', trust_remote_code=self.trust_remote_code + self.model_name, + torch_dtype='auto', + trust_remote_code=self.trust_remote_code, + load_in_4bit=self.load_in_4bit, ) else: from transformers import AutoConfig @@ -108,9 +113,10 @@ def training_step(self, batch): # Prepare for loss calculation logits = outputs.logits.float() n_cls = logits.shape[-1] - logits, labels = logits.view(-1, n_cls), labels.view(-1) - assert logits.shape[-2] == labels.shape[-1] + logits = logits.view(-1, n_cls) + labels = labels.view(-1) + assert logits.shape[-2] == labels.shape[-1], "Expected logits & labels to have the same length" loss = self.loss_fn(logits, labels, loss_mask) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) return loss @@ -127,12 +133,12 @@ def validation_step(self, batch, batch_idx): outputs = self.forward(**batch) - # Prepare for loss calculation logits = outputs.logits.float() n_cls = logits.shape[-1] - logits, labels = logits.view(-1, n_cls), labels.view(-1) - assert logits.shape[-2] == labels.shape[-1] + logits = logits.view(-1, n_cls) + labels = labels.view(-1) + assert logits.shape[-2] == labels.shape[-1], "Expected logits & labels to have the same length" loss = self.loss_fn(logits, labels, loss_mask) self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index 766b8993bf35..0ce6138d1c6b 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -45,7 +45,14 @@ class LinearAdapter(nn.Module): """ def __init__( - self, orig_linear, dim=8, alpha=32, dropout=0.1, dropout_position='post', lora_A_init_method='xavier' + self, + orig_linear, + dim=8, + alpha=32, + dropout=0.1, + dropout_position='post', + lora_A_init_method='xavier', + lora_dtype=None, ): super(LinearAdapter, self).__init__() assert isinstance(orig_linear, nn.Linear) @@ -62,7 +69,8 @@ def __init__( in_features = self.orig_linear.in_features out_features = self.orig_linear.out_features - dtype = self.orig_linear.weight.dtype + dtype = lora_dtype or self.orig_linear.weight.dtype + self.lora_a = nn.Parameter(torch.zeros((in_features, dim), dtype=dtype, device=device)) self.lora_b = nn.Parameter(torch.zeros((dim, out_features), dtype=dtype, device=device)) if lora_A_init_method == 'xavier': @@ -112,6 +120,7 @@ class LoRA(PEFT): dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'. a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. + lora_drype (torch.dtype): Parameter data type for LoRA weights. Default None (will use model's dtype). Example: -------- @@ -140,6 +149,7 @@ class LoRA(PEFT): lora_A_init_method: str = "xavier" lora_B_init_method: str = "zero" a2a_experimental: bool = False + lora_dtype: torch.dtype = None def transform(self, m: nn.Module, name=None, prefix=None): """ @@ -159,7 +169,12 @@ def transform(self, m: nn.Module, name=None, prefix=None): if name in self.target_modules or any(wildcard_match(pattern, full_name) for pattern in self.target_modules): if isinstance(m, nn.Linear): return LinearAdapter( - m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method + m, + dim=self.dim, + alpha=self.alpha, + dropout=self.dropout, + lora_A_init_method=self.lora_A_init_method, + lora_dtype=self.lora_dtype, ) input_is_parallel, in_features, out_features = get_adapter_attributes_from_linear(m) diff --git a/nemo/collections/vlm/__init__.py b/nemo/collections/vlm/__init__.py index b5e693830fa5..3e9eebe47cbe 100644 --- a/nemo/collections/vlm/__init__.py +++ b/nemo/collections/vlm/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.vlm.hf.data.hf_dataset import HFDatasetDataModule +from nemo.collections.vlm.hf.model.hf_auto_model_for_image_text_to_text import HFAutoModelForImageTextToText from nemo.collections.vlm.llava_next.data import LlavaNextMockDataModule, LlavaNextTaskEncoder from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig from nemo.collections.vlm.llava_next.model.llava_next import LlavaNextConfig7B, LlavaNextConfig13B, LlavaNextModel @@ -51,6 +53,8 @@ from nemo.collections.vlm.recipes import * __all__ = [ + "HFDatasetDataModule", + "HFAutoModelForImageTextToText", "NevaMockDataModule", "NevaLazyDataModule", "MLlamaMockDataModule", diff --git a/nemo/collections/vlm/hf/data/hf_dataset.py b/nemo/collections/vlm/hf/data/hf_dataset.py new file mode 100644 index 000000000000..a73e6d3e3504 --- /dev/null +++ b/nemo/collections/vlm/hf/data/hf_dataset.py @@ -0,0 +1,281 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lightning.pytorch as pl +import torch +from datasets import Dataset, DatasetDict, load_dataset +from torch.utils.data import DataLoader + +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +def clean_split(name): + """name="train[:100]" returns "train" """ + if '[' in name: + return name.split('[')[0] + return name + + +def make_dataset_splits(dataset, split, split_aliases): + """ + Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it + returns a dictionary containing the corresponding dataset splits. + + For example: + + $ ds = load_dataset("dataset-id") + $ ans = make_dataset_splits(ds) + + # `ds` contains the following + $ print(ds) + > DatasetDict({ + > train: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 87599 + > }) + > validation: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 10570 + > }) + > }) + + # In this case the value of `ans` (returned value) will be: + $ print(ans) + > { + > "train": Dataset .. (with 87599 rows), + > "val": Dataset .. (with 10570 rows), + > } + """ + valid_split_names = ['train', 'test', 'val'] + dataset_splits = {_split: None for _split in valid_split_names} + + alias_to_split = {} + for split_name, _split_aliases in split_aliases.items(): + assert split_name in valid_split_names + for alias in _split_aliases: + alias_to_split[alias] = split_name + + if isinstance(dataset, Dataset): + assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split)) + split = clean_split(split) + dataset_splits[split] = dataset + elif isinstance(dataset, DatasetDict): + dataset_split_names = dataset.keys() + logging.info(f"HF dataset has the following splits: {dataset_split_names}") + for alias_split_name, split in dataset.items(): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = split + elif isinstance(split, list): + logging.info(f"Loaded HF dataset will use " + str(split) + " splits.") + assert isinstance(dataset, list) + for i, alias_split_name in enumerate(map(clean_split, split)): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset[i] + elif isinstance(split, str): + logging.info(f"Loaded HF dataset has a single split.") + assert not isinstance(dataset, list) + alias_split_name = split + if '+' in alias_split_name: + raise ValueError("Split concatenation not supported") + elif '[' in alias_split_name: + alias_split_name = alias_split_name.split('[')[0] + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset + else: + raise ValueError("Expected split name to be None, str or a list") + + assert set(valid_split_names) == set(dataset_splits.keys()), dataset_splits.keys() + num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values())) + assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}" + return dataset_splits + + +class HFDatasetDataModule(pl.LightningDataModule): + """HFDatasetDataModule wraps HF's load_dataset (datasets library) + so that it can be used within NeMo. + Users can select whether to use an mcore-sampler via use_mcore_sampler arg. + + Usage examples: + + - loading a single split (train) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split="train") + + - loading multiple splits (train, validation) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"]) + """ + + def __init__( + self, + path_or_dataset, + split=None, + collate_fn=None, + num_workers=2, + pin_memory=True, + persistent_workers=True, + seq_length=1024, + micro_batch_size=2, + global_batch_size=2, + pad_token_id=0, + use_mcore_sampler=False, + mcore_dataloader_type='cyclic', + train_aliases=["train", "training"], + test_aliases=["test", "testing"], + val_aliases=["val", "validation", "valid", "eval"], + **kwargs, + ) -> None: + super().__init__() + assert pad_token_id is not None + # A dataset usually will have several splits (e.g. train, val, test, etc). + # We map synonym names to canonical names (train, test, val). + # A synonym can be a prefix/suffixed word e.g. train <> training. + split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases} + + # self.dataset_splits will hold the actual dataset for each split. + if isinstance(path_or_dataset, str): + logging.info(f"Loading HF dataset from {path_or_dataset}") + dataset = load_dataset(path_or_dataset, split=split, **kwargs) + elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict): + logging.info(f"Using passed HF dataset {str(path_or_dataset)}") + dataset = path_or_dataset + else: + raise ValueError( + "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + str(type(path_or_dataset)) + ) + + self.dataset_splits = make_dataset_splits(dataset, split, split_aliases) + + if collate_fn is None: + self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + else: + self._collate_fn = collate_fn + + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.pad_token_id = pad_token_id + + self.use_mcore_sampler = use_mcore_sampler + self.mcore_dataloader_type = mcore_dataloader_type + + @staticmethod + def from_dict(dataset_dict, split, **kwargs): + """Creates a Dataset from a dictionary""" + dataset = Dataset.from_dict(dataset_dict) + return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs) + + @staticmethod + def collate_fn(batch, pad_token_id=0): + """Collate for VLM data""" + + def batchify(tensor): + if tensor.ndim == 1: + return tensor.unsqueeze_(0) + return tensor + + def extract_key_from_dicts(batch, key): + return list(map(lambda x: x[key], batch)) + + def pad_within_micro(batch, pad_token_id): + max_len = max(map(len, batch)) + return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + + return { + key: batchify( + torch.LongTensor( + pad_within_micro( + extract_key_from_dicts(batch, key), + pad_token_id, + ) + ) + ) + for key in batch[0].keys() + } + + def setup(self, stage: str): + """PTL hook""" + if not self.use_mcore_sampler: + return + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + ) + + def _make_dataloader(self, dataset, collate_fn=None): + """Creates a dataloader""" + assert dataset is not None + + if collate_fn is None: + + def collate_fn(x): + return HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=collate_fn, + batch_size=self.micro_batch_size, + ) + + @property + def train(self): + """Train data split""" + return self.dataset_splits['train'] + + @property + def val(self): + """Validation data split""" + return self.dataset_splits['val'] + + @property + def test(self): + """Testing data split""" + return self.dataset_splits['test'] + + def train_dataloader(self): + """Creates a dataloader for the train split""" + return self._make_dataloader(self.train, self._collate_fn) + + def val_dataloader(self): + """Creates a dataloader for the validation split""" + return self._make_dataloader(self.val, self._collate_fn) + + def test_dataloader(self): + """Creates a dataloader for the test split""" + return self._make_dataloader(self.test, self._collate_fn) + + def map(self, function=None, split_names=None, **kwargs): + """Maps a function to all/selected splits + Additional arguments can be passed down to dataset's map via kwargs""" + if isinstance(split_names, str): + dataset_splits = {split_names: self.dataset_splits[split_names]} + elif isinstance(split_names, list): + dataset_splits = {k: self.dataset_splits[k] for k in split_names} + else: + dataset_splits = self.dataset_splits + + for split_name, subset in dataset_splits.items(): + if subset is None: + continue + dataset_splits[split_name] = subset.map(function, **kwargs) diff --git a/nemo/collections/vlm/hf/model/hf_auto_model_for_image_text_to_text.py b/nemo/collections/vlm/hf/model/hf_auto_model_for_image_text_to_text.py new file mode 100644 index 000000000000..33ad04970d35 --- /dev/null +++ b/nemo/collections/vlm/hf/model/hf_auto_model_for_image_text_to_text.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lightning.pytorch as pl +import torch +import torch.nn.functional as F +from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor + +from nemo.collections.llm import fn +from nemo.lightning import io +from nemo.utils import logging + + +def masked_cross_entropy(logits, targets, mask=None): + """Cross entropy with optional mask""" + if mask is not None: + loss = F.cross_entropy(logits, targets, reduction='none') + return torch.mean(loss * mask) + else: + return F.cross_entropy(logits, targets) + + +class HFAutoModelForImageTextToText(pl.LightningModule, io.IOMixin, fn.FNMixin): + """Wrap's HF's AutoModelForImageTextToText in a pl.LightningModule + for use within NeMo""" + + def __init__( + self, + model_name='gpt2', + load_pretrained_weights=True, + processor=None, + loss_fn=masked_cross_entropy, + model_transform=None, + trust_remote_code=False, + default_dtype=torch.bfloat16, + load_in_4bit=False, + ): + super().__init__() + self.save_hyperparameters() + self.model_name = model_name + self._processor = processor + self.tokenizer = None + self.model = None + self.loss_fn = loss_fn + self.load_pretrained_weights = load_pretrained_weights + self.is_hf_model = True + self.model_transform = model_transform + self.trust_remote_code = trust_remote_code + self.load_in_4bit = load_in_4bit + + @property + def processor(self): + """Return's module processor""" + if self._processor is None: + self._processor = HFAutoModelForImageTextToText.configure_processor( + self.model_name, trust_remote_code=self.trust_remote_code + ) + return self._processor + + @processor.setter + def processor(self, value): + """Set's module's processor""" + assert self._processor is None + self._processor = value + + @staticmethod + def configure_processor(model_name, trust_remote_code=False): + """Initializes an AutoProcessor and returns the instance""" + return AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code) + + def configure_model(self): + """Instantiates the model""" + # create all your layers here + if self.load_pretrained_weights: + self.model = AutoModelForImageTextToText.from_pretrained( + self.model_name, + torch_dtype='auto', + trust_remote_code=self.trust_remote_code, + load_in_4bit=self.load_in_4bit, + ) + else: + config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) + dtype = getattr(config, 'torch_dtype', self.default_dtype) + self.model = AutoModelForImageTextToText.from_config( + config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + ) + self.model.train() + + def forward(self, batch): + """Runs forward with the model""" + return self.model(**batch) + + def training_step(self, batch): + """Run one training step""" + labels = batch.pop('labels').to(self.model.device) + loss_mask = batch.pop('loss_mask', None) + + outputs = self.forward(batch) + + logits = outputs.logits.float() + n_cls = logits.shape[-1] + logits = logits.view(-1, n_cls) + labels = labels.view(-1) + + assert logits.shape[-2] == labels.shape[-1], "Expected logits & labels to have the same length" + loss = self.loss_fn(logits, labels, loss_mask) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + @torch.no_grad + def validation_step(self, batch, batch_idx): + """Run one validation step""" + labels = batch.pop('labels').to(self.model.device) + loss_mask = batch.pop('loss_mask', None) + + outputs = self.forward(**batch) + + logits = outputs.logits.float() + n_cls = logits.shape[-1] + logits = logits.view(-1, n_cls) + labels = labels.view(-1) + + assert logits.shape[-2] == labels.shape[-1], "Expected logits & labels to have the same length" + loss = self.loss_fn(logits, labels, loss_mask) + + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + + def save_pretrained(self, path): + """Saves checkpoint using HF""" + assert self.model is not None, "Model has to be created first." + self.model.save_pretrained(path) + if self._processor is not None: + self._processor.save_pretrained(path) + else: + logging.warning("A processor wasn't created before to save.") + + @staticmethod + def extract_skipped_token_ids(tokenizer): + """Returns list of tokens to mask in labels""" + # qweb2-2b + QWEN_TOKENS = [ + '<|im_start|>', + '<|im_end|>', + '<|vision_start|>', + '<|vision_end|>', + '<|vision_pad|>', + '<|image_pad|>', + '<|video_pad|>', + '<|im_start|>', + '<|im_end|>', + '<|vision_start|>', + '<|vision_end|>', + '<|vision_pad|>', + '<|image_pad|>', + '<|video_pad|>', + ] + # llava-1.5-7b-hf, llava-v1.6-mistral-7b-hf + LLAVA_TOKENS = [ + "", + "", + ] + LLAMA_TOKENS = [ + '<|begin_of_text|>', + '<|end_of_text|>', + '<|finetune_right_pad_id|>', + '<|step_id|>', + '<|start_header_id|>', + '<|end_header_id|>', + '<|eom_id|>', + '<|eot_id|>', + '<|python_tag|>', + '<|image|>', + ] + PAD_TOKENS = set(QWEN_TOKENS + LLAVA_TOKENS + LLAMA_TOKENS) + tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) + skipped_token_ids = [] + for key, val in tokenizer.added_tokens_decoder.items(): + if str(val) in PAD_TOKENS: + skipped_token_ids.append(key) + return torch.IntTensor(list(set(skipped_token_ids))) diff --git a/tests/collections/vlm/hf/peft.py b/tests/collections/vlm/hf/peft.py new file mode 100644 index 000000000000..109bccfcfa1f --- /dev/null +++ b/tests/collections/vlm/hf/peft.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import torch +from lightning.pytorch.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm + +DATA_PATH = "/home/TestData/vlm/rdr-items" + + +def mk_hf_vlm_dataset(processor, mbs, gbs): + skipped_tokens = vlm.HFAutoModelForImageTextToText.extract_skipped_token_ids(processor) + + def collate_fn(examples, processor): + def fmt(sample): + instruction = "Describe accurately the given image." + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": instruction}, {"type": "image", "image": sample["image"]}], + }, + {"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]}, + ] + return {"conversation": conversation, "images": [sample['image']]} + + text = [] + images = [] + for example in map(fmt, examples): + text.append( + processor.apply_chat_template( + example["conversation"], + tokenize=False, + add_generation_prompt=False, + ) + ) + images += example['images'] + + # Tokenize the text and process the images + batch = processor( + text=text, + images=images, + padding=True, + return_tensors="pt", + ) + + batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16) + + labels = batch["input_ids"].clone() + labels[torch.isin(labels, skipped_tokens)] = -100 + batch["labels"] = labels + return batch + + return vlm.HFDatasetDataModule( + DATA_PATH, + split="train[:10]", + micro_batch_size=mbs, + global_batch_size=gbs, + collate_fn=lambda x: collate_fn(x, processor=processor), + ) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='Qwen/Qwen2-VL-2B-Instruct') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--mbs', default=1) + parser.add_argument('--gbs', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--disable-ckpt', action='store_false') + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = 0.5 + if args.strategy == 'fsdp': + # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + grad_clip = None + use_dist_samp = False + processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) + + llm.api.finetune( + model=vlm.HFAutoModelForImageTextToText(args.model), + data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + logger=wandb, + enable_checkpointing=args.disable_ckpt, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + peft=llm.peft.LoRA( + target_modules=['*_proj'], + dim=16, + ), + ) From c160abbd180ee741f1479151550ef3048a6fe728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 17 Dec 2024 23:47:18 +0100 Subject: [PATCH 08/20] ci: Bump release workflow (#11635) Signed-off-by: Oliver Koenig --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index dcaac34901cd..2ddad31e159e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -28,7 +28,7 @@ on: jobs: release: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.15.1 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.17.3 with: release-ref: ${{ inputs.release-ref }} image-name: nemo_container From 0cb318b14a7dd9d446241aef3cf4a6486d92b940 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 17 Dec 2024 17:43:34 -0800 Subject: [PATCH 09/20] Add fix docstring for speech commands (#11638) Signed-off-by: smajumdar --- tutorials/asr/Speech_Commands.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/asr/Speech_Commands.ipynb b/tutorials/asr/Speech_Commands.ipynb index c8a54e5135b2..927c0a15b76c 100644 --- a/tutorials/asr/Speech_Commands.ipynb +++ b/tutorials/asr/Speech_Commands.ipynb @@ -65,6 +65,7 @@ "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n", "\n", "## Install TorchAudio\n", + "## NOTE: TorchAudio installation may not work in all environments, please use Google Colab for best experience\n", "!pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html\n", "\n", "## Grab the config we'll use in this example\n", From 97129886a712ae176cc8dd94a0597547206cabf7 Mon Sep 17 00:00:00 2001 From: Weiqing Wang <164252040+weiqingw4ng@users.noreply.github.com> Date: Tue, 17 Dec 2024 23:42:04 -0800 Subject: [PATCH 10/20] Fixing Multi_Task_Adapters.ipynb by replacing canary2 with canary_custom (#11641) Signed-off-by: Weiqing Wang --- tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb index 0d35feb11a9a..978793ae4a06 100644 --- a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb @@ -433,7 +433,7 @@ "outputs": [], "source": [ "@registered_prompt_format_fn\n", - "def canary2(cuts, tokenizer, inference: bool):\n", + "def canary_custom(cuts, tokenizer, inference: bool):\n", " \"\"\" Users can implement this as needed \"\"\"\n", " raise NotImplementedError()\n", "\n", @@ -449,7 +449,7 @@ }, "outputs": [], "source": [ - "temp = get_prompt_format_fn('canary2')\n", + "temp = get_prompt_format_fn('canary_custom')\n", "temp.__name__" ] }, @@ -549,7 +549,7 @@ "class CanaryPromptFormatterV2(model.prompt.__class__):\n", "\n", " # make sure to provide a new name\n", - " NAME: str = \"canary2\"\n", + " NAME: str = \"canary_custom\"\n", "\n", " # Make any changes as necessary.\n", " # For this demonstration, we will not change anything other than the name" @@ -565,7 +565,7 @@ "outputs": [], "source": [ "# Next, lets update the model's prompt formatter\n", - "model.change_prompt(\"canary2\")" + "model.change_prompt(\"canary_custom\")" ] }, { @@ -577,9 +577,9 @@ "source": [ "---\n", "\n", - "We have now successfully changed the prompt format to `canary2`.\n", + "We have now successfully changed the prompt format to `canary_custom`.\n", "\n", - "**Note**: It is important to know that when changing the prompt format, the name of the new prompt format class (`canary2` in this case) **has to match** the name of the prompt function registered with `@registered_prompt_format_fn`!" + "**Note**: It is important to know that when changing the prompt format, the name of the new prompt format class (`canary_custom` in this case) **has to match** the name of the prompt function registered with `@registered_prompt_format_fn`!" ] }, { From 0133deb268812109eab5211f902d7c4b99b6ee95 Mon Sep 17 00:00:00 2001 From: nasretdinovr Date: Wed, 18 Dec 2024 17:52:15 +0400 Subject: [PATCH 11/20] fixed config name in online augmentation tutorial (#11628) Signed-off-by: Rauf --- .../Speech_Enhancement_with_Online_Augmentation.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb index e8b734537a41..41936e79675a 100644 --- a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb @@ -540,10 +540,10 @@ "config_dir = root_dir / 'conf'\n", "config_dir.mkdir(exist_ok=True)\n", "\n", - "config_path = config_dir / 'masking_online_aug.yaml'\n", + "config_path = config_dir / 'masking_with_online_augmentation.yaml'\n", "\n", "if not config_path.is_file():\n", - " !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/examples/audio/conf/masking_online_aug.yaml -P {config_dir.as_posix()}\n", + " !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/examples/audio/conf/masking_with_online_augmentation.yaml -P {config_dir.as_posix()}\n", "\n", "config = OmegaConf.load(config_path)\n", "config = OmegaConf.to_container(config, resolve=True)\n", From b68274b9d2fa2468bd955e215e9193688a8b577f Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 18 Dec 2024 10:47:27 -0500 Subject: [PATCH 12/20] fix default nodes (#11632) --- nemo/collections/llm/recipes/gemma2_27b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/recipes/gemma2_27b.py b/nemo/collections/llm/recipes/gemma2_27b.py index d6b41c0a221c..c03ae6c21aa3 100644 --- a/nemo/collections/llm/recipes/gemma2_27b.py +++ b/nemo/collections/llm/recipes/gemma2_27b.py @@ -62,7 +62,7 @@ def pretrain_recipe( virtual_pipeline_parallelism: Optional[int] = None, context_parallelism: int = 1, sequence_parallelism: bool = False, - num_nodes: int = 1, + num_nodes: int = 2, num_gpus_per_node: int = 8, max_steps: int = 1168251, precision: str = "bf16-mixed", From 1121289cd81eaf736a9afd71a8abf51dd1e73984 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Date: Wed, 18 Dec 2024 19:44:08 +0200 Subject: [PATCH 13/20] add renormalize_blend_weights param (#11647) Signed-off-by: dimapihtar --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 1 + .../nlp/models/language_modeling/megatron_gpt_model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index da160390b431..e70f3ca418c4 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -293,6 +293,7 @@ model: shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem data_cache_generation_only: False # Set to True to generate only the data cache and stop the training script + renormalize_blend_weights: False # Renormalize the blend weights to account for mid-level dataset oversampling done to ensure fulfillmenet of the of the requested number of samples. # Nsys profiling options nsys_profile: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index a8ed1ee7d28f..c2c3431070a6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1660,6 +1660,7 @@ def build_train_valid_test_datasets(self): "mmap_bin_files": self.cfg.data.get("mmap_bin_files", True), "drop_last_partial_validation_sequence": self.cfg.data.get("validation_drop_last", True), "num_dataset_builder_threads": self.cfg.data.get("num_dataset_builder_threads", 1), + "renormalize_blend_weights": self.cfg.data.get("renormalize_blend_weights", False), "add_extra_token_to_sequence": add_extra_token, } From b7478b6f86eafe5c90be90d5fde5c0fd9a7fd9bd Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 18 Dec 2024 09:46:48 -0800 Subject: [PATCH 14/20] Sortformer Diarizer 4spk v1 model PR Part 3: Speaker Diarization Mixin (#11511) * Adding diarization mixin for one click inference Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Resolving CodeQL and Pylint Signed-off-by: taejinp * Resolving CodeQL and Pylint - unsaved files resolved Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Unused package manifest_utils Signed-off-by: taejinp * Resolved diarization mixin test issues Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Removed commented lines Signed-off-by: taejinp * updating mixins code Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: ipmedenn * fixing test_diarizartion.py Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * moving diarization postprocessing-related stuff to vad_utils.py Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: ipmedenn * Resolving PyLint Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: ipmedenn * fixing batch_idx issue in sortformer_diar_models.py Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * adding sync_dist=True for sortformer validation Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: tango4j * Reflecting the comments from PR Signed-off-by: taejinp * Reflecting the comments from PR 2nd Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j * Resolved a codeQL unused variable Signed-off-by: taejinp * Now moved existance check after Signed-off-by: taejinp * Apply isort and black reformatting Signed-off-by: tango4j --------- Signed-off-by: taejinp Signed-off-by: tango4j Signed-off-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> Signed-off-by: ipmedenn Co-authored-by: tango4j Co-authored-by: ipmedenn <65592416+ipmedenn@users.noreply.github.com> Co-authored-by: ipmedenn --- .../neural_diarizer/e2e_diarize_speech.py | 105 +--- .../asr/data/audio_to_diar_label.py | 5 + .../asr/models/sortformer_diar_models.py | 172 +++++- .../asr/parts/mixins/diarization.py | 493 ++++++++++++++++++ nemo/collections/asr/parts/utils/vad_utils.py | 108 ++++ .../common/parts/preprocessing/collections.py | 61 ++- .../speaker_tasks/mixins/test_diarization.py | 271 ++++++++++ 7 files changed, 1108 insertions(+), 107 deletions(-) create mode 100644 nemo/collections/asr/parts/mixins/diarization.py create mode 100644 tests/collections/speaker_tasks/mixins/test_diarization.py diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 147d7a3aa002..60600b59db59 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -24,7 +24,7 @@ Usage for diarization inference: -The end-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". +The end-to-end speaker diarization model can be specified by "model_path". Data for diarization is fed through the "dataset_manifest". By default, post-processing is bypassed, and only binarization is performed. If you want to reproduce DER scores reported on NeMo model cards, you need to apply post-processing steps. @@ -45,45 +45,32 @@ import lightning.pytorch as pl import optuna import torch -import yaml from omegaconf import OmegaConf from pytorch_lightning import seed_everything -from tqdm import tqdm from nemo.collections.asr.metrics.der import score_labels from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object -from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing +from nemo.collections.asr.parts.utils.speaker_utils import ( + audio_rttm_map, + get_uniqname_from_filepath, + timestamps_to_pyannote_object, +) +from nemo.collections.asr.parts.utils.vad_utils import ( + PostProcessingParams, + load_postprocessing_from_yaml, + predlist_to_timestamps, +) from nemo.core.config import hydra_runner seed_everything(42) torch.backends.cudnn.deterministic = True -@dataclass -class PostProcessingParams: - """ - Postprocessing parameters for end-to-end speaker diarization models. - These parameters can significantly affect DER performance depending on the evaluation style and the dataset. - It is recommended to tune these parameters based on the evaluation style and the dataset - to achieve the desired DER performance. - """ - - onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech - offset: float = 0.5 # Offset threshold for detecting the end of a speech - pad_onset: float = 0.0 # Adding durations before each speech segment - pad_offset: float = 0.0 # Adding durations after each speech segment - min_duration_on: float = 0.0 # Threshold for small non-speech deletion - min_duration_off: float = 0.0 # Threshold for short speech segment deletion - - @dataclass class DiarizationConfig: """Diarization configuration parameters for inference.""" model_path: Optional[str] = None # Path to a .nemo file - pretrained_name: Optional[str] = None # Name of a pretrained model - audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations @@ -114,36 +101,6 @@ class DiarizationConfig: optuna_n_trials: int = 100000 -def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams = None) -> PostProcessingParams: - """ - Load postprocessing parameters from a YAML file. - - Args: - postprocessing_yaml (str): - Path to a YAML file for postprocessing configurations. - - Returns: - postprocessing_params (dataclass): - Postprocessing parameters loaded from the YAML file. - """ - # Add PostProcessingParams as a field - postprocessing_params = OmegaConf.structured(PostProcessingParams()) - if postprocessing_yaml is None: - logging.info( - f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." - ) - else: - # Load postprocessing params from the provided YAML file - with open(postprocessing_yaml, 'r') as file: - yaml_params = yaml.safe_load(file)['parameters'] - # Update the postprocessing_params with the loaded values - logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") - for key, value in yaml_params.items(): - if hasattr(postprocessing_params, key): - setattr(postprocessing_params, key, value) - return postprocessing_params - - def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: """ Suggests hyperparameters for postprocessing using Optuna. @@ -303,26 +260,19 @@ def convert_pred_mat_to_segments( """ batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] cfg_vad_params = OmegaConf.structured(postprocessing_cfg) - pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing" - for sample_idx, (uniq_id, audio_rttm_values) in tqdm( - enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message - ): - spk_ts = [] - offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] - speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) - speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] - for spk_id in range(speaker_assign_mat.shape[-1]): - ts_mat = ts_vad_post_processing( - speaker_assign_mat[:, spk_id], - cfg_vad_params=cfg_vad_params, - unit_10ms_frame_count=unit_10ms_frame_count, - bypass_postprocessing=bypass_postprocessing, - ) - ts_mat = ts_mat + offset - ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) - ts_seg_list = ts_mat.tolist() - speaker_timestamps[spk_id].extend(ts_seg_list) - spk_ts.append(ts_seg_list) + total_speaker_timestamps = predlist_to_timestamps( + batch_preds_list=batch_preds_list, + audio_rttm_map_dict=audio_rttm_map_dict, + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) + for sample_idx, (uniq_id, audio_rttm_values) in enumerate(audio_rttm_map_dict.items()): + speaker_timestamps = total_speaker_timestamps[sample_idx] + if audio_rttm_values.get("uniq_id", None) is not None: + uniq_id = audio_rttm_values["uniq_id"] + else: + uniq_id = get_uniqname_from_filepath(audio_rttm_values["audio_filepath"]) all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( speaker_timestamps, uniq_id, @@ -332,7 +282,6 @@ def convert_pred_mat_to_segments( all_uems, out_rttm_dir, ) - batch_pred_ts_segs.append(spk_ts) return all_hypothesis, all_reference, all_uems @@ -348,10 +297,8 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.random_seed: pl.seed_everything(cfg.random_seed) - if cfg.model_path is None and cfg.pretrained_name is None: - raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") - if cfg.audio_dir is None and cfg.dataset_manifest is None: - raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + if cfg.model_path is None: + raise ValueError("cfg.model_path cannot be None. Please specify the path to the model.") # setup GPU torch.set_float32_matmul_precision(cfg.matmul_precision) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 3f4ae61e0d08..1dbe68589c0a 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1120,6 +1120,11 @@ def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_le Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] """ + if rttm_file in [None, '']: + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + return targets + with open(rttm_file, 'r') as f: rttm_lines = f.readlines() diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 71de10cc2f79..483ff5328ad0 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -13,23 +13,29 @@ # limitations under the License. import itertools +import os import random from collections import OrderedDict -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning import Trainer +from torch.utils.data import DataLoader from tqdm import tqdm from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.asr.parts.utils.speaker_utils import generate_diarization_output_lines +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo @@ -40,7 +46,7 @@ __all__ = ['SortformerEncLabelModel'] -class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel, SpkDiarizationMixin): """ Encoder class for Sortformer diarization model. Model class creates training, validation methods for setting up data performing model forward pass. @@ -108,7 +114,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.streaming_mode = self._cfg.get("streaming_mode", False) self.save_hyperparameters("cfg") self._init_eval_metrics() - speaker_inds = list(range(self._cfg.max_num_of_spks)) self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations @@ -119,7 +124,6 @@ def _init_loss_weights(self): raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") self.pil_weight = pil_weight / (pil_weight + ats_weight) self.ats_weight = ats_weight / (pil_weight + ats_weight) - logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") def _init_eval_metrics(self): """ @@ -269,6 +273,113 @@ def forward_infer(self, emb_seq): preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) return preds + def _diarize_forward(self, batch: Any): + """ + A counterpart of `_transcribe_forward` function in ASR. + This function is a wrapper for forward pass functions for compataibility + with the existing classes. + + Args: + batch (Any): The input batch containing audio signal and audio signal length. + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + """ + with torch.no_grad(): + preds = self.forward(audio_signal=batch[0], audio_signal_length=batch[1]) + preds = preds.to('cpu') + torch.cuda.empty_cache() + return preds + + def _diarize_output_processing( + self, outputs, uniq_ids, diarcfg: DiarizeConfig + ) -> Union[List[List[str]], Tuple[List[List[str]], List[torch.Tensor]]]: + """ + Processes the diarization outputs and generates RTTM (Real-time Text Markup) files. + TODO: Currently, this function is not included in mixin test because of + `ts_vad_post_processing` function. + (1) Implement a test-compatible function + (2) `vad_utils.py` has `predlist_to_timestamps` function that is close to this function. + Needs to consolute differences and implement the test-compatible function. + + Args: + outputs (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + uniq_ids (List[str]): List of unique identifiers for each audio file. + diarcfg (DiarizeConfig): Configuration object for diarization. + + Returns: + diar_output_lines_list (List[List[str]]): A list of lists, where each inner list contains + the RTTM lines for a single audio file. + preds_list (List[torch.Tensor]): A list of tensors containing the diarization outputs + for each audio file. + + """ + preds_list, diar_output_lines_list = [], [] + if outputs.shape[0] == 1: # batch size = 1 + preds_list.append(outputs) + else: + preds_list.extend(torch.split(outputs, [1] * outputs.shape[0])) + + for sample_idx, uniq_id in enumerate(uniq_ids): + offset = self._diarize_audio_rttm_map[uniq_id]['offset'] + speaker_assign_mat = preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=diarcfg.postprocessing_params, + unit_10ms_frame_count=int(self._cfg.encoder.subsampling_factor), + bypass_postprocessing=False, + ) + ts_mat = ts_mat + offset + ts_seg_raw_list = ts_mat.tolist() + ts_seg_list = [[round(stt, 2), round(end, 2)] for (stt, end) in ts_seg_raw_list] + speaker_timestamps[spk_id].extend(ts_seg_list) + + diar_output_lines = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) + diar_output_lines_list.append(diar_output_lines) + if diarcfg.include_tensor_outputs: + return (diar_output_lines_list, preds_list) + else: + return diar_output_lines_list + + def _setup_diarize_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + - manifest_filepath: Path to the manifest file containing audio file paths + and corresponding speaker labels. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'num_spks': config.get('num_spks', self._cfg.max_num_of_spks), + 'batch_size': batch_size, + 'shuffle': False, + 'soft_label_thres': 0.5, + 'session_len_sec': config['session_len_sec'], + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + } + temporary_datalayer = self.__setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + def process_signal(self, audio_signal, audio_signal_length): """ Extract audio features from time-series signal for further processing in the model. @@ -291,7 +402,7 @@ def process_signal(self, audio_signal, audio_signal_length): - processed_signal_length (torch.Tensor): The length of each processed signal. Shape: (batch_size,) """ - audio_signal = audio_signal.to(self.device) + audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device) audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal processed_signal, processed_signal_length = self.preprocessor( input_signal=audio_signal, length=audio_signal_length @@ -372,7 +483,7 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: } return train_metrics - def training_step(self, batch: list) -> dict: + def training_step(self, batch: list, batch_idx: int) -> dict: """ Performs a single training step. @@ -382,6 +493,7 @@ def training_step(self, batch: list) -> dict: - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. - targets (torch.Tensor): The target labels for the batch. - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. Returns: (dict): A dictionary containing the 'loss' key with the calculated loss value. @@ -439,7 +551,7 @@ def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: } return val_metrics - def validation_step(self, batch: list, dataloader_idx: int = 0): + def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): """ Performs a single validation step. @@ -571,10 +683,46 @@ def test_batch( logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + def on_validation_epoch_end(self) -> Optional[dict[str, dict[str, torch.Tensor]]]: + """Run validation with sync_dist=True.""" + return super().on_validation_epoch_end(sync_metrics=True) + + @torch.no_grad() def diarize( self, - ): - """One-clieck runner function for diarization.""" - # TODO: A direct one-click runner function that generates - # speaker labels from audio file path lists. - raise NotImplementedError + audio: Union[str, List[str], np.ndarray, DataLoader], + batch_size: int = 1, + include_tensor_outputs: bool = False, + postprocessing_yaml: Optional[str] = None, + num_workers: int = 0, + verbose: bool = True, + override_config: Optional[DiarizeConfig] = None, + ) -> Union[List[List[str]], Tuple[List[List[str]], List[torch.Tensor]]]: + """One-click runner function for diarization. + + Args: + audio: (a single or list) of paths to audio files or path to a manifest file. + batch_size: (int) Batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + include_tensor_outputs: (bool) Include raw speaker activity probabilities to the output. + See Returns: for more details. + postprocessing_yaml: Optional(str) Path to .yaml file with postprocessing parameters. + num_workers: (int) Number of workers for DataLoader. + verbose: (bool) Whether to display tqdm progress bar. + override_config: (Optional[DiarizeConfig]) A config to override the default config. + + Returns: + *if include_tensor_outputs is False: A list of lists of speech segments with a corresponding speaker index, + in format "[begin_seconds, end_seconds, speaker_index]". + *if include_tensor_outputs is True: A tuple of the above list + and list of tensors of raw speaker activity probabilities. + """ + return super().diarize( + audio=audio, + batch_size=batch_size, + include_tensor_outputs=include_tensor_outputs, + postprocessing_yaml=postprocessing_yaml, + num_workers=num_workers, + verbose=verbose, + override_config=override_config, + ) diff --git a/nemo/collections/asr/parts/mixins/diarization.py b/nemo/collections/asr/parts/mixins/diarization.py new file mode 100644 index 000000000000..fe8f6bbecb21 --- /dev/null +++ b/nemo/collections/asr/parts/mixins/diarization.py @@ -0,0 +1,493 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, get_uniqname_from_filepath +from nemo.collections.asr.parts.utils.vad_utils import PostProcessingParams, load_postprocessing_from_yaml +from nemo.collections.common.data.utils import move_data_to_device +from nemo.utils import logging + +GenericDiarizationType = Union[List[Any], List[List[Any]], Tuple[Any], Tuple[List[Any]]] + + +@dataclass +class InternalDiarizeConfig: + """Internal diarization configuration parameters for diarization inference.""" + + # Internal values + device: Optional[torch.device] = None + dtype: Optional[torch.dtype] = None + training_mode: bool = False + logging_level: Optional[Any] = None + + # Preprocessor values + dither_value: float = 0.0 + pad_to_value: int = 0 + + # Scratch space + temp_dir: Optional[str] = None + manifest_filepath: Optional[str] = None + + +@dataclass +class DiarizeConfig: + """Configuration parameters for diarization inference.""" + + session_len_sec: float = -1 # End-to-end diarization session length limit in seconds + batch_size: int = 1 + num_workers: int = 1 + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + verbose: bool = True + include_tensor_outputs: bool = False + postprocessing_params: PostProcessingParams = None + + # Utility + _internal: Optional[InternalDiarizeConfig] = None + + +def get_value_from_diarization_config(diarcfg, key, default): + """ + Utility function to get a value from the diarization config. + If the value is not present in the diarization config, the default value is returned. + + Args: + diarcfg: A dataclass that represents the diarization config. + key: The name of the arg to retrieve. + default: The default value to return if the key is not present in the diarization config. + + Returns: + The value of the key in the diarization config or the default value. + """ + if hasattr(diarcfg, key): + return getattr(diarcfg, key) + else: + logging.debug( + f"Using default value of {default} for {key} because it is not present \ + in the diarization config {diarcfg}." + ) + return default + + +class SpkDiarizationMixin(ABC): + """ + An abstract class for diarize-able models. + + Creates a template function `diarize()` that provides an interface to perform transcription of audio tensors or + filepaths. + + The following abstract classes must be implemented by the subclass: + + - `_setup_diarize_dataloader()`: + Setup the dataloader for diarization. Receives the output from + `_diarize_input_manifest_processing()`. + + - `_diarize_forward()`: + Implements the model's custom forward pass to return outputs that are processed by + `_diarize_output_processing()`. + + - `_diarize_output_processing()`: + Implements the post processing of the model's outputs to return the results to + the user. The result can be a list of objects, list of list of objects, tuple of objects, tuple of list of + objects, or a dict of list of objects. + + """ + + def __init__(self): + self._diarize_audio_rttm_map = {} + + @torch.inference_mode() + def diarize( + self, + audio: Union[str, List[str], np.ndarray, DataLoader], + batch_size: int = 1, + include_tensor_outputs: bool = False, + postprocessing_yaml: Optional[str] = None, + num_workers: int = 1, + verbose: bool = False, + override_config: Optional[DiarizeConfig] = None, + **config_kwargs, + ) -> GenericDiarizationType: + """ + Takes paths to audio files and returns speaker labels + """ + + if override_config is None: + postprocessing_params = load_postprocessing_from_yaml(postprocessing_yaml) + diarize_cfg = DiarizeConfig( + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + include_tensor_outputs=include_tensor_outputs, + postprocessing_yaml=postprocessing_yaml, + postprocessing_params=postprocessing_params, + **config_kwargs, + ) + else: + if not hasattr(override_config, '_internal'): + raise ValueError( + "`diarize_cfg must have an `_internal` argument, which must be of an object of type " + "InternalDiarizeConfig or its subclass." + ) + + if override_config._internal is None: + override_config._internal = InternalDiarizeConfig() + + diarize_cfg = override_config + + # Add new internal config + if diarize_cfg._internal is None: + diarize_cfg._internal = InternalDiarizeConfig() + else: + # Check if internal config is valid + if not isinstance(diarize_cfg._internal, InternalDiarizeConfig): + raise ValueError( + "`diarize_cfg._internal` must be of an object of type InternalDiarizeConfig or " "its subclass" + ) + + # Hold the results here + results = None + + try: + generator = self.diarize_generator(audio, override_config=diarize_cfg) + + for processed_outputs in generator: + # Store results + if isinstance(processed_outputs, list): + # Create a results of the same type as each element in processed_outputs + if results is None: + results = [] + + results.extend(processed_outputs) + + elif isinstance(processed_outputs, tuple): + # Create a results of the same type as each element in processed_outputs + if results is None: + results = tuple([[] for _ in processed_outputs]) + + # If nested list structure + if isinstance(processed_outputs[0], list): + for i, processed_output in enumerate(processed_outputs): + results[i].extend(processed_output) + else: + # If flat list structure + if len(processed_outputs) != len(results): + raise RuntimeError( + f"The number of elements in the result ({len(results)}) does not " + f"match the results of the current batch ({len(processed_outputs)})." + ) + + for i, processed_output in enumerate(processed_outputs): + results[i].append(processed_output) + + except StopIteration: + pass + + return results + + def diarize_generator(self, audio, override_config: Optional[DiarizeConfig]): + """ + A generator version of `diarize` function. + """ + if override_config is None: + override_config = DiarizeConfig() + + if not hasattr(override_config, '_internal'): + raise ValueError( + "`diarize_cfg must have an `_internal` argument, which must be of an object of type " + "InternalDiarizeConfig or its subclass." + ) + + # Add new internal config + if override_config._internal is None: + override_config._internal = InternalDiarizeConfig() + else: + # Check if internal config is valid + if not isinstance(override_config._internal, InternalDiarizeConfig): + raise ValueError( + "`diarize_cfg._internal` must be of an object of type InternalDiarizeConfig or " "its subclass" + ) + + diarize_cfg = override_config + + try: + # Initialize and assert the diarization environment + self._diarize_on_begin(audio, diarize_cfg) + + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + diarize_cfg._internal.temp_dir = tmpdir + + # Create a DataLoader if not already present + if not isinstance(audio, DataLoader): + dataloader = self._diarize_input_processing(audio, diarize_cfg) + else: + dataloader = audio + + if hasattr(diarize_cfg, 'verbose'): + verbose = diarize_cfg.verbose + else: + verbose = True + + for batch_idx, test_batch in enumerate(tqdm(dataloader, desc="Diarizing", disable=not verbose)): + # Move batch to device + test_batch = move_data_to_device(test_batch, diarize_cfg._internal.device) + uniq_ids = list(self._diarize_audio_rttm_map.keys())[ + batch_idx * diarize_cfg.batch_size : (batch_idx + 1) * diarize_cfg.batch_size + ] + + # Run forward pass + pred_outputs = self._diarize_forward(test_batch) + processed_outputs = self._diarize_output_processing(pred_outputs, uniq_ids, diarize_cfg) + + # Yield results if generator + yield processed_outputs + + # clear up memory + del test_batch, pred_outputs, processed_outputs + torch.cuda.empty_cache() + + finally: + # set mode back to its original value + self._diarize_on_end(diarize_cfg) + + def _input_audio_to_rttm_processing(self, audio_files: List[str]) -> List[Dict[str, Union[str, float]]]: + """ + Generate manifest style dict if `audio` is a list of paths to audio files. + + Args: + audio_files: A list of paths to audio files. + + Returns: + audio_rttm_map_dict A list of manifest style dicts. + """ + audio_rttm_map_dict = {} + for audio_file in audio_files: + uniq_id = get_uniqname_from_filepath(audio_file) + entry = { + 'uniq_id': uniq_id, + 'audio_filepath': audio_file, + 'offset': 0.0, + 'duration': None, + 'text': '-', + 'label': 'infer', + } + audio_rttm_map_dict[uniq_id] = entry + return audio_rttm_map_dict + + def _diarize_on_begin(self, audio: Union[str, List[str]], diarcfg: DiarizeConfig): + """ + Internal function to setup the model for diarization. Perform all setup and pre-checks here. + + Args: + audio (Union[str, List[str]]): Of type `GenericDiarizationType` + diarcfg (DiarizeConfig): An instance of `DiarizeConfig`. + """ + if audio is None: + return {} + + if isinstance(audio, str): + audio = [audio] + + if isinstance(audio, list) and len(audio) == 0: + return {} + + # Set num_workers + num_workers = get_value_from_diarization_config(diarcfg, 'num_workers', default=1) + + if num_workers is None: + _batch_size = get_value_from_diarization_config(diarcfg, 'batch_size', default=1) + num_workers = min(_batch_size, os.cpu_count() - 1) + + # Assign num_workers if available as key in diarcfg + if hasattr(diarcfg, 'num_workers'): + diarcfg.num_workers = num_workers + + # Model's mode and device + diarcfg._internal.training_mode = self.training + + # Switch model to evaluation mode + if hasattr(self, 'preprocessor'): + if hasattr(self.preprocessor, 'featurizer') and hasattr(self.preprocessor.featurizer, 'dither'): + diarcfg._internal.dither_value = self.preprocessor.featurizer.dither + self.preprocessor.featurizer.dither = 0.0 + + if hasattr(self.preprocessor, 'featurizer') and hasattr(self.preprocessor.featurizer, 'pad_to'): + diarcfg._internal.pad_to_value = self.preprocessor.featurizer.pad_to + self.preprocessor.featurizer.pad_to = 0 + + # Switch model to evaluation mode + self.eval() + + # Disable logging + diarcfg._internal.logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + def _diarize_input_processing(self, audio, diarcfg: DiarizeConfig): + """ + Internal function to process the input audio data and return a DataLoader. This function is called by + `diarize()` and `diarize_generator()` to setup the input data for diarization. + + Args: + audio: Of type `GenericDiarizationType` + diarcfg: The diarization config dataclass. Subclasses can change this to a different dataclass if needed. + + Returns: + A DataLoader object that is used to iterate over the input audio data. + """ + if isinstance(audio, (list, tuple)): + if len(audio) == 0: + raise ValueError("Input `audio` is empty") + else: + # Assume it is a single variable, so wrap it in a list + audio = [audio] + + # Check if audio is a list of strings (filepaths or manifests) + if isinstance(audio[0], str): + if len(audio) == 1 and audio[0].endswith('.json') or audio[0].endswith('.jsonl'): + # Assume it is a path to a manifest file + diarcfg._internal.manifest_filepath = audio[0] + self._diarize_audio_rttm_map = audio_rttm_map(audio[0]) + audio_files = [] + for uniq_id, meta_dict in self._diarize_audio_rttm_map.items(): + audio_files.append(meta_dict['audio_filepath']) + else: + # Make `audio_files` a list of audio file paths + audio_files = list(audio) + self._diarize_audio_rttm_map = self._input_audio_to_rttm_processing(audio_files=audio_files) + + tmp_dir = diarcfg._internal.temp_dir + ds_config = self._diarize_input_manifest_processing(audio_files, tmp_dir, diarcfg) + + temp_dataloader = self._setup_diarize_dataloader(ds_config) + return temp_dataloader + + else: + raise ValueError( + f"Input `audio` is of type {type(audio[0])}. " "Only `str` (path to audio file) is supported as input." + ) + + def _diarize_input_manifest_processing( + self, audio_files: List[str], temp_dir: str, diarcfg: DiarizeConfig + ) -> Dict[str, Any]: + """ + Internal function to process the input audio filepaths and return a config dict for the dataloader. + + Args: + audio_files: A list of string filepaths for audio files. + temp_dir: A temporary directory to store intermediate files. + diarcfg: The diarization config dataclass. Subclasses can change this to a different dataclass if needed. + + Returns: + A config dict that is used to setup the dataloader for diarization. + """ + with open(os.path.join(temp_dir, 'manifest.json'), 'w', encoding='utf-8') as fp: + for audio_file in audio_files: + if isinstance(audio_file, str): + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''} + fp.write(json.dumps(entry) + '\n') + elif isinstance(audio_file, dict): + fp.write(json.dumps(audio_file) + '\n') + else: + raise ValueError( + f"Input `audio` is of type {type(audio_file)}. " + "Only `str` (path to audio file) or `dict` are supported as input." + ) + + ds_config = { + 'paths2audio_files': audio_files, + 'batch_size': get_value_from_diarization_config(diarcfg, 'batch_size', 1), + 'temp_dir': temp_dir, + 'session_len_sec': get_value_from_diarization_config(diarcfg, 'session_len_sec', diarcfg.session_len_sec), + 'num_workers': get_value_from_diarization_config(diarcfg, 'num_workers', 1), + } + + return ds_config + + @abstractmethod + def _setup_diarize_dataloader(self, config: Dict) -> DataLoader: + """ + Internal function to setup the dataloader for diarization. This function is called by + `diarize()` and `diarize_generator()` to setup the input data for diarization. + + Args: + config: A config dict that is used to setup the dataloader for diarization. + It can be generated by `_diarize_input_manifest_processing()`. + + Returns: + A DataLoader object that is used to iterate over the input audio data. + """ + pass + + @abstractmethod + def _diarize_forward(self, batch: Any): + """ + Internal function to perform the model's custom forward pass to return outputs that are processed by + `_diarize_output_processing()`. + This function is called by `diarize()` and `diarize_generator()` to perform the model's forward pass. + + Args: + batch: A batch of input data from the data loader that is used to perform the model's forward pass. + + Returns: + The model's outputs that are processed by `_diarize_output_processing()`. + """ + pass + + @abstractmethod + def _diarize_output_processing(self, outputs, uniq_ids, diarcfg: DiarizeConfig) -> GenericDiarizationType: + """ + Internal function to process the model's outputs to return the results to the user. This function is called by + `diarize()` and `diarize_generator()` to process the model's outputs. + + Args: + outputs: The model's outputs that are processed by `_diarize_forward()`. + uniq_ids: List of unique recording identificators in batch + diarcfg: The diarization config dataclass. Subclasses can change this to a different dataclass if needed. + + Returns: + The output can be a list of + objects, list of list of objects, tuple of objects, tuple of list of objects. + Its type is defined in `GenericDiarizationType`. + """ + pass + + def _diarize_on_end(self, diarcfg: DiarizeConfig): + """ + Internal function to teardown the model after transcription. Perform all teardown and post-checks here. + + Args: + diarcfg: The diarization config dataclass. Subclasses can change this to a different dataclass if needed. + """ + # set mode back to its original value + self.train(mode=diarcfg._internal.training_mode) + + if hasattr(self, 'preprocessor'): + if hasattr(self.preprocessor, 'featurizer') and hasattr(self.preprocessor.featurizer, 'dither'): + self.preprocessor.featurizer.dither = diarcfg._internal.dither_value + + if hasattr(self.preprocessor, 'featurizer') and hasattr(self.preprocessor.featurizer, 'pad_to'): + self.preprocessor.featurizer.pad_to = diarcfg._internal.pad_to_value + + if diarcfg._internal.logging_level is not None: + logging.set_verbosity(diarcfg._internal.logging_level) diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 83a811ee4adb..fc29129295c0 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -18,6 +18,7 @@ import multiprocessing import os import shutil +from dataclasses import dataclass from itertools import repeat from math import ceil, floor from pathlib import Path @@ -29,6 +30,7 @@ import numpy as np import pandas as pd import torch +import yaml from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection @@ -44,6 +46,53 @@ """ +@dataclass +class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ + + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion + + +def load_postprocessing_from_yaml(postprocessing_yaml: str = None) -> PostProcessingParams: + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + + def prepare_manifest(config: dict) -> str: """ Perform VAD on long audio snippet might cause CUDA out of memory issue. @@ -1785,3 +1834,62 @@ def ts_vad_post_processing( cfg_vad_params.pad_offset = 0.0 speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) return speech_segments + + +def predlist_to_timestamps( + batch_preds_list: List[torch.Tensor], + audio_rttm_map_dict: Dict[str, Dict[str, Union[float, int]]], + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int, + bypass_postprocessing: bool = False, + precision: int = 2, +) -> List[List[float]]: + """ + Converts floating point number tensor diarization results to timestamps using VAD style + post-processing methods. + + Args: + batch_preds_list (List[Tensor]): + Tensor diarization results for each sample. + Dimension: [(num_frames, num_speakers), ...] + audio_rttm_map_dict (Dict[str, Dict[str, Union[float, int]]]): + Dictionary mapping unique audio file names to their rttm file entries. + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + precision (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Returns: + total_speaker_timestamps (List[List[List[float]]]): + A list of lists of timestamp tensors for each session (utterance) + Levels: + - Session-level (uniq_id) [session1_list, session2_list,...] + - Segment-level: [[start1, end1], [start2, end2],...]] + - List of start and end timestamp [start, end] + """ + total_speaker_timestamps = [] + pp_message = "Binarization" if bypass_postprocessing else "Post-processing" + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message + ): + offset = audio_rttm_values['offset'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) + ts_mat = ts_mat + offset + ts_seg_raw_list = ts_mat.tolist() + ts_seg_list = [[round(stt, precision), round(end, precision)] for (stt, end) in ts_seg_raw_list] + speaker_timestamps[spk_id].extend(ts_seg_list) + total_speaker_timestamps.append(speaker_timestamps) + return total_speaker_timestamps diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 12f5a9b3ecff..afd35e01c993 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -22,6 +22,7 @@ import pandas as pd from nemo.collections.common.parts.preprocessing import manifest, parsers +from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging, logging_mode @@ -1513,14 +1514,6 @@ def __init__( for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): # Training mode - rttm_labels = [] - with open(item['rttm_file'], 'r') as f: - for index, rttm_line in enumerate(f.readlines()): - rttm = rttm_line.strip().split() - start = round(float(rttm[3]), round_digits) - end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) - speaker = rttm[7] - rttm_labels.append('{} {} {}'.format(start, end, speaker)) audio_files.append(item['audio_file']) uniq_ids.append(item['uniq_id']) durations.append(item['duration']) @@ -1540,6 +1533,13 @@ def __init__( def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: """Parse each rttm file and save it to in Dict format""" item = json.loads(line) + + if 'offset' not in item or item['offset'] is None: + item['offset'] = 0 + + # If the name `audio_file` is not present in the manifest file, replace it. + if 'audio_file' in item: + pass if 'audio_filename' in item: item['audio_file'] = item.pop('audio_filename') elif 'audio_filepath' in item: @@ -1548,25 +1548,54 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: raise ValueError( f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." ) + + # Audio file handling depending on the types if isinstance(item['audio_file'], list): - item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + for single_audio_file in item['audio_file']: + audio_file_list.append(get_full_path(audio_file=single_audio_file, manifest_file=manifest_file)) + item['audio_file'] = audio_file_list + elif isinstance(item['audio_file'], str): + item['audio_file'] = get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + if not os.path.exists(item['audio_file']): + raise FileNotFoundError(f"Audio file not found: {item['audio_file']}") + else: + raise ValueError( + f"Manifest file has invalid json line " + f"structure: {line} without proper audio file value: {item['audio_file']}." + ) + + # If the name `rttm_file` is not present in the manifest file, replace it or assign None. + if 'rttm_file' in item: + pass + elif 'rttm_filename' in item: + item['rttm_file'] = item.pop('rttm_filename') + elif 'rttm_filepath' in item: + item['rttm_file'] = item.pop('rttm_filepath') else: - item['audio_file'] = os.path.expanduser(item['audio_file']) + item['rttm_file'] = None + + # If item['rttm_file'] is not None and the RTTM file exists, get the full path + if item['rttm_file'] is not None: + item['rttm_file'] = get_full_path(audio_file=item['rttm_file'], manifest_file=manifest_file) + if not os.path.exists(item['rttm_file']): + raise FileNotFoundError(f"RTTM file not found: {item['rttm_file']}") + + # Handling `uniq_id` string + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] - if not isinstance(item['audio_file'], list): - if 'uniq_id' not in item: - item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] - elif 'uniq_id' not in item: + if not isinstance(item['uniq_id'], str): raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") if 'duration' not in item: raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( audio_file=item['audio_file'], uniq_id=item['uniq_id'], duration=item['duration'], - rttm_file=item['rttm_filepath'], - offset=item.get('offset', None), + rttm_file=item['rttm_file'], + offset=item.get('offset', 0), ) return item diff --git a/tests/collections/speaker_tasks/mixins/test_diarization.py b/tests/collections/speaker_tasks/mixins/test_diarization.py new file mode 100644 index 000000000000..84ec29d84437 --- /dev/null +++ b/tests/collections/speaker_tasks/mixins/test_diarization.py @@ -0,0 +1,271 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List + +import pytest +import torch +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin + + +class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.encoder = torch.nn.Linear(1, 1) + + self.execution_count = 0 + self.flag_begin = False + + def forward(self, x): + # Input: [1, 1] Output = [1, 1 + out = self.encoder(x) + return out + + +@pytest.mark.with_downloads() +@pytest.fixture() +def audio_files(test_data_dir): + """ + Returns a list of audio files for testing. + """ + import soundfile as sf + + audio_file1 = os.path.join(test_data_dir, "an4_speaker", "an4", "wav", "an4_clstk", "fash", "an251-fash-b.wav") + audio_file2 = os.path.join(test_data_dir, "an4_speaker", "an4", "wav", "an4_clstk", "ffmm", "cen1-ffmm-b.wav") + + audio1, _ = sf.read(audio_file1, dtype='float32') + audio2, _ = sf.read(audio_file2, dtype='float32') + + return audio1, audio2 + + +class DiarizableDummy(DummyModel, SpkDiarizationMixin): + def _diarize_on_begin(self, audio, diarcfg: DiarizeConfig): + super()._diarize_on_begin(audio, diarcfg) + self.flag_begin = True + + def _diarize_input_manifest_processing(self, audio_files: List[str], temp_dir: str, diarcfg: DiarizeConfig): + # Create a dummy manifest + manifest_path = os.path.join(temp_dir, 'dummy_manifest.json') + with open(manifest_path, 'w', encoding='utf-8') as fp: + for audio_file in audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''} + fp.write(json.dumps(entry) + '\n') + + ds_config = { + 'paths2audio_files': audio_files, + 'batch_size': diarcfg.batch_size, + 'temp_dir': temp_dir, + 'session_len_sec': diarcfg.session_len_sec, + 'num_workers': diarcfg.num_workers, + } + return ds_config + + def _setup_diarize_dataloader(self, config: Dict) -> DataLoader: + class DummyDataset(Dataset): + def __init__(self, audio_files: List[str], config: Dict): + self.audio_files = audio_files + self.config = config + + def __getitem__(self, index): + data = self.audio_files[index] + data = torch.tensor([float(data)]).view(1) + return data + + def __len__(self): + return len(self.audio_files) + + dataset = DummyDataset(config['paths2audio_files'], config) + + return DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + pin_memory=False, + drop_last=False, + ) + + def _diarize_forward(self, batch: Any): + output = self(batch) + return output + + def _diarize_output_processing(self, outputs, uniq_ids, diarcfg: DiarizeConfig): + self.execution_count += 1 + + result = [] + for output in outputs: + result.append(float(output.item())) + + if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'dict': + results = {'output': result} + return results + + if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'dict2': + results = [{'output': res} for res in result] + return results + + if hasattr(diarcfg, 'output_type') and diarcfg.output_type == 'tuple': + result = tuple(result) + return result + + # Pass list of results by default + return result + + +class DummyDataset(Dataset): + def __init__(self, audio_tensors: List[str], config: Dict = None): + self.audio_tensors = audio_tensors + self.config = config + + def __getitem__(self, index): + data = self.audio_tensors[index] + samples = torch.tensor(data) + # Calculate seq length + seq_len = torch.tensor(samples.shape[0], dtype=torch.long) + + # Dummy text tokens + targets = torch.tensor([0], dtype=torch.long) + targets_len = torch.tensor(1, dtype=torch.long) + return (samples, seq_len, targets, targets_len) + + def __len__(self): + return len(self.audio_tensors) + + +@pytest.fixture() +def dummy_model(): + return DiarizableDummy() + + +class TestSpkDiarizationMixin: + @pytest.mark.unit + def test_constructor_non_instance(self): + model = DummyModel() + assert not isinstance(model, SpkDiarizationMixin) + assert not hasattr(model, 'diarize') + + @pytest.mark.unit + def test_diarize(self, dummy_model): + dummy_model = dummy_model.eval() + dummy_model.encoder.weight.data.fill_(1.0) + dummy_model.encoder.bias.data.fill_(0.0) + + audio = ['1.0', '2.0', '3.0'] + outputs = dummy_model.diarize(audio, batch_size=1) + assert len(outputs) == 3 + assert outputs[0] == 1.0 + assert outputs[1] == 2.0 + assert outputs[2] == 3.0 + + @pytest.mark.unit + def test_diarize_generator(self, dummy_model): + dummy_model = dummy_model.eval() + dummy_model.encoder.weight.data.fill_(1.0) + dummy_model.encoder.bias.data.fill_(0.0) + + audio = ['1.0', '2.0', '3.0'] + + diarize_config = DiarizeConfig(batch_size=1) + generator = dummy_model.diarize_generator(audio, override_config=diarize_config) + + outputs = [] + index = 1 + for result in generator: + outputs.extend(result) + assert len(result) == 1 + assert len(outputs) == index + index += 1 + + assert len(outputs) == 3 + assert outputs[0] == 1.0 + assert outputs[1] == 2.0 + assert outputs[2] == 3.0 + + @pytest.mark.unit + def test_diarize_generator_explicit_stop_check(self, dummy_model): + dummy_model = dummy_model.eval() + dummy_model.encoder.weight.data.fill_(1.0) + dummy_model.encoder.bias.data.fill_(0.0) + + audio = ['1.0', '2.0', '3.0'] + + diarize_config = DiarizeConfig(batch_size=1) + generator = dummy_model.diarize_generator(audio, override_config=diarize_config) + + outputs = [] + index = 1 + while True: + try: + result = next(generator) + except StopIteration: + break + outputs.extend(result) + assert len(result) == 1 + assert len(outputs) == index + index += 1 + + assert len(outputs) == 3 + assert outputs[0] == 1.0 + assert outputs[1] == 2.0 + assert outputs[2] == 3.0 + + @pytest.mark.unit + def test_diarize_check_flags(self, dummy_model): + dummy_model = dummy_model.eval() + + audio = ['1.0', '2.0', '3.0'] + dummy_model.diarize(audio, batch_size=1) + assert dummy_model.flag_begin + + @pytest.mark.unit + def test_transribe_override_config_incorrect(self, dummy_model): + # Not subclassing DiarizeConfig + @dataclass + class OverrideConfig: + batch_size: int = 1 + output_type: str = 'dict' + + dummy_model = dummy_model.eval() + + audio = [1.0, 2.0, 3.0] + override_cfg = OverrideConfig(batch_size=1, output_type='dict') + with pytest.raises(ValueError): + _ = dummy_model.diarize(audio, override_config=override_cfg) + + @pytest.mark.unit + def test_transribe_override_config_correct(self, dummy_model): + @dataclass + class OverrideConfig(DiarizeConfig): + output_type: str = 'dict' + verbose: bool = False + + dummy_model = dummy_model.eval() + dummy_model.encoder.weight.data.fill_(1.0) + dummy_model.encoder.bias.data.fill_(0.0) + + audio = ['1.0', '2.0', '3.0'] + override_cfg = OverrideConfig(batch_size=1, output_type='list') + outputs = dummy_model.diarize(audio, override_config=override_cfg) + + assert isinstance(outputs, list) + assert len(outputs) == 3 + assert outputs[0] == 1.0 + assert outputs[1] == 2.0 + assert outputs[2] == 3.0 From b74b6ec9e126259ff0d6efbb8ae6e6c5c9296b7a Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 18 Dec 2024 16:47:44 -0500 Subject: [PATCH 15/20] Fix peft inference (#11568) * fix peft inference (trainer not attached) Signed-off-by: Chen Cui * enable greedy generation Signed-off-by: Chen Cui * add ci test for PEFT inference Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * typo Signed-off-by: Chen Cui * fix test Signed-off-by: Chen Cui * handle remove_special_tokens Signed-off-by: Chen Cui * move llama3configci to common file Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * incoming commit Signed-off-by: Chen Cui * address comment Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Co-authored-by: cuichenx --- .github/workflows/cicd-main.yml | 72 +++++++++++++++--------- nemo/collections/llm/inference/base.py | 7 ++- nemo/lightning/pytorch/callbacks/peft.py | 5 +- nemo/lightning/pytorch/utils.py | 12 ++++ scripts/llm/generate.py | 11 +++- tests/collections/llm/common.py | 10 ++++ tests/collections/llm/gpt_finetuning.py | 11 +--- tests/collections/llm/peft/lora_merge.py | 11 ---- 8 files changed, 87 insertions(+), 52 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index fce4ef2acfbd..25e0c5252100 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4254,7 +4254,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4264,7 +4264,7 @@ jobs: --mbs 1 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4283,7 +4283,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4293,7 +4293,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4312,7 +4312,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4322,7 +4322,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4341,7 +4341,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4351,7 +4351,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4370,7 +4370,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4380,7 +4380,7 @@ jobs: --mbs 1 --packed python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4399,7 +4399,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4409,7 +4409,7 @@ jobs: --mbs 1 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4428,7 +4428,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4438,7 +4438,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4457,7 +4457,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4467,7 +4467,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4486,7 +4486,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4496,7 +4496,7 @@ jobs: --mbs 2 python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4514,7 +4514,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4524,7 +4524,7 @@ jobs: --mbs 1 --packed python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4542,7 +4542,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4552,7 +4552,7 @@ jobs: --mbs 1 --packed python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4569,7 +4569,7 @@ jobs: RUNNER: self-hosted-azure SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4579,7 +4579,7 @@ jobs: --mbs 1 --packed python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4597,7 +4597,7 @@ jobs: SCRIPT: | python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 3 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4608,7 +4608,7 @@ jobs: --chat_dataset_path /home/TestData/nemo2_data/chat python tests/collections/llm/gpt_finetuning.py \ - --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --devices 2 \ --max_steps 6 \ --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ @@ -4702,9 +4702,26 @@ jobs: SCRIPT: | python tests/collections/llm/peft/lora_merge.py \ - --lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \ + --lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint_v2/ \ --output_path=/tmp/nemo2_lora_merge/${{ github.run_id }} + L2_NEMO_2_LoRA_Inference: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_Inference') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + + python scripts/llm/generate.py \ + --model_path /home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint_v2/ \ + --tp 1 \ + --pp 1 \ + --devices 1 \ + --top_p 0.0 \ + --top_k 1 \ + --num_tokens_to_generate 3 + L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4900,6 +4917,7 @@ jobs: - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 - L2_NEMO_2_LoRA_MERGE + - L2_NEMO_2_LoRA_Inference - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_FP8 - L2_Community_LLM_Checkpoints_tests_Llama3 diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 6c89a1b42b15..dd53d97b21ad 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import json from pathlib import Path from typing import Optional, Union @@ -61,7 +61,10 @@ def detokenize(self, tokens, remove_special_tokens=False): Returns: str: The detokenized string. """ - return self.tokenizer.ids_to_text(tokens, remove_special_tokens) + if 'remove_special_tokens' in inspect.signature(self.tokenizer.ids_to_text).parameters: + return self.tokenizer.ids_to_text(tokens, remove_special_tokens) + else: + return self.tokenizer.ids_to_text(tokens) def tokenize(self, prompt): """ diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 09a0885ead17..d138117e4599 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -32,6 +32,7 @@ from nemo.lightning.megatron_parallel import MegatronParallel from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.utils import is_trainer_attached from nemo.utils import logging from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO @@ -105,7 +106,7 @@ def __call__(self, model: nn.Module) -> nn.Module: else: model.walk(self.transform) - if hasattr(model, "trainer") and model.trainer.state.fn != TrainerFn.FITTING: + if is_trainer_attached(model) and model.trainer.state.fn != TrainerFn.FITTING: self.freeze_model(model) return model @@ -128,7 +129,7 @@ def freeze_model(self, model: nn.Module) -> None: model.module.freeze() else: model.freeze() - if hasattr(model, "trainer") and model.trainer.state.fn == TrainerFn.FITTING: + if is_trainer_attached(model) and model.trainer.state.fn == TrainerFn.FITTING: model.train(mode=True) def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: diff --git a/nemo/lightning/pytorch/utils.py b/nemo/lightning/pytorch/utils.py index 045cf79b5777..77fd702da410 100644 --- a/nemo/lightning/pytorch/utils.py +++ b/nemo/lightning/pytorch/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import torch @@ -55,3 +56,14 @@ def dtype_from_hf(config): return dtype_from_str(torch_dtype) else: raise ValueError("torch_dtype is not of type str/torch.dtype") + + +def is_trainer_attached(model: pl.LightningModule): + """ + Returns true if trainer is attached to a model + """ + try: + trainer = model.trainer + return True + except (AttributeError, RuntimeError): + return False diff --git a/scripts/llm/generate.py b/scripts/llm/generate.py index 56653aa3bbb5..f01c384604a2 100644 --- a/scripts/llm/generate.py +++ b/scripts/llm/generate.py @@ -72,6 +72,12 @@ def get_args(): default=0.95, help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) + parser.add_argument( + "--top_k", + type=float, + default=0, + help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", + ) parser.add_argument( "--num_tokens_to_generate", type=int, @@ -118,7 +124,10 @@ def get_args(): prompts=prompts, trainer=trainer, inference_params=CommonInferenceParams( - temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=args.num_tokens_to_generate + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + num_tokens_to_generate=args.num_tokens_to_generate, ), text_only=True, ) diff --git a/tests/collections/llm/common.py b/tests/collections/llm/common.py index 8e93c9c84c9e..f8015950aa93 100644 --- a/tests/collections/llm/common.py +++ b/tests/collections/llm/common.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from dataclasses import dataclass import lightning.pytorch as pl import nemo_run as run @@ -191,3 +192,12 @@ def verify_precision(tensor: torch.Tensor) -> None: assert tensor.dtype == precision return verify_precision + + +@dataclass +class Llama3ConfigCI(llm.Llama3Config8B): + seq_length: int = 2048 + num_layers: int = 2 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 8 diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py index be5331c32f3b..384faa383435 100644 --- a/tests/collections/llm/gpt_finetuning.py +++ b/tests/collections/llm/gpt_finetuning.py @@ -22,17 +22,10 @@ from nemo.collections import llm from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - -## NOTE: This script is present for github-actions testing only. +from tests.collections.llm.common import Llama3ConfigCI -@dataclass -class Llama3ConfigCI(llm.Llama3Config8B): - seq_length: int = 2048 - num_layers: int = 2 - hidden_size: int = 768 - ffn_hidden_size: int = 3072 - num_attention_heads: int = 8 +## NOTE: This script is present for github-actions testing only. def get_args(): diff --git a/tests/collections/llm/peft/lora_merge.py b/tests/collections/llm/peft/lora_merge.py index 2ca7390ea7e6..0e0c9361c4f5 100644 --- a/tests/collections/llm/peft/lora_merge.py +++ b/tests/collections/llm/peft/lora_merge.py @@ -12,20 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -from dataclasses import dataclass - from nemo.collections import llm -@dataclass -class Llama3ConfigCI(llm.Llama3Config8B): - seq_length: int = 2048 - num_layers: int = 2 - hidden_size: int = 768 - ffn_hidden_size: int = 3072 - num_attention_heads: int = 8 - - def get_args(): parser = argparse.ArgumentParser(description='Merge LoRA weights with base LLM') parser.add_argument('--lora_checkpoint_path', type=str, help="Path to finetuned LORA checkpoint") From c46ba6f95f6c4e181c6b15e0e9a80b55731b272a Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Wed, 18 Dec 2024 18:53:06 -0800 Subject: [PATCH 16/20] Add fix docstring for speech commands (#11659) Signed-off-by: smajumdar --- tutorials/asr/Voice_Activity_Detection.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/asr/Voice_Activity_Detection.ipynb b/tutorials/asr/Voice_Activity_Detection.ipynb index fb3cef1b44ea..aa81b79ebd94 100644 --- a/tutorials/asr/Voice_Activity_Detection.ipynb +++ b/tutorials/asr/Voice_Activity_Detection.ipynb @@ -34,6 +34,7 @@ "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n", "\n", "## Install TorchAudio\n", + "## NOTE: TorchAudio installation may not work in all environments, please use Google Colab for best experience\n", "!pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html\n", "\n", "## Grab the config we'll use in this example\n", From 093ffc40fadb051a0659e114116ceee19ee18923 Mon Sep 17 00:00:00 2001 From: Huiying Date: Wed, 18 Dec 2024 19:21:10 -0800 Subject: [PATCH 17/20] update nemo container version for notebooks (#11651) Signed-off-by: Huiying Li --- tutorials/llm/llama-3/nemo2-sft-peft/README.rst | 4 ++-- tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/README.rst b/tutorials/llm/llama-3/nemo2-sft-peft/README.rst index 7adf2777db2c..d1bd7b87759c 100644 --- a/tutorials/llm/llama-3/nemo2-sft-peft/README.rst +++ b/tutorials/llm/llama-3/nemo2-sft-peft/README.rst @@ -20,7 +20,7 @@ Requirements * Software Requirements * Use the latest [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags) . Note that you must be logged in to the container registry to view this page. - * This notebook uses the container: `nvcr.io/nvidia/nemo:dev`. + * This notebook is tested on the container: `nvcr.io/nvidia/nemo:24.12-rc0`. * Get your Hugging Face [access token](https://huggingface.co/docs/hub/en/security-tokens), which will be used to obtain the tokenizer required during training. * NeMo 2.0 and NeMo-Run @@ -42,7 +42,7 @@ Start the NeMo Framework Container --rm -it \ -v ${PWD}:/workspace \ -w /workspace \ - nvcr.io/nvidia/nemo:dev bash + nvcr.io/nvidia/nemo:24.12-rc0 bash Once you are inside the container, you can run `nvidia-smi` to verify that the GPUs are accessible. diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb index aa463e2b84be..b3393d133a45 100644 --- a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb +++ b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb @@ -533,7 +533,7 @@ "\n", "2. [NeMo-Run GitHub repo](https://github.com/NVIDIA/NeMo-Run/)\n", "\n", - "3. NeMo Framework Container: `nvcr.io/nvidia/nemo:dev`\n", + "3. NeMo Framework Container: `nvcr.io/nvidia/nemo:24.12-rc0`\n", "\n", "\n", "\n", From a121c592c80c25d8bd85531551b571d6167b00ee Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 18 Dec 2024 23:28:02 -0500 Subject: [PATCH 18/20] Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT (#11631) * Fix Optimizer & LR scheduler Resume * fix unit test Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * typo Signed-off-by: Chen Cui * Fix consume samples * Fix unit tests * Apply isort and black reformatting Signed-off-by: suiyoubi --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Signed-off-by: suiyoubi Co-authored-by: Chen Cui Co-authored-by: cuichenx Co-authored-by: suiyoubi --- nemo/collections/llm/gpt/data/fine_tuning.py | 7 ++++--- nemo/lightning/pytorch/callbacks/peft.py | 7 ++++++- tests/lightning/test_data.py | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 0d866bb600fe..a22ed72f4656 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -93,6 +93,7 @@ def __init__( self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size self.validate_batch_size_for_packed_sequence() self.dataset_kwargs = dataset_kwargs or {} + self.init_global_step = 0 def validate_batch_size_for_packed_sequence(self): """ @@ -163,9 +164,7 @@ def state_dict(self) -> Dict[str, Any]: A dictionary containing datamodule state. """ - consumed_samples = self.data_sampler.compute_consumed_samples( - self.trainer.global_step - self.data_sampler.init_global_step - ) + consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step) return {"consumed_samples": consumed_samples} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -240,6 +239,8 @@ def _create_dataset(self, path, is_test=False, **kwargs): def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: # pylint: disable=C0115,C0116 + self.init_global_step = self.trainer.global_step + self.data_sampler.init_global_step = self.init_global_step return WrappedDataLoader( mode=mode, dataset=dataset, diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index d138117e4599..d2e93fe9ab42 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -204,7 +204,12 @@ def apply_transform(self, trainer): ) trainer.strategy.load_model_state_dict(adapter_state, strict=False) if trainer.state.fn == TrainerFn.FITTING: - trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True) + # Load optimizer + trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False) + # Load lr scheduler + if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None: + for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers): + config.scheduler.load_state_dict(lrs_state) for cb in trainer.callbacks[::-1]: if isinstance(cb, MegatronOptimizerModule): diff --git a/tests/lightning/test_data.py b/tests/lightning/test_data.py index 2519616766f4..b848bec3dae9 100644 --- a/tests/lightning/test_data.py +++ b/tests/lightning/test_data.py @@ -15,11 +15,18 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + + +@pytest.fixture +def trainer(): + return MagicMock() + @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_finetuning_module(mock_gpt_sft_dataset) -> None: +def test_finetuning_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import FineTuningDataModule dataset_root = 'random_root' @@ -30,6 +37,8 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once() @@ -38,7 +47,7 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None: @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_dolly_module(mock_gpt_sft_dataset) -> None: +def test_dolly_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import DollyDataModule datamodule = DollyDataModule( @@ -47,6 +56,8 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once() @@ -55,7 +66,7 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None: @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_squad_module(mock_gpt_sft_dataset) -> None: +def test_squad_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import SquadDataModule datamodule = SquadDataModule( @@ -64,6 +75,8 @@ def test_squad_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once() From 54615868cd78c47a092a4d5510ad1e9b0115b93f Mon Sep 17 00:00:00 2001 From: meatybobby Date: Wed, 18 Dec 2024 22:08:14 -0800 Subject: [PATCH 19/20] Add vlm generation function (#11063) * Add vlm inference * Add init * Apply isort and black reformatting Signed-off-by: meatybobby * Add KV cache and xattn cache in inference * Fix position id for KV cache * Apply isort and black reformatting Signed-off-by: meatybobby * Add doc string * pylint fix * Remove max_output_len in inference controller * Modify generate script * Apply isort and black reformatting Signed-off-by: meatybobby * Rename wrapped model * Rename var --------- Signed-off-by: meatybobby Co-authored-by: meatybobby --- nemo/collections/vlm/inference/__init__.py | 15 ++ nemo/collections/vlm/inference/base.py | 129 ++++++++++++++++++ nemo/collections/vlm/inference/vlm_engine.py | 52 +++++++ .../vlm/inference/vlm_inference_controller.py | 79 +++++++++++ .../vlm/inference/vlm_inference_wrapper.py | 119 ++++++++++++++++ nemo/collections/vlm/mllama/model/base.py | 16 ++- nemo/collections/vlm/mllama/model/language.py | 2 +- scripts/vlm/mllama_generate.py | 54 +++----- 8 files changed, 426 insertions(+), 40 deletions(-) create mode 100644 nemo/collections/vlm/inference/__init__.py create mode 100644 nemo/collections/vlm/inference/base.py create mode 100644 nemo/collections/vlm/inference/vlm_engine.py create mode 100644 nemo/collections/vlm/inference/vlm_inference_controller.py create mode 100644 nemo/collections/vlm/inference/vlm_inference_wrapper.py diff --git a/nemo/collections/vlm/inference/__init__.py b/nemo/collections/vlm/inference/__init__.py new file mode 100644 index 000000000000..6c338b383c73 --- /dev/null +++ b/nemo/collections/vlm/inference/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.vlm.inference.base import generate, setup_inference_wrapper, setup_model_and_tokenizer diff --git a/nemo/collections/vlm/inference/base.py b/nemo/collections/vlm/inference/base.py new file mode 100644 index 000000000000..bbc85a8ee4a8 --- /dev/null +++ b/nemo/collections/vlm/inference/base.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import pytorch_lightning as pl +import torch +import torch.distributed +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from transformers import AutoProcessor + +import nemo.lightning as nl +from nemo.collections import vlm +from nemo.collections.vlm.inference.vlm_engine import VLMEngine +from nemo.collections.vlm.inference.vlm_inference_controller import VLMTextGenerationController +from nemo.collections.vlm.inference.vlm_inference_wrapper import VLMInferenceWrapper + + +def _setup_trainer_and_restore_model(path: str, trainer: nl.Trainer, model: pl.LightningModule): + """Setup trainer and restore model from path""" + fabric = trainer.to_fabric() + model = fabric.load_model(path, model) + return model + + +def setup_inference_wrapper( + model, + tokenizer, + params_dtype: torch.dtype = torch.bfloat16, + inference_batch_times_seqlen_threshold: int = 1000, +): + """Set up inference wrapper for the model""" + config = model.config + + mcore_model = model.module.cuda() + mcore_model = mcore_model.to(params_dtype) + + inference_wrapped_model = VLMInferenceWrapper( + mcore_model, + InferenceWrapperConfig( + hidden_size=config.language_model_config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=tokenizer.vocab_size, + ), + ) + + return inference_wrapped_model + + +def setup_model_and_tokenizer( + path: str, + trainer: Optional[nl.Trainer] = None, + params_dtype: torch.dtype = torch.bfloat16, + inference_batch_times_seqlen_threshold: int = 1000, +): + """Set up model and tokenizer""" + model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + config = vlm.MLlamaConfig11BInstruct() + model = vlm.MLlamaModel(config, tokenizer=tokenizer) + _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) + + inference_wrapped_model = setup_inference_wrapper( + model, tokenizer, params_dtype, inference_batch_times_seqlen_threshold + ) + + return inference_wrapped_model, processor + + +def generate( + wrapped_model: VLMInferenceWrapper, + tokenizer, + image_processor, + prompts: list[str], + images: list, + max_batch_size: int = 4, + random_seed: Optional[int] = None, + inference_params: Optional[CommonInferenceParams] = None, +) -> dict: + """ + Generates text using a NeMo VLM model. + Args: + wrapped_model (VLMInferenceWrapper): The model inference wrapper. + tokenizer: tokenizer for the input text, + image_processor: image processor for the input image, + prompts (list[str]): The list of prompts to generate text for. + images (list): The list of images to generate text for. + max_batch_size (int, optional): The maximum batch size. Defaults to 4. + random_seed (Optional[int], optional): The random seed. Defaults to None. + inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in + Mcore's CommonInferenceParams. Defaults to None. + + Returns: + list[Union["InferenceRequest", str]]: A list of generated text, + either as a string or as an InferenceRequest object. + """ + text_generation_controller = VLMTextGenerationController( + inference_wrapped_model=wrapped_model, + tokenizer=tokenizer, + image_processor=image_processor, + ) + mcore_engine = VLMEngine( + text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed + ) + + common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=50) + + results = mcore_engine.generate( + prompts=prompts, + images=images, + common_inference_params=common_inference_params, + ) + + return results diff --git a/nemo/collections/vlm/inference/vlm_engine.py b/nemo/collections/vlm/inference/vlm_engine.py new file mode 100644 index 000000000000..bce373e7a2f5 --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_engine.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.inference_request import InferenceRequest +from PIL.Image import Image + + +class VLMEngine(MCoreEngine): + # pylint: disable=C0115,C0116 + def generate( + self, + prompts: List[str], + images: List[Image] = None, + common_inference_params: CommonInferenceParams = None, + ) -> dict: + # pylint: disable=C0115,C0116 + if self.random_seed: + torch.random.manual_seed(self.random_seed) + + for i in range(len(prompts)): + prompt = prompts[i] + image = images[i] if images is not None else None + prompt_tokens, image_dict = self.text_generation_controller.tokenize_prompt(prompt, image) + + # Reuse encoder_prompt from scheduler to pass image + self.scheduler.add_request( + prompt=prompt, + prompt_tokens=prompt_tokens, + encoder_prompt=image_dict, + inference_parameters=common_inference_params, + ) + + self.run_engine() + + result: List[InferenceRequest] = self.scheduler.completed_request_pool.values() + return result diff --git a/nemo/collections/vlm/inference/vlm_inference_controller.py b/nemo/collections/vlm/inference/vlm_inference_controller.py new file mode 100644 index 000000000000..9db1ce24031d --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_inference_controller.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) + + +class TokenizerWrapper: + # pylint: disable=C0115,C0116 + def __init__(self, tokenizer): + self.eod = tokenizer.eos_token_id + self.vocab_size = None + self._tokenizer = tokenizer + + def detokenize(self, tokens): + # pylint: disable=C0115,C0116 + return self._tokenizer.decode(tokens, skip_special_tokens=True) + + def tokenize(self, prompt): + # pylint: disable=C0115,C0116 + return self._tokenizer.encode(prompt, add_special_tokens=False) + + +class VLMTextGenerationController(SimpleTextGenerationController): + # pylint: disable=C0115,C0116 + def __init__(self, inference_wrapped_model, tokenizer, image_processor): + super().__init__(inference_wrapped_model, TokenizerWrapper(tokenizer)) + self.image_processor = image_processor + + def tokenize_prompt(self, prompt: str, image): + # pylint: disable=C0115,C0116 + tokens = self.tokenizer.tokenize(prompt) + if image is None: + image_dict = dict( + pixel_values=torch.zeros( + 1, 4, 3, self.image_processor.size['height'], self.image_processor.size['width'] + ), + aspect_ratio_ids=torch.tensor([0], dtype=torch.long), + num_tiles=[0], + ) + else: + image_dict = self.image_processor.preprocess(image, return_tensors='pt') + image_dict = { + k: v[0] for k, v in image_dict.items() if k in ["pixel_values", "aspect_ratio_ids", "num_tiles"] + } + return tokens, image_dict + + def prep_model_for_inference( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] + ): + """Preparing batch for inference, using respective wrapper's prep_model_for_inference method + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[int, InferenceRequest]): The input active requests + """ + images = list(map(lambda request: request.encoder_prompt, active_requests.values())) + + self.inference_wrapped_model.prep_model_for_inference( + prompts_tokens=prompts_tokens, + image_dict=images, + ) diff --git a/nemo/collections/vlm/inference/vlm_inference_wrapper.py b/nemo/collections/vlm/inference/vlm_inference_wrapper.py new file mode 100644 index 000000000000..29d7d83a9d54 --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_inference_wrapper.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import Namespace +from typing import Dict, List + +import torch +import torch.nn.functional as F +from megatron.core import tensor_parallel +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference_params import InferenceParams +from torch.utils.data import default_collate + +from nemo.collections.vlm.mllama.model.utils import create_vision_mask_tensor + + +class VLMInferenceWrapper(AbstractModelInferenceWrapper): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input + data, and runs the forward pass + + Args: + model (MllamaModel): The Mllama model + args (Namespace): The command line arguments that were passed + """ + + def __init__(self, model, args: Namespace): + super().__init__(model, args) + + def prep_model_for_inference( + self, + prompts_tokens: torch.Tensor, + image_dict: List[Dict] = None, + ): + # pylint: disable=C0115,C0116 + super().prep_model_for_inference(prompts_tokens=prompts_tokens) + max_num_concurrent_media = max(instance['pixel_values'].shape[0] for instance in image_dict) + for instance in image_dict: + pad_num_images = max_num_concurrent_media - instance['pixel_values'].shape[0] + instance['pixel_values'] = F.pad( + instance['pixel_values'], (0, 0, 0, 0, 0, 0, 0, 0, 0, pad_num_images), 'constant', 0 + ) + instance['aspect_ratio_ids'] = F.pad( + instance['aspect_ratio_ids'], (0, max(pad_num_images - 1, 0)), 'constant', 0 + ) + instance['num_tiles'] = F.pad( + torch.tensor(instance['num_tiles']), (0, max(pad_num_images - 1, 0)), 'constant', 0 + ) + batch = default_collate(image_dict) + + batch_size = prompts_tokens.size(0) + seq_length = prompts_tokens.size(1) + self.position_ids = ( + torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) + .unsqueeze(0) + .expand_as(prompts_tokens) + ) + self.pixel_values = batch['pixel_values'].cuda(non_blocking=True) + self.num_tiles = batch['num_tiles'] + self.aspect_ratio_ids = batch['aspect_ratio_ids'].cuda(non_blocking=True) + + self.inference_params = InferenceParams(batch_size, seq_length) + self.inference_params.xattn_caches = None + self.inference_params.cross_attention_masks = None + self.inference_params.full_text_row_masked_out_mask = None + + def get_batch_for_context_window(self, context_start_position: int, context_end_position: int) -> List: + # pylint: disable=C0115,C0116 + tokens2use = self.prompts_tokens[:, context_start_position:context_end_position] + positions2use = self.position_ids[:, context_start_position:context_end_position] + data_at_step_idx = [tokens2use, positions2use] + + return data_at_step_idx + + def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without + any parallelism or only tensor parallelism. + + Args: + inference_input (List): A list containg the inputs for the vlm + model [tokens, position ids] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens2use, positions2use = inference_input + batch_masks = [create_vision_mask_tensor(tokens2use[0], 128256)] * tokens2use.size(0) + logits = self.model( + batch_images=self.pixel_values, + batch_masks=batch_masks, + num_chunks=self.num_tiles, + aspect_ratio_ids=self.aspect_ratio_ids, + tokens=tokens2use, + position_ids=positions2use, + xattn_caches=self.inference_params.xattn_caches, + cross_attention_masks=self.inference_params.cross_attention_masks, + full_text_row_masked_out_mask=self.inference_params.full_text_row_masked_out_mask, + inference_params=self.inference_params, + ) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + self.inference_params.sequence_len_offset += tokens2use.size(1) + + return logits diff --git a/nemo/collections/vlm/mllama/model/base.py b/nemo/collections/vlm/mllama/model/base.py index 9279936e23d7..1e8bb8d5adcf 100644 --- a/nemo/collections/vlm/mllama/model/base.py +++ b/nemo/collections/vlm/mllama/model/base.py @@ -22,6 +22,7 @@ import torch.distributed from einops import rearrange from megatron.core.enums import ModelType +from megatron.core.inference_params import InferenceParams from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.optimizer import OptimizerConfig from megatron.core.tensor_parallel.layers import ColumnParallelLinear @@ -425,6 +426,7 @@ def forward( cross_attention_masks: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[torch.Tensor] = None, xattn_caches: Optional[List] = None, + inference_params: InferenceParams = None, ) -> torch.Tensor: """Forward.""" if xattn_caches is None: @@ -467,6 +469,15 @@ def forward( total_len=position_ids.shape[1], ) + xattn_mask_index = position_ids[0] + + if inference_params is not None: + inference_params.xattn_caches = xattn_caches + inference_params.cross_attention_masks = cross_attention_masks + inference_params.full_text_row_masked_out_mask = full_text_row_masked_out_mask + else: + xattn_mask_index = [cross_attention_masks.shape[2] - 1] + assert self.add_decoder, "Language model required for forward pass." language_embeddings = None if self.pre_process: @@ -474,7 +485,7 @@ def forward( language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [text_seq_len, b, h_language] full_text_row_masked_out_mask = ( - full_text_row_masked_out_mask[:, :, position_ids[0]].permute(2, 0, 1, 3).squeeze(2) + full_text_row_masked_out_mask[:, :, xattn_mask_index].permute(2, 0, 1, 3).squeeze(2) if cross_attention_masks is not None else None ) @@ -485,10 +496,11 @@ def forward( decoder_input=language_embeddings, attention_mask=None, cross_attention_masks=( - cross_attention_masks[:, :, position_ids[0]] if cross_attention_masks is not None else None + cross_attention_masks[:, :, xattn_mask_index] if cross_attention_masks is not None else None ), full_text_row_masked_out_mask=full_text_row_masked_out_mask, xattn_caches=xattn_caches, + inference_params=inference_params, ) return output diff --git a/nemo/collections/vlm/mllama/model/language.py b/nemo/collections/vlm/mllama/model/language.py index 5d4cc2e09f21..bec3ec526f6e 100644 --- a/nemo/collections/vlm/mllama/model/language.py +++ b/nemo/collections/vlm/mllama/model/language.py @@ -346,7 +346,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, rotary_pos_emb=rotary_pos_emb, cross_attention_bias=cross_attention_bias, - inference_params=inference_params, + inference_params=None, # Skip inference_params for xattn packed_seq_params=packed_seq_params, ) hidden_states, context = layer( diff --git a/scripts/vlm/mllama_generate.py b/scripts/vlm/mllama_generate.py index c97a0a81d5b9..10dc197f63a0 100644 --- a/scripts/vlm/mllama_generate.py +++ b/scripts/vlm/mllama_generate.py @@ -21,12 +21,14 @@ import requests import torch +from megatron.core.inference.common_inference_params import CommonInferenceParams from PIL import Image from transformers import AutoProcessor from nemo import lightning as nl from nemo.collections import vlm -from nemo.collections.vlm.mllama.model.utils import create_vision_mask_tensor +from nemo.collections.vlm.inference import generate as vlm_generate +from nemo.collections.vlm.inference import setup_inference_wrapper model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" @@ -54,44 +56,22 @@ def generate(model, processor, image, text): } ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) - batch = processor(image, input_text, add_special_tokens=False, return_tensors="pt") - input_ids = batch["input_ids"].cuda(non_blocking=True) - position_ids = ( - torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) + model = setup_inference_wrapper(model, processor.tokenizer) + + prompts = [input_text] + images = [image] + params = CommonInferenceParams(top_k=1, top_p=0, num_tokens_to_generate=100) + result = vlm_generate( + model, + processor.tokenizer, + processor.image_processor, + prompts, + images, + inference_params=params, ) - num_tiles = processor.image_processor.preprocess(image, return_tensors='pt')["num_tiles"] - - min_prompt_len = position_ids.shape[-1] - - input_ids = input_ids[:, :min_prompt_len] - generated_ids = input_ids.clone() - - from tqdm import tqdm - - for cur_pos in tqdm(range(min_prompt_len, min_prompt_len + 100)): - with torch.no_grad(): - position_ids = torch.arange(0, cur_pos, dtype=torch.long, device="cuda").reshape(1, -1) - batch_masks = create_vision_mask_tensor(generated_ids[0]) - - output = model( - batch_images=batch["pixel_values"].cuda(non_blocking=True), - batch_masks=[batch_masks], - num_chunks=torch.tensor(num_tiles), - aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True), - tokens=generated_ids, - position_ids=position_ids, - ) - - next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) - # Broadcast the tensor from rank 0 to all other ranks - torch.distributed.broadcast(next_token_ids, src=0) - generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) - if (next_token_ids == tokenizer.eos_token_id).all(): - break - - generated_ids = generated_ids.tolist() - generated_texts = tokenizer.decode(generated_ids[0][min_prompt_len:]) + + generated_texts = list(result)[0].generated_text if torch.distributed.get_rank() == 0: print("======== GENERATED TEXT OUTPUT ========") From 5faf1a947d5caa39a6906c53c4f8298e3b5b3d15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Thu, 19 Dec 2024 14:05:31 +0100 Subject: [PATCH 20/20] ci: Small pylint fix (#11667) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: oliver könig --- .github/workflows/code-formatting.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/code-formatting.yml b/.github/workflows/code-formatting.yml index 0eaad048b3a5..3730e0bcf955 100644 --- a/.github/workflows/code-formatting.yml +++ b/.github/workflows/code-formatting.yml @@ -139,11 +139,12 @@ jobs: echo "Will run on these files: ${FILTERED[@]}" - set +xe + set +e LOG=$(pylint ${FILTERED[@]}) EXIT_CODE=$? set -e + set +x echo "OUTPUT<> $GITHUB_ENV echo "$LOG" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV