Skip to content
This repository has been archived by the owner on Aug 10, 2024. It is now read-only.

improve document #54

Merged
merged 15 commits into from
Sep 21, 2023
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false # don't stop CI even when one of them fails
matrix:
version:
- '1.8.1' # lowest version supported
- '1.9.1' # lowest version supported
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -27,6 +27,9 @@ jobs:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-docdeploy@latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
23 changes: 12 additions & 11 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
using JITrench
using Documenter

DocMeta.setdocmeta!(JITrench, :DocTestSetup, :(using JITrench); recursive = true)
DocMeta.setdocmeta!(JITrench, :DocTestSetup, :(using JITrench); recursive=true)

makedocs(;
modules = [JITrench],
authors = "Yuchi Yamaguchi",
repo = "https://github.com/abap34/JITrench.jl/blob/{commit}{path}#{line}",
sitename = "JITrench.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://abap34.github.io/JITrench.jl",
assets = String[],
modules=[JITrench],
authors="Yuchi Yamaguchi",
repo="https://github.com/abap34/JITrench.jl/blob/{commit}{path}#{line}",
sitename="JITrench.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://abap34.github.io/JITrench.jl",
assets=String[]
),
pages = ["Home" => "index.md"],
pages=["Home" => "index.md",
"API" => "api.md",]
)

deploydocs(; repo = "github.com/abap34/JITrench.jl")
deploydocs(; repo="github.com/abap34/JITrench.jl")
14 changes: 14 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# API

## AutoDiff

```@autodocs
Modules = [AutoDiff]
```

## Utilities

```@docs
JITrench.plot_graph
JITrench.PNGContainer
```
66 changes: 63 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,69 @@ CurrentModule = JITrench

# JITrench

```@index
JITrench.jl is Lightweight Automatic Differentiation & DeepLearning Framework implemented in pure Julia.

# Quick Tour

## Automatic Differentiation

```julia
julia> using JITrench

julia> f(x) = sin(x) + 1
f (generic function with 1 method)

julia> JITrench.@diff! f(x)
f′ (generic function with 1 method)

julia> f′(π)
-1.0
```

```@autodocs
Modules = [JITrench]
## Train Neural Network

```julia
using JITrench
using JITrench.NN
using Printf


N = 100
p = 1
n_iter = 20000

x = rand(N, p)
y = sin.(2π .* x) + (rand(N, p) / 1)

function model(x)
x = NN.Linear(out_dim=10)(x)
x = NN.functions.sigmoid.(x)
x = NN.Linear(out_dim=1)(x)
return NN.result(x)
end

params = NN.init(model, NN.Initializer((nothing, 1)))
optimizer = NN.SGD(params, 1e-1)

x = Tensor(x)
y = Tensor(y)

for iter in 1:n_iter
pred = NN.apply(model, x, params)
loss = NN.functions.mean_squared_error(y, pred)
NN.cleargrads!(params)
backward!(loss)
NN.optimize!(optimizer)
if (iter % 500 == 0)
@printf "[iters] %4i [loss] %.4f\n" iter loss.values
end
end


NN.save_weight(params, "weight")
```

