Skip to content

Commit

Permalink
ndarray: change internal api of plus to help autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Dec 7, 2017
1 parent f8d4f62 commit d6e7b2a
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,10 @@ Summation. Multiple arguments of either scalar or `NDArray` could be
added together. Note at least the first or second argument needs to be an
`NDArray` to avoid ambiguity of built-in summation.
"""
+(x::NDArray, ys::NDArrayOrReal...) = add_to!(copy(x, context(x)), ys...)
+(x::Real, y::NDArray, zs::NDArrayOrReal...) = add_to!(copy(y, context(y)), x, zs...)
+(x::NDArray) = x
+(x::NDArray, y::NDArray) = _plus(x, y)
+(x::NDArray, y::Real) = _plus_scalar(x, scalar = y)
+(y::Real, x::NDArray) = _plus_scalar(x, scalar = y)

broadcast_(::typeof(+), x::NDArray, y::NDArrayOrReal) = x + y
broadcast_(::typeof(+), x::Real, y::NDArray) = x + y
Expand Down Expand Up @@ -1176,20 +1178,16 @@ function _get_ndarray_function_def(name :: String)
args = MX_handle[]
end

if length(output_vars) > 0
output_handles = map((x) -> Base.cconvert(MX_handle, x), output_vars)
# XXX: Julia 0.4 has bug: [Array{MX_handle}] == Array{MX_handle}
output_handles_pp = Array{Array{MX_handle}}(1)
output_handles_pp[1] = Base.cconvert(Ptr{MX_handle}, output_handles)
output_handles_pp = if length(output_vars) > 0
[map(x -> x.handle, output_vars)]
else
output_handles_pp = [Base.convert(Ptr{MX_handle}, 0)]
[Ptr{MX_handle}(C_NULL)]
end
num_outputs_p = [convert(Cint, num_outputs)]

kw_keys_str = String[string(x[1]) for x in kwargs]
kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs]

#op_handle = _get_cached_libmx_op_handle($(QuoteNode(name)))
op_handle = _get_cached_libmx_op_handle($(name))
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle},
Expand Down

0 comments on commit d6e7b2a

Please sign in to comment.