Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enable few ORTModule Unit Tests #19847

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ORTModule_Convergence_Notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:

```diff
+ from onnxruntime.training.utils import inspect_activation
+ from onnxruntime.training.utils.hooks import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from ._subscriber_base import RuntimeStates, SubscriberBase
from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP


class _InspectActivation(torch.autograd.Function):
Expand Down Expand Up @@ -176,21 +177,23 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st
display_name = name + " forward run" if is_forward is True else name + " backward run"
output_file_name = name + "_forward" if is_forward is True else name + "_backward"

if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return
# Skip dump during model pre-export output schema preparison run and export run.
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return

step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name
step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name

with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")

with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)


def _summarize_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model(
return model


def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""
def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0):
"""Generate attention mask for encoder following the implementation in HuggingFace.

input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
Be noted: past_key_values_length is 0 for training.

return input_ids, input_mask, labels
Generate mask using this
https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81

"""

attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
return attention_mask


def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device):
"""Returns sample data with random shape to be used with BertForSequenceClassification model"""

x = random.randint(1, 100)
y = random.randint(1, 100)
input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device)
bsz = random.randint(1, 100)
seq_length = random.randint(1, 100)
input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device)
labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device)

return input_ids, input_mask, labels


def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""

input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)

return input_ids, input_mask, labels

Expand Down Expand Up @@ -2211,32 +2225,27 @@ def run_step(model, x):
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)


# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to
# unblock the move to a later version of transformers to resolve security vulnerability.
# (Moving from transformers v4.4.2 to v4.30.0)
# def test_bert_inputs_with_dynamic_shape():
# # create pytorch model with dropout disabled
# pt_model = _get_bert_for_sequence_classification_model(
# "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
# )
# ort_model = ORTModule(copy.deepcopy(pt_model))
def test_bert_inputs_with_dynamic_shape():
# create pytorch model with dropout disabled
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)
ort_model = ORTModule(copy.deepcopy(pt_model))

# def run_step(model, x, y, z):
# outputs = model(x, y, None, None, None, None, z)
# loss = outputs[0]
# loss.backward()
# return outputs[0]
def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()
return outputs[0]

# for _step in range(10):
# x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
for _step in range(10):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")

# pt_p = run_step(pt_model, x, y, z)
# ort_p = run_step(ort_model, x, y, z)
pt_p = run_step(pt_model, x, y, z)
ort_p = run_step(ort_model, x, y, z)

# _test_helpers.assert_values_are_close(
# ort_p, pt_p, atol=1e-01
# ) # TODO: this assert is failing with smaller tolerance, need to investigate!!
# # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation
_test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)


@pytest.mark.parametrize("device", ["cuda", "cpu"])
Expand Down Expand Up @@ -6424,9 +6433,6 @@ def run_step(model, x):
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]


@pytest.mark.skip(
reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now."
)
def test_bert_result_with_layerwise_recompute():
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
# Create PyTorch model with dropout disabled.
Expand Down
Loading