![](https://github.com/abap34/JITrench.jl/blob/master/example/NN/learning.gif)



1 change: 1 addition & 0 deletions src/JITrench.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export AutoDiff, DiffableFunction, backward!, Scalar, AbstractTensor, Tensor, Cu
Device = AutoDiff.Device
GPU = AutoDiff.GPU
CPU = AutoDiff.CPU
check_same_device = AutoDiff.check_same_device

include("core/functions.jl")

Expand Down
10 changes: 6 additions & 4 deletions src/core/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module AutoDiff

abstract type JTFunction end
abstract type Variable end
abstract type DiffableFunction <: JTFunction end
abstract type DiffableFunction <: JTFunction end


include("device.jl")
Expand All @@ -16,10 +16,12 @@ include("broadcast/sum_to.jl")
include("broadcast/broadcast_to.jl")


const ScalarTypes = Union{Real, Scalar}
const TensorTypes = Union{AbstractArray, Tensor, CuTensor}
const ScalarTypes = Union{Real,Scalar}
const TensorTypes = Union{AbstractArray,Tensor,CuTensor}

export DiffableFunction,
export
check_same_device,
DiffableFunction,
BinaryOperator,
UnaryOperator,
GradField,
Expand Down
78 changes: 74 additions & 4 deletions src/core/autodiff/device.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
abstract type Device end

"""
CPU <: Device

Repersents CPU device. This is used for the default device.
"""
struct CPU <: Device end

"""
GPU(idx::Int64) <: Device

Repersents GPU device. `idx` is corresponding device index in CUDA.jl.
"""
struct GPU <: Device
idx::Int64
function GPU(idx::Int64)
Expand All @@ -12,14 +22,74 @@
end
end

check_same_device(device1::T, device2::T) where {T <: Device} = nothing
"""
NotSameDeviceError <: Exception

Exception thrown when the device of two tensors are not the same.
"""
struct NotSameDeviceError <: Exception
same_accelerator::Bool
same_gpu_idx::Bool
function NotSameDeviceError(; same_accelerator, same_gpu_idx)
if (same_accelerator) && (same_gpu_idx)
throw(

Check warning on line 35 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L33-L35

Added lines #L33 - L35 were not covered by tests
DomainError(
"same_accelerator and same_gpu_idx can never be false at the same time",
),
)
end
return new(same_accelerator, same_gpu_idx)

Check warning on line 41 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L41

Added line #L41 was not covered by tests
end
end


function Base.showerror(io::IO, e::NotSameDeviceError)
if !(e.same_accelerator)
print(

Check warning on line 48 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L46-L48

Added lines #L46 - L48 were not covered by tests
io,
"Arguments must be in the same device, Arguments are on both the CPU and the GPU.",
)
end

if !(e.same_gpu_idx)
print(

Check warning on line 55 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
io,
"All arguments must be in the same device. Arguments are on different GPUs.",
)
end
end


"""
check_same_device(device1::Device, device2::Device)

Check if `device1` and `device2` are the same device.
If they are the same device, return nothing. Otherwise, throw `NotSameDeviceError`.

# Arguments
- `device1`: Device to be compared.
- `device2`: Device to be compared.

# Example
```julia-repl
julia> device1 = JITrench.CPU()
JITrench.AutoDiff.CPU()

julia> device2 = JITrench.GPU(0)
JITrench.AutoDiff.GPU(0)

julia> JITrench.check_same_device(device1, device2)
ERROR: All arguments must be in the same device. Arguments are on different GPUs.
```
"""
check_same_device(::T, ::T) where {T<:Device} = nothing

Check warning on line 85 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L85

Added line #L85 was not covered by tests

check_same_device(device1::CPU, device2::GPU) = throw(NotSameDeviceError(true, false))
check_same_device(::CPU, ::GPU) = throw(NotSameDeviceError(same_accelerator=true, same_gpu_idx=false))

Check warning on line 87 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L87

Added line #L87 was not covered by tests

function check_same_device(device1::GPU, device2::GPU)
if device1.idx != device2.idx
throw(NotSameDeviceError(same_accelerator = true, same_gpu_idx = false))
throw(NotSameDeviceError(same_accelerator=true, same_gpu_idx=false))

Check warning on line 91 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L91

Added line #L91 was not covered by tests
else
return device1.idx
return nothing

Check warning on line 93 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L93

Added line #L93 was not covered by tests
end
end
92 changes: 92 additions & 0 deletions src/core/autodiff/function.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@

abstract type AdditionalField end

"""
GradField(inputs, output, generation)

Retain the information of the function for backpropagation.
All DiffableFunction must have this field.
"""
struct GradField{T, S}
" Tuple of input variables."
inputs::T
" Output variable."
output::S
" Generation of the function. Corresponds evaluation order priority of the function in backward pass. "
generation::Int
function GradField(
inputs::T,
Expand All @@ -13,16 +23,98 @@
end
end

"""
BinaryOperator

DiffableFunction which takes two variables as input.
"""
abstract type BinaryOperator <: DiffableFunction end

"""
UnaryOperator

DiffableFunction which takes one variable as input.
"""
abstract type UnaryOperator <: DiffableFunction end


function Base.show(io::IO, f::DiffableFunction)
print(io, typeof(f))
end


"""
_get_gf(f::DiffableFunction)

Get the GradField of the function.
"""
_get_gf(f::DiffableFunction) = f.grad_field


"""
@diffable

"""



function Base.show(io::IO, ::MIME"text/plain", f::DiffableFunction)
print(io, typeof(f))
end


function _subtypedef(ex)
if ex.head != :<:
return false

Check warning on line 68 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end

if ex.args[2] isa Symbol
return eval(ex.args[2]) <: DiffableFunction

Check warning on line 72 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
else
return false

Check warning on line 74 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L74

Added line #L74 was not covered by tests
end
end

"""
@diffable

Check definition of DiffableFunction has `grad_field` field and it's type is GradField.
"""
macro diffable(ex)
if (ex.head != :struct)
throw(ArgumentError(

Check warning on line 85 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L85

Added line #L85 was not covered by tests
"@diffable is macro that check definition of struct. Passed argument is not struct."
))
end
if (ex.args[2].head != :<:) || !(JITrench.AutoDiff.eval(ex.args[2].args[2]) <: DiffableFunction)
throw(ArgumentError(

Check warning on line 90 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L90

Added line #L90 was not covered by tests
"@diffable is macro that check definition of struct. Which is subtypes of DiffableFunction."
))
end

if !(ex.args[3] isa Expr)
throw(ArgumentError(

Check warning on line 96 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L96

Added line #L96 was not covered by tests
"@diffable is macro that check definition of struct. Which has grad_field field."
))
end

for arg in ex.args[3].args
if arg isa Expr
if arg.head == :(::)
if arg.args[1] == :grad_field
if JITrench.AutoDiff.eval(arg.args[2]) <: GradField
return esc(ex)
else
throw(ArgumentError(

Check warning on line 108 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L108

Added line #L108 was not covered by tests
"Type of grad_field field must be GradField."
))
end
end
end
end
end
throw(ArgumentError(

Check warning on line 116 in src/core/autodiff/function.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/function.jl#L116

Added line #L116 was not covered by tests
"DiffableFunction must have `grad_field` field."
))
end

Loading
Loading