Skip to content
This repository has been archived by the owner on Aug 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #53 from abap34/develop
Browse files Browse the repository at this point in the history
add some error message
  • Loading branch information
abap34 authored Sep 2, 2023
2 parents 9592f0e + ce950bd commit 94228c5
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/core/autodiff/broadcast/sum_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function sum_to(x::T, out_shape) where {T <: AbstractArray}
dims = (findall(in_shape[1:(end - lead)] .!= out_shape)..., lead_axis...)
return dropdims(sum(x, dims = dims), dims = lead_axis)
else
# TODO:implement error
throw(DimensionMismatch("Input shape $in_shape cannot be reduced to $out_shape"))
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/core/autodiff/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct GPU <: Device
idx::Int64
function GPU(idx::Int64)
if idx < 0
# TODO: implement Error
throw(ArgumentError("GPU index must be non-negative. Passed idx: $idx"))
end
return new(idx)
end
Expand Down
8 changes: 2 additions & 6 deletions src/core/autodiff/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@ using GPUArraysCore

using ..JITrench

struct NotImplementedError <: Exception end

# TODO: Better error massage
Base.showerror(io::IO, e::NotImplementedError) = print(io, "Not Implemented")

function forward(args...)
throw(NotImplementedError())
throw(ArgumentError("Not Implemented forward function. args: $args"))
end

function backward(args...)
throw(NotImplementedError())
throw(ArgumentError("Not Implemented backward function. args: $args"))
end

function out_to_tensor(
Expand Down
11 changes: 6 additions & 5 deletions src/nn/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ struct DataLoader
batch_size :: Int
shuffle :: Bool
index :: Vector{Int}
function DataLoader(dataset; batch_size=1, shuffle=false)
function DataLoader(dataset; batch_size::Int=1, shuffle=false)
if loader.batch_size > length(loader.dataset)
throw(DomainError("Batch size must be less than or equal to the length of dataset. Batch size: $(loader.batch_size), Dataset length: $(length(loader.dataset))"))
elseif loader.batch_size < 1
throw(DomainError("Batch size must be greater than or equal to 1. Batch size: $(loader.batch_size)"))
end
new(dataset, batch_size, shuffle, zeros(Int, length(dataset)))
end
end

function Base.iterate(loader::DataLoader)
if loader.batch_size > length(loader.dataset)
# TODO: better error
throw(DomainError("batch size > data length error"))
end
loader.index .= randperm(length(loader.dataset))
data = loader.dataset[1:loader.batch_size]
head = loader.batch_size + 1
Expand Down
3 changes: 1 addition & 2 deletions src/nn/function/metrics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
function accuracy(::AbstractArray{<:AbstractFloat}, ::AbstractArray{<:AbstractFloat})
# TODO impl error
throw(DomainError(""))
throw(DomainError("Accuracy is not defined for floating point arrays."))
end


Expand Down
14 changes: 13 additions & 1 deletion src/nn/layer/layer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
using ..JITrench
using DataStructures: OrderedDict, DefaultDict

Parameter = OrderedDict{String, Dict{String, <: AbstractTensor}}
struct Parameter
weight :: OrderedDict{String, Dict{String, <: AbstractTensor}}
layer_names :: Vector{String}
meta :: Dict{String, Any}
function Parameter(weight::OrderedDict{String, Dict{String, <: AbstractTensor}})
layer_names = Vector{String}(undef, length(weight))
for (i, key) in enumerate(keys(weight))
layer_names[i] = key
end
return new(weight, layer_names, Dict{String, Any}())
end
end


abstract type Layer end

Expand Down
3 changes: 1 addition & 2 deletions src/nn/layer/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ function (linear::Linear)(initializer::Initializer)
device = initializer.device
if !(linear.in_dim isa Nothing)
if in_dim != linear.in_dim
# TODO: impl Error
throw(DimensionMismatch(""))
throw(DimensionMismatch("Input dimension $in_dim does not match the expected dimension $(linear.in_dim)"))
end
end
out_dim = linear.out_dim
Expand Down
9 changes: 6 additions & 3 deletions src/nn/layer/parameters.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
function iterate_layer(params::Parameter)
return params
return params.weight
end

function iterate_all(params::Parameter)
return Base.Iterators.map(x -> x.second, Iterators.flatten(values(params)))
return Base.Iterators.map(x -> x.second, Iterators.flatten(values(params.weight)))
end

function cleargrads!(params::Parameter)
for param in iterate_all(params)
for param in iterate_all(params.weight)
JITrench.AutoDiff.cleargrad!(param)
end
end

function layer_names(params::Parameter)
return params.layer_names
end



0 comments on commit 94228c5

Please sign in to comment.