diff --git a/src/linalg.jl b/src/linalg.jl index d0f072eb..c5c0bcf3 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -217,17 +217,42 @@ end # Norms _inner_eltype(v::AbstractArray) = isempty(v) ? eltype(v) : _inner_eltype(first(v)) _inner_eltype(x::Number) = typeof(x) -@inline _init_zero(v::StaticArray) = float(norm(zero(_inner_eltype(v)))) +@inline _init_zero(v::AbstractArray) = float(norm(zero(_inner_eltype(v)))) @inline function LinearAlgebra.norm_sqr(v::StaticArray) return mapreduce(LinearAlgebra.norm_sqr, +, v; init=_init_zero(v)) end +@inline maxabs_nested(a::Number) = abs(a) +function maxabs_nested(a::AbstractArray) + prod(size(a)) == 0 && (return _init_zero(a)) + + m = maxabs_nested(a[1]) + for j = 2:prod(size(a)) + m = @fastmath max(m, maxabs_nested(a[j])) + end + + return m +end + +@generated function _norm_scaled(::Size{S}, a::StaticArray) where {S} + expr = :(LinearAlgebra.norm_sqr(a[1]/scale)) + for j = 2:prod(S) + expr = :($expr + LinearAlgebra.norm_sqr(a[$j]/scale)) + end + + return quote + $(Expr(:meta, :inline)) + scale = maxabs_nested(a) + + scale==0 && return _init_zero(a) + return @inbounds scale * sqrt($expr) + end +end + @inline norm(a::StaticArray) = _norm(Size(a), a) @generated function _norm(::Size{S}, a::StaticArray) where {S} - if prod(S) == 0 - return :(_init_zero(a)) - end + prod(S) == 0 && return :(_init_zero(a)) expr = :(LinearAlgebra.norm_sqr(a[1])) for j = 2:prod(S) @@ -236,7 +261,10 @@ end return quote $(Expr(:meta, :inline)) - @inbounds return sqrt($expr) + l = @inbounds sqrt($expr) + + 0