Skip to content

Commit

Permalink
Merge pull request #234 from JuliaSymbolics/s/less-parens
Browse files Browse the repository at this point in the history
Less parens in printing
  • Loading branch information
shashi authored Mar 13, 2021
2 parents ce87b71 + 2d3b138 commit 9293647
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,18 @@ setargs(t, args) = Term{symtype(t)}(operation(t), args)
cdrargs(args) = setargs(t, cdr(args))

print_arg(io, x::Union{Complex, Rational}) = print(io, "(", x, ")")
print_arg(io, x) = print(io, x)
print_arg(io, f::typeof(^), x) = print_arg(IOContext(io, :paren=>true), x)
function print_arg(io, x; paren=false)
if paren && isbinop(x)
print(io, "(", x, ")")
else
print(io, x)
end
end
print_arg(io, s::String) = show(io, s)
function print_arg(io, f, x)
f !== (*) && return print_arg(io, x)
if istree(x) && Base.isbinaryoperator(nameof(operation(x)))
print_arg(IOContext(io, :paren=>true), x)
if Base.isbinaryoperator(nameof(f)) && isbinop(x)
print_arg(io, x, paren=true)
else
print_arg(io, x)
end
Expand All @@ -448,11 +453,20 @@ function show_add(io, args)
print_arg(io, -, t)
else
print(io, " - ")
print_arg(IOContext(io, :paren=>true), +, -t)
print_arg(io, -t, paren=true)
end
end
end

isbinop(f) = istree(f) && Base.isbinaryoperator(nameof(operation(f)))
function show_pow(io, args)
base, ex = args

print_arg(io, base, paren=isbinop(base))
print(io, "^")
print_arg(io, ex, paren=isbinop(base))
end

function show_mul(io, args)
length(args) == 1 && return print_arg(io, *, args[1])

Expand Down Expand Up @@ -481,8 +495,8 @@ function show_call(io, f, args)
binary = Base.isbinaryoperator(fname)
if binary
for (i, t) in enumerate(args)
i != 1 && print(io, fname == :^ ? fname : " $fname ")
print_arg(io, (^), t)
i != 1 && print(io, " $fname ")
print_arg(io, t)
end
else
if f isa Sym
Expand All @@ -507,15 +521,15 @@ function show_term(io::IO, t)
f = operation(t)
args = arguments(t)

get(io, :paren, false) && print(io, "(")
if f === (+)
show_add(io, args)
elseif f === (*)
show_mul(io, args)
elseif f === (^)
show_pow(io, args)
else
show_call(io, f, args)
end
get(io, :paren, false) && print(io, ")")

return nothing
end
Expand Down

0 comments on commit 9293647

Please sign in to comment.