Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Mar 25, 2024
1 parent 11c0bbe commit 30454e0
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6500,7 +6500,7 @@ def run_step(model, x, y, z):


@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32])
def test_overriden_softmax_export(softmax_compute_type):
def test_overridden_softmax_export(softmax_compute_type):
class CustomSoftmaxExportTest(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -6510,7 +6510,9 @@ def forward(self, attn_weight):

device = "cuda"
pt_model = CustomSoftmaxExportTest().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overriden_softmax_export"))
ort_model = ORTModule(
copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overridden_softmax_export")
)

def run_step(model, attn_weight):
prediction = model(attn_weight)
Expand Down

0 comments on commit 30454e0

Please sign in to comment.