Skip to content

Commit

Permalink
[Dynamo] Replace torch._dynamo.optimize() with torch.compile() [5…
Browse files Browse the repository at this point in the history
  • Loading branch information
shink authored and pytorchmergebot committed Nov 18, 2024
1 parent 16bc82a commit a1327fa
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 97 deletions.
24 changes: 12 additions & 12 deletions test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_print_graph(self):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2

Expand Down Expand Up @@ -105,7 +105,7 @@ def test_print_disas(self):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2

Expand Down Expand Up @@ -149,7 +149,7 @@ def _(ctx):

return x

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x + g(x)

Expand All @@ -169,7 +169,7 @@ def test_print_locals(self):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2

Expand All @@ -195,7 +195,7 @@ def _(ctx):
def test_print_direct(self):
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x, z):
y = x * 2
lambda: z
Expand All @@ -208,7 +208,7 @@ def test_sleep(self):
sleep_time = 5
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x, z, should_sleep):
if should_sleep:
comptime.sleep(sleep_time)
Expand All @@ -233,7 +233,7 @@ def test_get_local_closure_variable(self):
SELF = self
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
z = 3

Expand Down Expand Up @@ -265,7 +265,7 @@ def _(ctx):

return x + 3

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
y = g(y)
Expand All @@ -284,7 +284,7 @@ def test_print_guards(self):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2

Expand Down Expand Up @@ -349,7 +349,7 @@ def _(ctx):
def test_graph_break(self):
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2

Expand All @@ -363,7 +363,7 @@ def _(ctx):
self.assertEqual(cnt.frame_count, 1)
cnt.frame_count = 0

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def g(x):
y = x * 2

Expand All @@ -386,7 +386,7 @@ def test_get_local(self):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()

@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
lit = 2
Expand Down
6 changes: 3 additions & 3 deletions test/dynamo/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fn(a, b):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_static)(fn)
opt_fn = torch.compile(fn, backend=cnt_static)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_static.frame_count, 10)
Expand All @@ -35,7 +35,7 @@ def fn(a, b):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
Expand All @@ -52,7 +52,7 @@ def fn(a, b):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
Expand Down
Loading

0 comments on commit a1327fa

Please sign in to comment.