Skip to content
This repository has been archived by the owner on Aug 10, 2024. It is now read-only.

add some error message #53

Merged
merged 7 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/autodiff/broadcast/sum_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
dims = (findall(in_shape[1:(end - lead)] .!= out_shape)..., lead_axis...)
return dropdims(sum(x, dims = dims), dims = lead_axis)
else
# TODO:implement error
throw(DimensionMismatch("Input shape $in_shape cannot be reduced to $out_shape"))

Check warning on line 23 in src/core/autodiff/broadcast/sum_to.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/broadcast/sum_to.jl#L23

Added line #L23 was not covered by tests
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/core/autodiff/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
idx::Int64
function GPU(idx::Int64)
if idx < 0
# TODO: implement Error
throw(ArgumentError("GPU index must be non-negative. Passed idx: $idx"))

Check warning on line 9 in src/core/autodiff/device.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/device.jl#L9

Added line #L9 was not covered by tests
end
return new(idx)
end
Expand Down
8 changes: 2 additions & 6 deletions src/core/autodiff/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@

using ..JITrench

struct NotImplementedError <: Exception end

# TODO: Better error massage
Base.showerror(io::IO, e::NotImplementedError) = print(io, "Not Implemented")

function forward(args...)
throw(NotImplementedError())
throw(ArgumentError("Not Implemented forward function. args: $args"))

Check warning on line 8 in src/core/autodiff/propagation.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/propagation.jl#L8

Added line #L8 was not covered by tests
end

function backward(args...)
throw(NotImplementedError())
throw(ArgumentError("Not Implemented backward function. args: $args"))

Check warning on line 12 in src/core/autodiff/propagation.jl

View check run for this annotation

Codecov / codecov/patch

src/core/autodiff/propagation.jl#L12

Added line #L12 was not covered by tests
end

function out_to_tensor(
Expand Down
11 changes: 6 additions & 5 deletions src/nn/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
batch_size :: Int
shuffle :: Bool
index :: Vector{Int}
function DataLoader(dataset; batch_size=1, shuffle=false)
function DataLoader(dataset; batch_size::Int=1, shuffle=false)
if loader.batch_size > length(loader.dataset)
throw(DomainError("Batch size must be less than or equal to the length of dataset. Batch size: $(loader.batch_size), Dataset length: $(length(loader.dataset))"))
elseif loader.batch_size < 1
throw(DomainError("Batch size must be greater than or equal to 1. Batch size: $(loader.batch_size)"))

Check warning on line 13 in src/nn/data/dataloader.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/data/dataloader.jl#L9-L13

Added lines #L9 - L13 were not covered by tests
end
new(dataset, batch_size, shuffle, zeros(Int, length(dataset)))
end
end

function Base.iterate(loader::DataLoader)
if loader.batch_size > length(loader.dataset)
# TODO: better error
throw(DomainError("batch size > data length error"))
end
loader.index .= randperm(length(loader.dataset))
data = loader.dataset[1:loader.batch_size]
head = loader.batch_size + 1
Expand Down
3 changes: 1 addition & 2 deletions src/nn/function/metrics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
function accuracy(::AbstractArray{<:AbstractFloat}, ::AbstractArray{<:AbstractFloat})
# TODO impl error
throw(DomainError(""))
throw(DomainError("Accuracy is not defined for floating point arrays."))

Check warning on line 2 in src/nn/function/metrics.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/function/metrics.jl#L2

Added line #L2 was not covered by tests
end


Expand Down
14 changes: 13 additions & 1 deletion src/nn/layer/layer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
using ..JITrench
using DataStructures: OrderedDict, DefaultDict

Parameter = OrderedDict{String, Dict{String, <: AbstractTensor}}
struct Parameter
weight :: OrderedDict{String, Dict{String, <: AbstractTensor}}
layer_names :: Vector{String}
meta :: Dict{String, Any}
function Parameter(weight::OrderedDict{String, Dict{String, <: AbstractTensor}})
layer_names = Vector{String}(undef, length(weight))
for (i, key) in enumerate(keys(weight))
layer_names[i] = key
end
return new(weight, layer_names, Dict{String, Any}())

Check warning on line 13 in src/nn/layer/layer.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/layer.jl#L8-L13

Added lines #L8 - L13 were not covered by tests
end
end


abstract type Layer end

Expand Down
3 changes: 1 addition & 2 deletions src/nn/layer/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
device = initializer.device
if !(linear.in_dim isa Nothing)
if in_dim != linear.in_dim
# TODO: impl Error
throw(DimensionMismatch(""))
throw(DimensionMismatch("Input dimension $in_dim does not match the expected dimension $(linear.in_dim)"))

Check warning on line 35 in src/nn/layer/linear.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/linear.jl#L35

Added line #L35 was not covered by tests
end
end
out_dim = linear.out_dim
Expand Down
9 changes: 6 additions & 3 deletions src/nn/layer/parameters.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
function iterate_layer(params::Parameter)
return params
return params.weight

Check warning on line 2 in src/nn/layer/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/parameters.jl#L2

Added line #L2 was not covered by tests
end

function iterate_all(params::Parameter)
return Base.Iterators.map(x -> x.second, Iterators.flatten(values(params)))
return Base.Iterators.map(x -> x.second, Iterators.flatten(values(params.weight)))

Check warning on line 6 in src/nn/layer/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/parameters.jl#L6

Added line #L6 was not covered by tests
end

function cleargrads!(params::Parameter)
for param in iterate_all(params)
for param in iterate_all(params.weight)

Check warning on line 10 in src/nn/layer/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/parameters.jl#L10

Added line #L10 was not covered by tests
JITrench.AutoDiff.cleargrad!(param)
end
end

function layer_names(params::Parameter)
return params.layer_names

Check warning on line 16 in src/nn/layer/parameters.jl

View check run for this annotation

Codecov / codecov/patch

src/nn/layer/parameters.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end



Loading