forked from Elvin-Ma/pytorch_guide
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
113 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
import torch | ||
from torch._dynamo import optimize | ||
from torch._dynamo import optimize # | ||
import torch._inductor.config | ||
|
||
|
||
torch._inductor.config.debug = True | ||
torch._dynamo.config.suppress_errors = True | ||
|
||
|
||
# 对这个function 进行加速 | ||
def fn(x): | ||
a = torch.sin(x).cuda() | ||
b = torch.sin(a).cuda() | ||
return b | ||
|
||
new_fn = optimize("inductor")(fn) | ||
new_fn = optimize("inductor")(fn) # new_fn | ||
input_tensor = torch.randn(10000).to(device="cuda:0") | ||
a = new_fn(input_tensor) | ||
print("run dynamo_hell.py successfully !!!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import time | ||
from torchvision.models import resnet18 | ||
|
||
# Returns the result of running `fn()` and the time it took for `fn()` to run, | ||
# in seconds. We use CUDA events and synchronization for the most accurate | ||
# measurements. | ||
def timed(fn): | ||
# start = torch.cuda.Event(enable_timing=True) | ||
# end = torch.cuda.Event(enable_timing=True) | ||
a = time.time() | ||
# start.record() | ||
result = fn() | ||
# end.record() | ||
b = time.time() | ||
# torch.cuda.synchronize() | ||
return result, (b-a) | ||
|
||
# Generates random input and targets data for the model, where `b` is | ||
# batch size. | ||
def generate_data(b): | ||
return ( | ||
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(), | ||
torch.randint(1000, (b,)).cuda(), | ||
) | ||
|
||
def init_model(): | ||
return resnet18().to(torch.float32).cuda() | ||
|
||
def evaluate(mod, inp): | ||
return mod(inp) | ||
|
||
if __name__ == "__main__": | ||
|
||
model = init_model() | ||
|
||
# # Reset since we are using a different mode. | ||
import torch._dynamo | ||
torch._dynamo.reset() | ||
|
||
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead") | ||
|
||
# 验证一次 | ||
# inp = generate_data(16)[0] | ||
# print("eager:", timed(lambda: evaluate(model, inp))[1]) | ||
# print("compile:", timed(lambda: evaluate_opt(model, inp))[1]) | ||
|
||
N_ITERS = 10 | ||
|
||
eager_times = [] | ||
for i in range(N_ITERS): | ||
inp = generate_data(16)[0] | ||
_, eager_time = timed(lambda: evaluate(model, inp)) | ||
eager_times.append(eager_time) | ||
print(f"eager eval time {i}: {eager_time}") | ||
|
||
print("~" * 10) | ||
|
||
compile_times = [] | ||
for i in range(N_ITERS): | ||
inp = generate_data(16)[0] | ||
_, compile_time = timed(lambda: evaluate_opt(model, inp)) | ||
compile_times.append(compile_time) | ||
print(f"compile eval time {i}: {compile_time}") | ||
print("~" * 10) | ||
|
||
import numpy as np | ||
eager_med = np.median(eager_times) | ||
compile_med = np.median(compile_times) | ||
speedup = eager_med / compile_med | ||
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") | ||
print("~" * 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
|
||
def demo_1(): | ||
def foo(x, y): | ||
a = torch.sin(x) | ||
b = torch.cos(x) | ||
return a + b | ||
|
||
opt_foo1 = torch.compile(foo) | ||
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) | ||
|
||
def demo_2(): | ||
@torch.compile | ||
def opt_foo2(x, y): | ||
a = torch.sin(x) | ||
b = torch.cos(x) | ||
return a + b | ||
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10))) | ||
|
||
def demo_3(): | ||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.lin = torch.nn.Linear(100, 10) | ||
|
||
def forward(self, x): | ||
return torch.nn.functional.relu(self.lin(x)) | ||
|
||
mod = MyModule() | ||
opt_mod = torch.compile(mod) | ||
print(opt_mod(torch.randn(10, 100))) | ||
|
||
if __name__ == "__main__": | ||
# demo_1() | ||
# demo_2() | ||
demo_3() | ||
print("run torch2_demo.py successfully !!!") |