Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make normalization more AD friendly (Diffractor) #148

Merged
merged 1 commit into from
Sep 5, 2022
Merged

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Sep 5, 2022

Diffractor use relies on JuliaDiff/Diffractor.jl#89

Some Benchmarks:

Tested on julia master

ResNet18

using Diffractor, Zygote, Boltz, Lux

model, ps, st = resnet(:resnet18);
ps, st = (ps, st) .|> gpu;

x = randn(Float32, 224, 224, 3, 1) |> gpu;

model(x, ps, st);

loss_function(model, x, ps, st) = sum(model(x, ps, st)[1])
l(x, p) = loss_function(model, x, p, st)

l(x, ps)

@time Diffractor.gradient(l, x, ps);  # 26.025503 seconds (31.82 M allocations: 2.033 GiB, 3.84% gc time, 92.70% compilation time)
@time Zygote.gradient(l, x, ps);  # 123.329143 seconds (59.74 M allocations: 3.654 GiB, 1.52% gc time, 92.13% compilation time)

ResNet50

using Diffractor, Zygote, Boltz, Lux

model, ps, st = resnet(:resnet50);
ps, st = (ps, st) .|> gpu;

x = randn(Float32, 224, 224, 3, 1) |> gpu;

model(x, ps, st);

loss_function(model, x, ps, st) = sum(model(x, ps, st)[1])
l(x, p) = loss_function(model, x, p, st)

l(x, ps)

@time Diffractor.gradient(l, x, ps);  # 39.250322 seconds (34.66 M allocations: 2.212 GiB, 2.55% gc time, 95.59% compilation time)
@time Zygote.gradient(l, x, ps);  # 352.995727 seconds (71.11 M allocations: 4.367 GiB, 0.70% gc time, 99.72% compilation time)

@codecov
Copy link

codecov bot commented Sep 5, 2022

Codecov Report

Merging #148 (73a790b) into main (3d6c75c) will decrease coverage by 0.73%.
The diff coverage is 50.00%.

@@            Coverage Diff             @@
##             main     #148      +/-   ##
==========================================
- Coverage   84.82%   84.08%   -0.74%     
==========================================
  Files          17       17              
  Lines        1140     1150      +10     
==========================================
  Hits          967      967              
- Misses        173      183      +10     
Impacted Files Coverage Δ
src/autodiff.jl 62.85% <0.00%> (-4.84%) ⬇️
src/utils.jl 90.90% <0.00%> (-1.40%) ⬇️
src/layers/normalize.jl 87.50% <68.75%> (-2.37%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@avik-pal avik-pal merged commit 647498d into main Sep 5, 2022
@avik-pal avik-pal deleted the ap/diffractor branch September 5, 2022 19:21
avik-pal added a commit that referenced this pull request Sep 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant