diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index 791b6c32c9b48..2374e7b7c538a 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: ```diff -+ from onnxruntime.training.utils import inspect_activation ++ from onnxruntime.training.utils.hooks import inspect_activation class BloomForCausalLM(BloomPreTrainedModel): def __init__(self, config: BloomConfig): ... diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 68b78f8df70f1..a8e730488d76d 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -14,6 +14,7 @@ import torch from ._subscriber_base import RuntimeStates, SubscriberBase +from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP class _InspectActivation(torch.autograd.Function): @@ -176,21 +177,23 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st display_name = name + " forward run" if is_forward is True else name + " backward run" output_file_name = name + "_forward" if is_forward is True else name + "_backward" - if tensor is None or not isinstance(tensor, torch.Tensor): - print(f"{display_name} not a torch tensor, value: {tensor}") - return + # Skip dump during model pre-export output schema preparison run and export run. + if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: + if tensor is None or not isinstance(tensor, torch.Tensor): + print(f"{display_name} not a torch tensor, value: {tensor}") + return - step_path = Path(step_folder) - if not step_path.exists(): - step_path.mkdir(parents=True, exist_ok=False) - order_file_path = step_path / "order.txt" - tensor_file_path = step_path / output_file_name + step_path = Path(step_folder) + if not step_path.exists(): + step_path.mkdir(parents=True, exist_ok=False) + order_file_path = step_path / "order.txt" + tensor_file_path = step_path / output_file_name - with order_file_path.open(mode="a", encoding="utf-8") as f: - f.write(f"{output_file_name}\n") + with order_file_path.open(mode="a", encoding="utf-8") as f: + f.write(f"{output_file_name}\n") - with tensor_file_path.open(mode="w", encoding="utf-8") as f: - _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) + with tensor_file_path.open(mode="w", encoding="utf-8") as f: + _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) def _summarize_tensor( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 365c2bb8ebe0e..f0261c776609e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model( return model -def _get_bert_for_sequence_classification_sample_data(device): - """Returns sample data to be used with BertForSequenceClassification model""" +def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0): + """Generate attention mask for encoder following the implementation in HuggingFace. - input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device) - input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device) - labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device) + Be noted: past_key_values_length is 0 for training. - return input_ids, input_mask, labels + Generate mask using this + https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81 + + """ + + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + return attention_mask def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device): """Returns sample data with random shape to be used with BertForSequenceClassification model""" - x = random.randint(1, 100) - y = random.randint(1, 100) - input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device) - input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device) - labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device) + bsz = random.randint(1, 100) + seq_length = random.randint(1, 100) + input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device) + input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device) + labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device) + + return input_ids, input_mask, labels + + +def _get_bert_for_sequence_classification_sample_data(device): + """Returns sample data to be used with BertForSequenceClassification model""" + + input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device) + input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device) + labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device) return input_ids, input_mask, labels @@ -2211,32 +2225,27 @@ def run_step(model, x): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) -# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to -# unblock the move to a later version of transformers to resolve security vulnerability. -# (Moving from transformers v4.4.2 to v4.30.0) -# def test_bert_inputs_with_dynamic_shape(): -# # create pytorch model with dropout disabled -# pt_model = _get_bert_for_sequence_classification_model( -# "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 -# ) -# ort_model = ORTModule(copy.deepcopy(pt_model)) +def test_bert_inputs_with_dynamic_shape(): + # create pytorch model with dropout disabled + pt_model = _get_bert_for_sequence_classification_model( + "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) + ort_model = ORTModule(copy.deepcopy(pt_model)) -# def run_step(model, x, y, z): -# outputs = model(x, y, None, None, None, None, z) -# loss = outputs[0] -# loss.backward() -# return outputs[0] + def run_step(model, x, y, z): + outputs = model(x, y, None, None, None, None, z) + loss = outputs[0] + loss.backward() + return outputs[0] -# for _step in range(10): -# x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") + for _step in range(10): + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") -# pt_p = run_step(pt_model, x, y, z) -# ort_p = run_step(ort_model, x, y, z) + pt_p = run_step(pt_model, x, y, z) + ort_p = run_step(ort_model, x, y, z) -# _test_helpers.assert_values_are_close( -# ort_p, pt_p, atol=1e-01 -# ) # TODO: this assert is failing with smaller tolerance, need to investigate!! -# # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation + _test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @pytest.mark.parametrize("device", ["cuda", "cpu"]) @@ -6424,9 +6433,6 @@ def run_step(model, x): del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] -@pytest.mark.skip( - reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now." -) def test_bert_result_with_layerwise_recompute(): original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None) # Create PyTorch model with dropout disabled.