Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 21 additions & 149 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,31 @@ jobs:
test_group: [
'quality',
'basic',
'rrules/avoiding_non_differentiable_code',
'rrules/blas',
'rrules/builtins',
'rrules/fastmath',
'rrules/foreigncall',
'rrules/functionwrappers',
'rrules/iddict',
'rrules/lapack',
'rrules/linear_algebra',
'rrules/low_level_maths',
'rrules/memory',
'rrules/misc',
'rrules/new',
'rrules/tasks',
'rrules/twice_precision',
# 'rrules/avoiding_non_differentiable_code',
# 'rrules/blas',
# 'rrules/builtins',
# 'rrules/fastmath',
# 'rrules/foreigncall',
# 'rrules/functionwrappers',
# 'rrules/iddict',
# 'rrules/lapack',
# 'rrules/linear_algebra',
# 'rrules/low_level_maths',
# 'rrules/memory',
# 'rrules/misc',
# 'rrules/new',
# 'rrules/tasks',
# 'rrules/twice_precision',
]
version:
- 'lts'
# - 'lts'
- '1'
arch:
- x64
include:
- test_group: 'basic'
version: '1.10'
arch: x86
# include:
# - test_group: 'basic'
# version: '1.10'
# arch: x86
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -66,132 +66,4 @@ jobs:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
extra:
name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }}
runs-on: ubuntu-latest
if: github.event_name != 'schedule'
strategy:
fail-fast: false
matrix:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'dynamic_ppl'},
{test_type: 'ext', label: 'luxlib'},
{test_type: 'ext', label: 'nnlib'},
{test_type: 'ext', label: 'special_functions'},
{test_type: 'integration_testing', label: 'array'},
{test_type: 'integration_testing', label: 'bijectors'},
{test_type: 'integration_testing', label: 'diff_tests'},
{test_type: 'integration_testing', label: 'distributions'},
{test_type: 'integration_testing', label: 'gp'},
{test_type: 'integration_testing', label: 'logexpfunctions'},
{test_type: 'integration_testing', label: 'lux'},
{test_type: 'integration_testing', label: 'battery_tests'},
{test_type: 'integration_testing', label: 'misc_abstract_array'},
{test_type: 'integration_testing', label: 'temporalgps'},
{test_type: 'integration_testing', label: 'turing'},
]
version:
- '1'
- 'lts'
arch:
- x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: |
if [ ${{ matrix.test_group.test_type }} == 'ext' ]; then
julia --code-coverage=user --eval 'include("test/run_extra.jl")'
else
julia --eval 'include("test/run_extra.jl")'
fi
env:
LABEL: ${{ matrix.test_group.label }}
TEST_TYPE: ${{ matrix.test_group.test_type }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v5
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
perf:
name: "Performance (${{ matrix.perf_group }})"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
perf_group:
- 'hand_written'
- 'derived'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()'
env:
PERF_GROUP: ${{ matrix.perf_group }}
shell: bash
compperf:
name: "Performance (inter-AD)"
runs-on: ubuntu-latest
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: mkdir bench_results
- run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()'
env:
PERF_GROUP: 'comparison'
GKSwstype: '100'
shell: bash
- uses: actions/upload-artifact@v4
with:
name: benchmarking-results
path: bench_results/
# Useful code for testing action.
# - run: |
# text="this is line one
# this is line two
# this is line three"
# echo "$text" > benchmark_results.txt
- name: Read file content
id: read-file
run: |
{
echo "table<<EOF"
cat bench/benchmark_results.txt
echo "EOF"
} >> $GITHUB_OUTPUT
- name: Find Comment
uses: peter-evans/find-comment@v3
id: fc
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: github-actions[bot]
- id: post-report-as-pr-comment
name: Post Report as Pull Request Comment
uses: peter-evans/create-or-update-comment@v4
with:
issue-number: ${{ github.event.pull_request.number }}
body: "Performance Ratio:\nRatio of time to compute gradient and time to compute function.\nWarning: results are very approximate! See [here](https://github.com/compintell/Mooncake.jl/tree/main/bench#inter-framework-benchmarking) for more context.\n```\n${{ steps.read-file.outputs.table }}\n```"
comment-id: ${{ steps.fc.outputs.comment-id }}
edit-mode: replace

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ profile.pb.gz
scratch.jl
docs/build/
docs/site/
playground.jl
16 changes: 16 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using Base:
twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics,
bitcast,
Expand All @@ -50,6 +51,13 @@ using FunctionWrappers: FunctionWrapper
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand All @@ -75,8 +83,11 @@ pb!!(1.0)
"""
function rrule!! end

include("interpreter/diffractor_compiler_utils.jl")

include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -88,6 +99,7 @@ include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "bbcode.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_forward_mode_ad.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
Expand Down Expand Up @@ -133,9 +145,13 @@ export primal,
_add_to_primal,
_diff,
_dot,
Dual,
zero_dual,
zero_codual,
codual_type,
frule!!,
rrule!!,
build_frule,
build_rrule,
value_and_gradient!!,
value_and_pullback!!,
Expand Down
1 change: 1 addition & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DebugFRule(rule) = rule # TODO: make it non-trivial

"""
DebugPullback(pb, y, x)
Expand Down
28 changes: 28 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
struct Dual{P,T}
primal::P
tangent::T
end

primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:Dual} = x

zero_dual(x) = Dual(x, zero_tangent(x))

function dual_type(::Type{P}) where {P}
P == DataType && return Dual
P isa Union && return Union{dual_type(P.a),dual_type(P.b)}
P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual
end

function dual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent}
end

_primal(x) = x
_primal(x::Dual) = primal(x)

make_dual(x) = zero_dual(x)
make_dual(x::Dual) = x
9 changes: 9 additions & 0 deletions src/interpreter/bbcode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,15 @@ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...)
inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x
inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest)
inc_args(x::IDGotoNode) = x
function inc_args(x::PhiNode)
new_values = Vector{Any}(undef, length(x.values))
for n in eachindex(x.values)
if isassigned(x.values, n)
new_values[n] = __inc(x.values[n])
end
end
return PhiNode(x.edges, new_values)
end
function inc_args(x::IDPhiNode)
new_values = Vector{Any}(undef, length(x.values))
for n in eachindex(x.values)
Expand Down
Loading
Loading