diff --git a/ext/SteadyStateDiffEqExt.jl b/ext/SteadyStateDiffEqExt.jl index dea85c01..5887e072 100644 --- a/ext/SteadyStateDiffEqExt.jl +++ b/ext/SteadyStateDiffEqExt.jl @@ -5,7 +5,7 @@ using HarmonicBalance: HarmonicBalance, steady_state_sweep using SteadyStateDiffEq: solve, NonlinearProblem, SteadyStateProblem, DynamicSS, remake using LinearAlgebra: norm, eigvals -using SteadyStateDiffEq.SciMLBase.SciMLStructures: isscimlstructure, Tunable, replace +using SteadyStateDiffEq.SciMLBase.SciMLStructures: Tunable, replace function HarmonicBalance.steady_state_sweep( prob::SteadyStateProblem, alg::DynamicSS; varied::Pair, kwargs... diff --git a/ext/TimeEvolution/TimeEvolution.jl b/ext/TimeEvolution/TimeEvolution.jl index 2e9e185e..c496c19e 100644 --- a/ext/TimeEvolution/TimeEvolution.jl +++ b/ext/TimeEvolution/TimeEvolution.jl @@ -28,10 +28,9 @@ include("ODEProblem.jl") include("hysteresis_sweep.jl") include("plotting.jl") -export FFT export ParameterSweep -export transform_solutions, plot, plot!, is_stable +export transform_solutions, plot, plot! export ODEProblem, solve -export plot_1D_solutions_branch, follow_branch +export follow_branch end diff --git a/src/HarmonicBalance.jl b/src/HarmonicBalance.jl index 5d44ee81..58bf9d6d 100644 --- a/src/HarmonicBalance.jl +++ b/src/HarmonicBalance.jl @@ -24,8 +24,44 @@ function set_imaginary_tolerance(x::Float64) @eval(IM_TOL::Float64 = $x) end -include("Symbolics_customised.jl") -include("Symbolics_utils.jl") +using SymbolicUtils: + SymbolicUtils, + Postwalk, + Sym, + BasicSymbolic, + isterm, + ispow, + isadd, + isdiv, + ismul, + add_with_div, + frac_maketerm, + @compactified, + issym + +using Symbolics: + Symbolics, + Num, + unwrap, + wrap, + get_variables, + simplify, + expand_derivatives, + build_function, + Equation, + Differential, + @variables, + arguments, + simplify_fractions, + substitute, + term, + expand, + operation + +include("Symbolics/Symbolics_utils.jl") +include("Symbolics/exponentials.jl") +include("Symbolics/fourier.jl") +include("Symbolics/drop_powers.jl") include("modules/extention_functions.jl") include("utils.jl") @@ -74,14 +110,14 @@ export get_krylov_equations include("modules/FFTWExt.jl") using .FFTWExt -@setup_workload begin - # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the - # precompile file and potentially make loading faster. - @compile_workload begin - # all calls in this block will be precompiled, regardless of whether - # they belong to your package or not (on Julia 1.8 and higher) - include("precompilation.jl") - end -end +# @setup_workload begin +# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the +# # precompile file and potentially make loading faster. +# @compile_workload begin +# # all calls in this block will be precompiled, regardless of whether +# # they belong to your package or not (on Julia 1.8 and higher) +# include("precompilation.jl") +# end +# end end # module diff --git a/src/HarmonicEquation.jl b/src/HarmonicEquation.jl index 61ae2c42..968ed495 100644 --- a/src/HarmonicEquation.jl +++ b/src/HarmonicEquation.jl @@ -63,7 +63,8 @@ function slow_flow!(eom::HarmonicEquation; fast_time::Num, slow_time::Num, degre drop = [d(var, fast_time, degree) => 0 for var in get_variables(eom)] eom.equations = substitute_all(substitute_all(eom.equations, drop), replace) - return eom.variables = substitute_all(eom.variables, replace) + eom.variables = substitute_all(eom.variables, replace) + return nothing end """ diff --git a/src/HarmonicVariable.jl b/src/HarmonicVariable.jl index 0276c317..de84214c 100644 --- a/src/HarmonicVariable.jl +++ b/src/HarmonicVariable.jl @@ -51,3 +51,33 @@ Symbolics.get_variables(var::HarmonicVariable)::Num = Num(first(get_variables(va Base.isequal(v1::HarmonicVariable, v2::HarmonicVariable)::Bool = isequal(v1.symbol, v2.symbol) + +"The derivative of f w.r.t. x of degree deg" +function d(f::Num, x::Num, deg=1)::Num + return isequal(deg, 0) ? f : (Differential(x)^deg)(f) +end +d(funcs::Vector{Num}, x::Num, deg=1) = Num[d(f, x, deg) for f in funcs] + +"Declare a variable in the the currently active namespace" +function declare_variable(name::String) + var_sym = Symbol(name) + @eval($(var_sym) = first(Symbolics.@variables $var_sym)) + return eval(var_sym) +end + +"Declare a variable that is a function of another variable in the the current namespace" +function declare_variable(name::String, independent_variable::Num) + # independent_variable = declare_variable(independent_variable) convert string into Num + var_sym = Symbol(name) + new_var = Symbolics.@variables $var_sym(independent_variable) + @eval($(var_sym) = first($new_var)) # store the variable under "name" in this namespace + return eval(var_sym) +end + +"Return the name of a variable (excluding independent variables)" +function var_name(x::Num) + var = Symbolics._toexpr(x) + return var isa Expr ? String(var.args[1]) : String(var) +end +# var_name(x::Term) = String(Symbolics._toexpr(x).args[1]) +var_name(x::Sym) = String(x.name) diff --git a/src/Symbolics/Symbolics_utils.jl b/src/Symbolics/Symbolics_utils.jl new file mode 100644 index 00000000..719897e1 --- /dev/null +++ b/src/Symbolics/Symbolics_utils.jl @@ -0,0 +1,128 @@ + +expand_all(x::Num) = Num(expand_all(x.val)) +_apply_termwise(f, x::Num) = wrap(_apply_termwise(f, unwrap(x))) + +"Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" +function expand_all(x) + result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) + return isnothing(result) ? x : result +end +expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im) + +"Apply a function f on every member of a sum or a product" +function _apply_termwise(f, x::BasicSymbolic) + @compactified x::BasicSymbolic begin + Add => sum([f(arg) for arg in arguments(x)]) + Mul => prod([f(arg) for arg in arguments(x)]) + Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den) + _ => f(x) + end +end + +simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im +simplify_complex(x) = x +function simplify_complex(x::BasicSymbolic) + @compactified x::BasicSymbolic begin + Add => _apply_termwise(simplify_complex, x) + Mul => _apply_termwise(simplify_complex, x) + Div => _apply_termwise(simplify_complex, x) + _ => x + end +end + +""" +$(TYPEDSIGNATURES) + +Perform substitutions in `rules` on `x`. +`include_derivatives=true` also includes all derivatives of the variables of the keys of `rules`. +""" +subtype = Union{Num,Equation,BasicSymbolic} +function substitute_all(x::subtype, rules::Dict; include_derivatives=true) + if include_derivatives + rules = merge( + rules, + Dict([Differential(var) => Differential(rules[var]) for var in keys(rules)]), + ) + end + return substitute(x, rules) +end +"Variable substitution - dictionary" +function substitute_all(dict::Dict, rules::Dict)::Dict + new_keys = substitute_all.(keys(dict), rules) + new_values = substitute_all.(values(dict), rules) + return Dict(zip(new_keys, new_values)) +end +Collections = Union{Dict,Pair,Vector,OrderedDict} +substitute_all(v::AbstractArray, rules) = [substitute_all(x, rules) for x in v] +substitute_all(x::subtype, rules::Collections) = substitute_all(x, Dict(rules)) +# Collections = Union{Dict,OrderedDict} +# function substitute_all(x, rules::Collections; include_derivatives=true) +# if include_derivatives +# rules = merge( +# rules, +# Dict([Differential(var) => Differential(rules[var]) for var in keys(rules)]), +# ) +# end +# return substitute(x, rules) +# end +# "Variable substitution - dictionary" +# function substitute_all(dict::Dict, rules::Dict)::Dict +# new_keys = substitute_all.(keys(dict), rules) +# new_values = substitute_all.(values(dict), rules) +# return Dict(zip(new_keys, new_values)) +# end +# substitute_all(v::AbstractArray, rules::Collections) = [substitute_all(x, rules) for x in v] + +get_independent(x::Num, t::Num) = get_independent(x.val, t) +function get_independent(x::Complex{Num}, t::Num) + return get_independent(x.re, t) + im * get_independent(x.im, t) +end +get_independent(v::Vector{Num}, t::Num) = [get_independent(el, t) for el in v] +get_independent(x, t::Num) = x + +function get_independent(x::BasicSymbolic, t::Num) + @compactified x::BasicSymbolic begin + Add => sum([get_independent(arg, t) for arg in arguments(x)]) + Mul => prod([get_independent(arg, t) for arg in arguments(x)]) + Div => !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 + Pow => !is_function(x.base, t) && !is_function(x.exp, t) ? x : 0 + Term => !is_function(x, t) ? x : 0 + Sym => !is_function(x, t) ? x : 0 + _ => x + end +end + +"Return all the terms contained in `x`" +get_all_terms(x::Num) = unique(_get_all_terms(Symbolics.expand(x).val)) +function get_all_terms(x::Equation) + return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1)) +end +function _get_all_terms(x::BasicSymbolic) + @compactified x::BasicSymbolic begin + Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...) + Mul => Num.(SymbolicUtils.arguments(x)) + Div => Num.([_get_all_terms(x.num)..., _get_all_terms(x.den)...]) + _ => Num(x) + end +end +_get_all_terms(x) = Num(x) + +function is_harmonic(x::Num, t::Num)::Bool + all_terms = get_all_terms(x) + t_terms = setdiff(all_terms, get_independent(all_terms, t)) + isempty(t_terms) && return true + trigs = is_trig.(t_terms) + + if !prod(trigs) + return false + else + powers = [max_power(first(term.val.arguments), t) for term in t_terms[trigs]] + return all(isone, powers) + end +end + +is_harmonic(x::Equation, t::Num) = is_harmonic(x.lhs, t) && is_harmonic(x.rhs, t) +is_harmonic(x, t) = is_harmonic(Num(x), Num(t)) + +"Return true if `f` is a function of `var`." +is_function(f, var) = any(isequal.(get_variables(f), var)) diff --git a/src/Symbolics/drop_powers.jl b/src/Symbolics/drop_powers.jl new file mode 100644 index 00000000..7c28b101 --- /dev/null +++ b/src/Symbolics/drop_powers.jl @@ -0,0 +1,70 @@ +""" +$(SIGNATURES) +Remove parts of `expr` where the combined power of `vars` is => `deg`. + +# Example +```julia-repl +julia> @variables x,y; +julia>drop_powers((x+y)^2, x, 2) +y^2 + 2*x*y +julia>drop_powers((x+y)^2, [x,y], 2) +0 +julia>drop_powers((x+y)^2 + (x+y)^3, [x,y], 3) +x^2 + y^2 + 2*x*y +``` +""" +function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) + Symbolics.@variables ϵ + subs_expr = deepcopy(expr) + rules = Dict([var => ϵ * var for var in unique(vars)]) + subs_expr = Symbolics.expand(substitute_all(subs_expr, rules)) + max_deg = max_power(subs_expr, ϵ) + removal = Dict([ϵ^d => Num(0) for d in deg:max_deg]) + res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1))) + return Symbolics.expand(res) +end + +function drop_powers(expr::Vector{Num}, var::Vector{Num}, deg::Int) + return [drop_powers(x, var, deg) for x in expr] +end + +# calls the above for various types of the first argument +function drop_powers(eq::Equation, var::Vector{Num}, deg::Int) + return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.lhs, var, deg) +end +function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int) + return [ + Equation(drop_powers(eq.lhs, var, deg), drop_powers(eq.rhs, var, deg)) for eq in eqs + ] +end +drop_powers(expr, var::Num, deg::Int) = drop_powers(expr, [var], deg) +drop_powers(x, vars, deg::Int) = drop_powers(Num(x), vars, deg) + +"Return the highest power of `y` occuring in the term `x`." +function max_power(x::Num, y::Num) + terms = get_all_terms(x) + powers = power_of.(terms, y) + return maximum(powers) +end + +max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y)) +max_power(x::Complex, y::Num) = maximum(max_power.([x.re, x.im], y)) +max_power(x, t) = max_power(Num(x), Num(t)) + +"Return the power of `y` in the term `x`" +function power_of(x::Num, y::Num) + issym(y.val) ? nothing : error("power of " * string(y) * " is ambiguous") + return power_of(x.val, y.val) +end + +function power_of(x::BasicSymbolic, y::BasicSymbolic) + if ispow(x) && issym(y) + return isequal(x.base, y) ? x.exp : 0 + elseif issym(x) && issym(y) + return isequal(x, y) ? 1 : 0 + else + return 0 + end +end + +power_of(x, y) = 0 diff --git a/src/Symbolics/exponentials.jl b/src/Symbolics/exponentials.jl new file mode 100644 index 00000000..74423ecb --- /dev/null +++ b/src/Symbolics/exponentials.jl @@ -0,0 +1,41 @@ +expand_exp_power(expr::Num) = expand_exp_power(expr.val) +simplify_exp_products(x::Num) = simplify_exp_products(x.val) + +"Returns true if expr is an exponential" +isexp(expr) = isterm(expr) && expr.f == exp + +"Expand powers of exponential such that exp(x)^n => exp(x*n) " +function expand_exp_power(expr::BasicSymbolic) + @compactified expr::BasicSymbolic begin + Add => sum([expand_exp_power(arg) for arg in arguments(expr)]) + Mul => prod([expand_exp_power(arg) for arg in arguments(expr)]) + _ => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr + end +end +expand_exp_power(expr) = expr + +"Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b) +This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls" +function simplify_exp_products(expr::BasicSymbolic) + @compactified expr::BasicSymbolic begin + Add => _apply_termwise(simplify_exp_products, expr) + Div => _apply_termwise(simplify_exp_products, expr) + Mul => simplify_exp_products_mul(expr) + _ => expr + end +end +function simplify_exp_products(x::Complex{Num}) + return Complex{Num}(simplify_exp_products(x.re.val), simplify_exp_products(x.im.val)) +end +function simplify_exp_products_mul(expr) + ind = findall(x -> isexp(x), arguments(expr)) + rest_ind = setdiff(1:length(arguments(expr)), ind) + rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) + total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) + if SymbolicUtils.is_literal_number(total) + (total == 0 && return rest) + else + return rest * exp(total) + end +end +simplify_exp_products(x) = x diff --git a/src/Symbolics/fourier.jl b/src/Symbolics/fourier.jl new file mode 100644 index 00000000..272a65f1 --- /dev/null +++ b/src/Symbolics/fourier.jl @@ -0,0 +1,120 @@ + +"Expand all sin/cos powers in `x`." +function trig_reduce(x) + x = add_div(x) # a/b + c/d = (ad + bc)/bd + x = expand(x) # open all brackets + x = trig_to_exp(x) + x = expand_all(x) # expand products of exponentials + x = simplify_exp_products(x) # simplify products of exps + x = exp_to_trig(x) + x = Num(simplify_complex(expand(x))) + return simplify_fractions(x) # (a*c^2 + b*c)/c^2 = (a*c + b)/c +end + +"Return true if `f` is a sin or cos." +function is_trig(f::Num) + f = ispow(f.val) ? f.val.base : f.val + isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] && return true + return false +end + +""" +$(TYPEDSIGNATURES) +Returns the coefficient of cos(ωt) in `x`. +""" +function fourier_cos_term(x, ω, t) + return _fourier_term(x, ω, t, cos) +end + +"Simplify fraction a/b + c/d = (ad + bc)/bd" +add_div(x) = wrap(Postwalk(add_with_div; maketerm=frac_maketerm)(unwrap(x))) + +""" +$(TYPEDSIGNATURES) +Returns the coefficient of sin(ωt) in `x`. +""" +function fourier_sin_term(x, ω, t) + return _fourier_term(x, ω, t, sin) +end + +function _fourier_term(x::Equation, ω, t, f) + return Equation(_fourier_term(x.lhs, ω, t, f), _fourier_term(x.rhs, ω, t, f)) +end + +"Return the coefficient of f(ωt) in `x` where `f` is a cos or sin." +function _fourier_term(x, ω, t, f) + term = x * f(ω * t) + term = trig_reduce(term) + indep = get_independent(term, t) + ft = Num(simplify_complex(Symbolics.expand(indep))) + ft = !isequal(ω, 0) ? 2 * ft : ft # extra factor in case ω = 0 ! + return Symbolics.expand(ft) +end + +"Convert all sin/cos terms in `x` into exponentials." +function trig_to_exp(x::Num) + all_terms = get_all_terms(x) + trigs = filter(z -> is_trig(z), all_terms) + + rules = [] + for trig in trigs + is_pow = ispow(trig.val) # trig is either a trig or a power of a trig + power = is_pow ? trig.val.exp : 1 + arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1] + type = is_pow ? operation(trig.val.base) : operation(trig.val) + + if type == cos + term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0) + elseif type == sin + term = + (1 * im^power) * + Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0) + end + # avoid Complex{Num} where possible as this causes bugs + # instead, the Nums store SymbolicUtils Complex types + term = Num(Symbolics.expand(term.re.val + im * term.im.val)) + append!(rules, [trig => term]) + end + + result = Symbolics.substitute(x, Dict(rules)) + return convert_to_Num(result) +end +convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments)) +convert_to_Num(x::Num)::Num = x + +function exp_to_trig(x::BasicSymbolic) + if isadd(x) || isdiv(x) || ismul(x) + return _apply_termwise(exp_to_trig, x) + elseif isterm(x) && x.f == exp + arg = first(x.arguments) + trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function + trigarg = simplify_complex(trigarg) + + # put arguments of trigs into a standard form such that sin(x) = -sin(-x), cos(x) = cos(-x) are recognized + if isadd(trigarg) + first_symbol = minimum( + cat(string.(arguments(trigarg)), string.(arguments(-trigarg)); dims=1) + ) + + # put trigarg => -trigarg the lowest alphabetic argument of trigarg is lower than that of -trigarg + # this is a meaningless key but gives unique signs to all sums + is_first = minimum(string.(arguments(trigarg))) == first_symbol + return if is_first + cos(-trigarg) - im * sin(-trigarg) + else + cos(trigarg) + im * sin(trigarg) + end + end + return if ismul(trigarg) && trigarg.coeff < 0 + cos(-trigarg) - im * sin(-trigarg) + else + cos(trigarg) + im * sin(trigarg) + end + else + return x + end +end + +exp_to_trig(x) = x +exp_to_trig(x::Num) = exp_to_trig(x.val) +exp_to_trig(x::Complex{Num}) = exp_to_trig(x.re) + im * exp_to_trig(x.im) diff --git a/src/Symbolics_customised.jl b/src/Symbolics_customised.jl deleted file mode 100644 index b7964d2f..00000000 --- a/src/Symbolics_customised.jl +++ /dev/null @@ -1,176 +0,0 @@ -using SymbolicUtils: - SymbolicUtils, - Postwalk, - Sym, - BasicSymbolic, - isterm, - ispow, - isadd, - isdiv, - ismul, - add_with_div, - frac_maketerm, #, @compactified - issym - -using Symbolics: - Symbolics, - Num, - unwrap, - get_variables, - simplify, - expand_derivatives, - build_function, - Equation, - Differential, - @variables, - arguments, - simplify_fractions, - substitute, - term, - expand, - operation - -"Returns true if expr is an exponential" -is_exp(expr) = isterm(expr) && expr.f == exp - -"Expand powers of exponential such that exp(x)^n => exp(x*n) " -expand_exp_power(expr) = - ispow(expr) && is_exp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr -expand_exp_power_add(expr) = sum([expand_exp_power(arg) for arg in arguments(expr)]) -expand_exp_power_mul(expr) = prod([expand_exp_power(arg) for arg in arguments(expr)]) -expand_exp_power(expr::Num) = expand_exp_power(expr.val) - -function expand_exp_power(expr::BasicSymbolic) - if isadd(expr) - return expand_exp_power_add(expr) - elseif ismul(expr) - return expand_exp_power_mul(expr) - else - return if ispow(expr) && is_exp(expr.base) - exp(expr.base.arguments[1] * expr.exp) - else - expr - end - end -end - -"Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" -expand_all(x) = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) -expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im) -expand_all(x::Num) = Num(expand_all(x.val)) - -"Apply a function f on every member of a sum or a product" -_apply_termwise(f, x) = f(x) -function _apply_termwise(f, x::BasicSymbolic) - if isadd(x) - return sum([f(arg) for arg in arguments(x)]) - elseif ismul(x) - return prod([f(arg) for arg in arguments(x)]) - elseif isdiv(x) - return _apply_termwise(f, x.num) / _apply_termwise(f, x.den) - else - return f(x) - end -end -# We could use @compactified to do the achive thing wit a speed-up. Neverthless, it yields less readable code. -# @compactified is what SymbolicUtils uses internally -# function _apply_termwise(f, x::BasicSymbolic) -# @compactified x::BasicSymbolic begin -# Add => sum([f(arg) for arg in arguments(x)]) -# Mul => prod([f(arg) for arg in arguments(x)]) -# Div => _apply_termwise(f, x.num) / _apply_termwise(f, x.den) -# _ => f(x) -# end -# end - -simplify_complex(x::Complex) = isequal(x.im, 0) ? x.re : x.re + im * x.im -simplify_complex(x) = x -function simplify_complex(x::BasicSymbolic) - if isadd(x) || ismul(x) || isdiv(x) - return _apply_termwise(simplify_complex, x) - else - return x - end -end - -"Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b) -This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls" -function simplify_exp_products_mul(expr) - ind = findall(x -> is_exp(x), arguments(expr)) - rest_ind = setdiff(1:length(arguments(expr)), ind) - rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) - total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) - if SymbolicUtils.is_literal_number(total) - (total == 0 && return rest) - else - return rest * exp(total) - end -end - -function simplify_exp_products(x::Complex{Num}) - return Complex{Num}(simplify_exp_products(x.re.val), simplify_exp_products(x.im.val)) -end -simplify_exp_products(x::Num) = simplify_exp_products(x.val) -simplify_exp_products(x) = x - -function simplify_exp_products(expr::BasicSymbolic) - if isadd(expr) || isdiv(expr) - return _apply_termwise(simplify_exp_products, expr) - elseif ismul(expr) - return simplify_exp_products_mul(expr) - else - return expr - end -end - -function exp_to_trig(x::BasicSymbolic) - if isadd(x) || isdiv(x) || ismul(x) - return _apply_termwise(exp_to_trig, x) - elseif isterm(x) && x.f == exp - arg = first(x.arguments) - trigarg = Symbolics.expand(-im * arg) # the argument of the to-be trig function - trigarg = simplify_complex(trigarg) - - # put arguments of trigs into a standard form such that sin(x) = -sin(-x), cos(x) = cos(-x) are recognized - if isadd(trigarg) - first_symbol = minimum( - cat(string.(arguments(trigarg)), string.(arguments(-trigarg)); dims=1) - ) - - # put trigarg => -trigarg the lowest alphabetic argument of trigarg is lower than that of -trigarg - # this is a meaningless key but gives unique signs to all sums - is_first = minimum(string.(arguments(trigarg))) == first_symbol - return if is_first - cos(-trigarg) - im * sin(-trigarg) - else - cos(trigarg) + im * sin(trigarg) - end - end - return if ismul(trigarg) && trigarg.coeff < 0 - cos(-trigarg) - im * sin(-trigarg) - else - cos(trigarg) + im * sin(trigarg) - end - else - return x - end -end - -exp_to_trig(x) = x -exp_to_trig(x::Num) = exp_to_trig(x.val) -exp_to_trig(x::Complex{Num}) = exp_to_trig(x.re) + im * exp_to_trig(x.im) - -# sometimes, expressions get stored as Complex{Num} with no way to decode what real(x) and imag(x) -# this overloads the Num constructor to return a Num if x.re and x.im have similar arguments -function Num(x::Complex{Num})::Num - if x.re.val isa Float64 && x.im.val isa Float64 - return Num(x.re.val) - else - if isequal(x.re.val.arguments, x.im.val.arguments) - Num(first(x.re.val.arguments)) - else - error("Cannot convert Complex{Num} " * string(x) * " to Num") - end - end -end -# ^ This function commits type-piracy with Symbolics.jl. We should change this. diff --git a/src/Symbolics_utils.jl b/src/Symbolics_utils.jl deleted file mode 100644 index aae6e8fb..00000000 --- a/src/Symbolics_utils.jl +++ /dev/null @@ -1,305 +0,0 @@ -"The derivative of f w.r.t. x of degree deg" -function d(f::Num, x::Num, deg=1)::Num - return isequal(deg, 0) ? f : (Differential(x)^deg)(f) -end -d(funcs::Vector{Num}, x::Num, deg=1) = Num[d(f, x, deg) for f in funcs] - -"Declare a variable in the the currently active namespace" -function declare_variable(name::String) - var_sym = Symbol(name) - @eval($(var_sym) = first(Symbolics.@variables $var_sym)) - return eval(var_sym) -end - -"Declare a variable that is a function of another variable in the the current namespace" -function declare_variable(name::String, independent_variable::Num) - # independent_variable = declare_variable(independent_variable) convert string into Num - var_sym = Symbol(name) - new_var = Symbolics.@variables $var_sym(independent_variable) - @eval($(var_sym) = first($new_var)) # store the variable under "name" in this namespace - return eval(var_sym) -end - -"Return the name of a variable (excluding independent variables)" -function var_name(x::Num) - var = Symbolics._toexpr(x) - return var isa Expr ? String(var.args[1]) : String(var) -end -# var_name(x::Term) = String(Symbolics._toexpr(x).args[1]) -var_name(x::Sym) = String(x.name) - -""" -$(TYPEDSIGNATURES) - -Perform substitutions in `rules` on `x`. -`include_derivatives=true` also includes all derivatives of the variables of the keys of `rules`. -""" -function substitute_all( - x::T, rules::Dict; include_derivatives=true -)::T where {T<:Union{Equation,Num}} - if include_derivatives - rules = merge( - rules, - Dict([Differential(var) => Differential(rules[var]) for var in keys(rules)]), - ) - end - return substitute(x, rules) -end - -"Variable substitution - dictionary" -function substitute_all(dict::Dict, rules::Dict)::Dict - new_keys = substitute_all.(keys(dict), rules) - new_values = substitute_all.(values(dict), rules) - return Dict(zip(new_keys, new_values)) -end - -function substitute_all( - v::Union{Array{Num},Array{Equation}}, rules::Union{Dict,Pair,Vector} -) - return [substitute_all(x, rules) for x in v] -end -function substitute_all(x::Union{Num,Equation}, rules::Union{Pair,Vector,OrderedDict}) - return substitute_all(x, Dict(rules)) -end -function substitute_all(x::Complex{Num}, rules::Union{Pair,Vector,OrderedDict,Dict}) - return substitute_all(Num(x.re.val.arguments[1]), rules) -end -substitute_all(x, rules) = substitute_all(Num(x), rules) - -""" -$(SIGNATURES) -Remove parts of `expr` where the combined power of `vars` is => `deg`. - -# Example -```julia-repl -julia> @variables x,y; -julia>drop_powers((x+y)^2, x, 2) -y^2 + 2*x*y -julia>drop_powers((x+y)^2, [x,y], 2) -0 -julia>drop_powers((x+y)^2 + (x+y)^3, [x,y], 3) -x^2 + y^2 + 2*x*y -``` -""" -function drop_powers(expr::Num, vars::Vector{Num}, deg::Int) - Symbolics.@variables ϵ - subs_expr = deepcopy(expr) - rules = Dict([var => ϵ * var for var in unique(vars)]) - subs_expr = Symbolics.expand(substitute_all(subs_expr, rules)) - max_deg = max_power(subs_expr, ϵ) - removal = Dict([ϵ^d => Num(0) for d in deg:max_deg]) - res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1))) - return Symbolics.expand(res) - #res isa Complex ? Num(res.re.val.arguments[1]) : res -end - -function drop_powers(expr::Vector{Num}, var::Vector{Num}, deg::Int) - return [drop_powers(x, var, deg) for x in expr] -end - -# calls the above for various types of the first argument -function drop_powers(eq::Equation, var::Vector{Num}, deg::Int) - return drop_powers(eq.lhs, var, deg) .~ drop_powers(eq.lhs, var, deg) -end -function drop_powers(eqs::Vector{Equation}, var::Vector{Num}, deg::Int) - return [ - Equation(drop_powers(eq.lhs, var, deg), drop_powers(eq.rhs, var, deg)) for eq in eqs - ] -end -drop_powers(expr, var::Num, deg::Int) = drop_powers(expr, [var], deg) -drop_powers(x, vars, deg::Int) = drop_powers(Num(x), vars, deg) - -flatten(a) = collect(Iterators.flatten(a)) - -### -# STUFF BELOW IS MAINLY FOR FOURIER-TRANSFORMING -### - -get_independent(x::Num, t::Num) = get_independent(x.val, t) -function get_independent(x::Complex{Num}, t::Num) - return get_independent(x.re, t) + im * get_independent(x.im, t) -end -get_independent(v::Vector{Num}, t::Num) = [get_independent(el, t) for el in v] -get_independent(x, t::Num) = x - -function get_independent(x::BasicSymbolic, t::Num) - if isadd(x) - return sum([get_independent(arg, t) for arg in arguments(x)]) - elseif ismul(x) - return prod([get_independent(arg, t) for arg in arguments(x)]) - elseif ispow(x) - return !is_function(x.base, t) && !is_function(x.exp, t) ? x : 0 - elseif isdiv(x) - return !is_function(x.den, t) ? get_independent(x.num, t) / x.den : 0 - elseif isterm(x) || issym(x) - return !is_function(x, t) ? x : 0 - else - return x - end -end - -"Return all the terms contained in `x`" -get_all_terms(x::Num) = unique(_get_all_terms(Symbolics.expand(x).val)) -function get_all_terms(x::Equation) - return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1)) -end - -_get_all_terms_mul(x) = Num.(SymbolicUtils.arguments(x)) -_get_all_terms_div(x) = Num.([_get_all_terms(x.num)..., _get_all_terms(x.den)...]) -_get_all_terms(x) = Num(x) - -function _get_all_terms_add(x)::Vector{Num} - list = [] - for term in keys(x.dict) - list = cat(list, _get_all_terms(term); dims=1) - end - return list -end - -function _get_all_terms(x::BasicSymbolic) - if isadd(x) - return _get_all_terms_add(x) - elseif ismul(x) - return _get_all_terms_mul(x) - elseif isdiv(x) - return _get_all_terms_div(x) - else - return Num(x) - end -end - -function is_harmonic(x::Num, t::Num)::Bool - all_terms = get_all_terms(x) - t_terms = setdiff(all_terms, get_independent(all_terms, t)) - isempty(t_terms) && return true - trigs = is_trig.(t_terms) - - if !prod(trigs) - return false - else - powers = [max_power(first(term.val.arguments), t) for term in t_terms[trigs]] - return all(isone, powers) - end -end - -is_harmonic(x::Equation, t::Num) = is_harmonic(x.lhs, t) && is_harmonic(x.rhs, t) -is_harmonic(x, t) = is_harmonic(Num(x), Num(t)) - -"Convert all sin/cos terms in `x` into exponentials." -function trig_to_exp(x::Num) - all_terms = get_all_terms(x) - trigs = filter(z -> is_trig(z), all_terms) - - rules = [] - for trig in trigs - is_pow = ispow(trig.val) # trig is either a trig or a power of a trig - power = is_pow ? trig.val.exp : 1 - arg = is_pow ? arguments(trig.val.base)[1] : arguments(trig.val)[1] - type = is_pow ? operation(trig.val.base) : operation(trig.val) - - if type == cos - term = Complex{Num}((exp(im * arg) + exp(-im * arg))^power * (1//2)^power, 0) - elseif type == sin - term = - (1 * im^power) * - Complex{Num}(((exp(-im * arg) - exp(im * arg)))^power * (1//2)^power, 0) - end - # avoid Complex{Num} where possible as this causes bugs - # instead, the Nums store SymbolicUtils Complex types - term = Num(Symbolics.expand(term.re.val + im * term.im.val)) - append!(rules, [trig => term]) - end - - result = Symbolics.substitute(x, Dict(rules)) - #result = result isa Complex ? Num(first(result.re.val.arguments)) : result - result = Num(result) - return result -end - -"Return true if `f` is a function of `var`." -is_function(f, var) = any(isequal.(get_variables(f), var)) - -"Return true if `f` is a sin or cos." -function is_trig(f::Num) - f = ispow(f.val) ? f.val.base : f.val - isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] && return true - return false -end - -"A vector of Sym(0) of length n" -Num_zeros(n::Int64) = [Num(0) for k in 1:n] -Num_zeros(vec::Vector{Any}) = Num_zeros(length(vec)) - -""" -$(TYPEDSIGNATURES) -Returns the coefficient of cos(ωt) in `x`. -""" -function fourier_cos_term(x, ω, t) - return _fourier_term(x, ω, t, cos) -end - -""" -$(TYPEDSIGNATURES) -Returns the coefficient of sin(ωt) in `x`. -""" -function fourier_sin_term(x, ω, t) - return _fourier_term(x, ω, t, sin) -end - -function _fourier_term(x::Equation, ω, t, f) - return Equation(_fourier_term(x.lhs, ω, t, f), _fourier_term(x.rhs, ω, t, f)) -end - -"Return the coefficient of f(ωt) in `x` where `f` is a cos or sin." -function _fourier_term(x, ω, t, f) - term = x * f(ω * t) - term = trig_reduce(term) - indep = get_independent(term, t) - ft = Num(simplify_complex(Symbolics.expand(indep))) - ft = !isequal(ω, 0) ? 2 * ft : ft # extra factor in case ω = 0 ! - return Symbolics.expand(ft) -end - -"Simplify fraction a/b + c/d = (ad + bc)/bd" -add_div(x) = Num(Postwalk(add_with_div; maketerm=frac_maketerm)(unwrap(x))) - -"Expand all sin/cos powers in `x`." -function trig_reduce(x) - x = add_div(x) # a/b + c/d = (ad + bc)/bd - x = expand(x) # open all brackets - x = trig_to_exp(x) - x = expand_all(x) # expand products of exponentials - x = simplify_exp_products(x) # simplify products of exps - x = exp_to_trig(x) - x = Num(simplify_complex(expand(x))) - return simplify_fractions(x) # (a*c^2 + b*c)/c^2 = (a*c + b)/c -end - -"Return the highest power of `y` occuring in the term `x`." -function max_power(x::Num, y::Num) - terms = get_all_terms(x) - powers = power_of.(terms, y) - return maximum(powers) -end - -max_power(x::Vector{Num}, y::Num) = maximum(max_power.(x, y)) -max_power(x::Complex, y::Num) = maximum(max_power.([x.re, x.im], y)) -max_power(x, t) = max_power(Num(x), Num(t)) - -"Return the power of `y` in the term `x`" -function power_of(x::Num, y::Num) - issym(y.val) ? nothing : error("power of " * string(y) * " is ambiguous") - return power_of(x.val, y.val) -end - -function power_of(x::BasicSymbolic, y::BasicSymbolic) - if ispow(x) && issym(y) - return isequal(x.base, y) ? x.exp : 0 - elseif issym(x) && issym(y) - return isequal(x, y) ? 1 : 0 - else - return 0 - end -end - -power_of(x, y) = 0 diff --git a/src/modules/LinearResponse/jacobians.jl b/src/modules/LinearResponse/jacobians.jl index 965d40b9..15e5ae15 100644 --- a/src/modules/LinearResponse/jacobians.jl +++ b/src/modules/LinearResponse/jacobians.jl @@ -40,7 +40,8 @@ function get_Jacobian(eqs::Vector{Num}, vars::Vector{Num}) end function get_Jacobian(eqs::Vector{Equation}, vars::Vector{Num}) - return get_Jacobian(Num.(getfield.(eqs, :lhs) .- getfield.(eqs, :rhs)), vars) + expr = Num[getfield(eq, :lhs) - getfield(eq, :rhs) for eq in eqs] + return get_Jacobian(expr, vars) end """ diff --git a/src/transform_solutions.jl b/src/transform_solutions.jl index fd49d359..619b9870 100644 --- a/src/transform_solutions.jl +++ b/src/transform_solutions.jl @@ -100,8 +100,8 @@ function to_lab_frame(soln, res, times)::Vector{AbstractFloat} timetrace = zeros(length(times)) for var in res.problem.eom.variables - val = unwrap(substitute_all(_remove_brackets(var), soln)) - ω = unwrap(substitute_all(var.ω, soln)) + val = real(substitute_all(unwrap(_remove_brackets(var)), soln)) + ω = real(unwrap(substitute_all(var.ω, soln))) if var.type == "u" timetrace .+= val * cos.(ω * times) elseif var.type == "v" diff --git a/src/utils.jl b/src/utils.jl index d9be907f..59fa7dbf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,2 +1,4 @@ is_real(x) = abs(imag(x)) / abs(real(x)) < IM_TOL::Float64 || abs(x) < 1e-70 is_real(x::Array) = is_real.(x) + +flatten(a) = collect(Iterators.flatten(a)) diff --git a/test/fourier.jl b/test/fourier.jl deleted file mode 100644 index 34fe3f77..00000000 --- a/test/fourier.jl +++ /dev/null @@ -1,50 +0,0 @@ -using HarmonicBalance: fourier_cos_term, fourier_sin_term -using HarmonicBalance.Symbolics: expand - -@variables f t θ a b - -@test isequal(fourier_cos_term(cos(f * t)^2, f, t), 0) -@test isequal(fourier_sin_term(sin(f * t)^2, f, t), 0) - -@test isequal(fourier_cos_term(cos(f * t)^2, 2 * f, t), 1//2) -@test isequal(fourier_sin_term(cos(f * t)^2, 2 * f, t), 0) -@test isequal(fourier_cos_term(sin(f * t)^2, 2 * f, t), -1//2) -@test isequal(fourier_sin_term(sin(f * t)^2, 2 * f, t), 0) - -@test isequal(fourier_cos_term(cos(f * t), f, t), 1) -@test isequal(fourier_sin_term(sin(f * t), f, t), 1) - -@test isequal(fourier_cos_term(cos(f * t + θ), f, t), cos(θ)) -@test isequal(fourier_sin_term(cos(f * t + θ), f, t), -sin(θ)) - -term = - (a * sin(f * t) + b * cos(f * t)) * - (a * sin(2 * f * t) + b * cos(2 * f * t)) * - (a * sin(3 * f * t) + b * cos(3 * f * t)) -fourier_cos_term(term, 2 * f, t) -@test isequal(fourier_cos_term(term, 2 * f, t), expand(1//4 * (a^2 * b + b^3))) -@test isequal(fourier_cos_term(term, 4 * f, t), expand(1//4 * (a^2 * b + b^3))) -@test isequal(fourier_cos_term(term, 6 * f, t), expand(1//4 * (-3 * a^2 * b + b^3))) -@test isequal(fourier_sin_term(term, 2 * f, t), expand(1//4 * (a^3 + a * b^2))) -@test isequal(fourier_sin_term(term, 4 * f, t), expand(1//4 * (a^3 + a * b^2))) -@test isequal(fourier_sin_term(term, 6 * f, t), expand(1//4 * (-a^3 + 3 * a * b^2))) - -# try something harder! -term = (a + b * cos(f * t + θ)^2)^3 * sin(f * t) -@test isequal( - fourier_sin_term(term, f, t), - expand( - a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - - 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), - ), -) -@test isequal( - fourier_cos_term(term, f, t), - expand(-3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * sin(2 * θ)), -) - -# FTing at zero : picking out constant terms -@test isequal(fourier_cos_term(cos(f * t), 0, t), 0) -@test isequal(fourier_cos_term(cos(f * t)^3 + 1, 0, t), 1) -@test isequal(fourier_cos_term(cos(f * t)^2 + 1, 0, t), 3//2) -@test isequal(fourier_cos_term((cos(f * t)^2 + cos(f * t))^3, 0, t), 23//16) diff --git a/test/harmonics.jl b/test/harmonics.jl deleted file mode 100644 index c59b173c..00000000 --- a/test/harmonics.jl +++ /dev/null @@ -1,11 +0,0 @@ -using HarmonicBalance: is_harmonic - -@variables a, b, c, t, x(t), f, y(t) - -@test is_harmonic(cos(f * t), t) -@test is_harmonic(1, t) -@test !is_harmonic(cos(f * t^2 + a), t) -@test !is_harmonic(a + t, t) - -dEOM = DifferentialEquation([a + x, t^2 + cos(t)], [x, y]) -@test !is_harmonic(dEOM, t) diff --git a/test/powers.jl b/test/powers.jl deleted file mode 100644 index 86993720..00000000 --- a/test/powers.jl +++ /dev/null @@ -1,13 +0,0 @@ -using HarmonicBalance: drop_powers, max_power -using HarmonicBalance.Symbolics: expand - -@variables a, b, c - -@test max_power(a^2 + b, a) == 2 -@test max_power(a * ((a + b)^4)^2 + a, a) == 9 - -@test isequal(drop_powers(a^2 + b, a, 1), b) -@test isequal(drop_powers((a + b)^2, a, 1), b^2) -@test isequal(drop_powers((a + b)^2, [a, b], 1), 0) - -@test isequal(drop_powers((a + b)^3 + (a + b)^5, [a, b], 4), expand((a + b)^3)) diff --git a/test/runtests.jl b/test/runtests.jl index c92402d8..a53c2f0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,11 @@ Random.seed!(SEED) @testset "Code quality" begin using ExplicitImports, Aqua + # using ModelingToolkit, OrdinaryDiffEqTsit5, SteadyStateDiffEq ignore_deps = [:Random, :LinearAlgebra, :Printf, :Test, :Pkg] + # TimeEvolution = Base.get_extension(HarmonicBalance, :TimeEvolution) + # ModelingToolkitExt = Base.get_extension(HarmonicBalance, :ModelingToolkitExt) + # SteadyStateDiffEqExt = Base.get_extension(HarmonicBalance, :SteadyStateDiffEqExt) @test check_no_stale_explicit_imports(HarmonicBalance) == nothing @test check_all_explicit_imports_via_owners(HarmonicBalance) == nothing @@ -19,9 +23,22 @@ Random.seed!(SEED) check_extras=(ignore=ignore_deps,), check_weakdeps=(ignore=ignore_deps,), ), - piracies=(treat_as_own=[HarmonicBalance.Num],), ambiguities=false, ) + # for mod in [HarmonicBalance, TimeEvolution, ModelingToolkitExt, SteadyStateDiffEqExt] + # @test check_no_stale_explicit_imports(mod) == nothing + # @test check_all_explicit_imports_via_owners(mod) == nothing + # Aqua.test_ambiguities(mod) + # Aqua.test_all( + # mod; + # deps_compat=false, + # ambiguities=false, + # piracies=false, + # stale_deps=false, + # project_extras=false, + # persistent_tasks=false + # ) + # end end @testset "Code linting" begin @@ -29,14 +46,12 @@ end JET.test_package(HarmonicBalance; target_defined_modules=true) end -@testset "Symbolics customised" begin +@testset "API" begin include("API.jl") end @testset "Symbolics customised" begin - include("powers.jl") - include("harmonics.jl") - include("fourier.jl") + include("symbolics.jl") end @testset "IO" begin diff --git a/test/symbolics.jl b/test/symbolics.jl new file mode 100644 index 00000000..d32b1213 --- /dev/null +++ b/test/symbolics.jl @@ -0,0 +1,179 @@ +using Test +using Symbolics +using HarmonicBalance +using SymbolicUtils: Fixpoint, Prewalk, PassThrough + +macro eqtest(expr) + @assert expr.head == :call && expr.args[1] in [:(==), :(!=)] + return esc( + if expr.args[1] == :(==) + :(@test isequal($(expr.args[2]), $(expr.args[3]))) + else + :(@test !isequal($(expr.args[2]), $(expr.args[3]))) + end, + ) +end + +@testset "exp(x)^n => exp(x*n)" begin + using HarmonicBalance: expand_all, expand_exp_power + @variables a n + + @eqtest expand_exp_power(exp(a)^3) == exp(3 * a) + @eqtest simplify(exp(a)^3) == exp(3 * a) + @eqtest simplify(exp(a)^n) == exp(n * a) + @eqtest expand_all(exp(a)^3) == exp(3 * a) + @eqtest expand_all(exp(a)^3) == exp(3 * a) + @eqtest expand_all(im * exp(a)^5) == im * exp(5 * a) +end + +@testset "exp(a)*exp(b) => exp(a+b)" begin + using HarmonicBalance: simplify_exp_products + @variables a b + + @eqtest simplify_exp_products(exp(a) * exp(b)) == exp(a + b) + @eqtest simplify_exp_products(exp(3a) * exp(4b)) == exp(3a + 4b) + @eqtest simplify_exp_products(im * exp(3a) * exp(4b)) == im * exp(3a + 4b) +end + +@testset "euler" begin + @variables a b + @eqtest cos(a) + im * sin(a) == exp(im * a) + @eqtest exp(a) * cos(b) + im * sin(b) * exp(a) == exp(a + im * b) +end + +@testset "powers" begin + using HarmonicBalance: drop_powers, max_power + using HarmonicBalance.Symbolics: expand + + @variables a, b, c + + @eqtest max_power(a^2 + b, a) == 2 + @eqtest max_power(a * ((a + b)^4)^2 + a, a) == 9 + + @eqtest drop_powers(a^2 + b, a, 1) == b + @eqtest drop_powers((a + b)^2, a, 1) == b^2 + @eqtest drop_powers((a + b)^2, [a, b], 1) == 0 + @eqtest drop_powers((a + b)^3 + (a + b)^5, [a, b], 4) == expand((a + b)^3) +end + +@testset "trig_to_exp and trig_to_exp" begin + using HarmonicBalance: expand_all, trig_to_exp, exp_to_trig + @variables f t + cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2 + sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im) + + trigs = [cos(f * t), sin(f * t)] + for (i, trig) in pairs(trigs) + z = trig_to_exp(trig) + @eqtest expand(exp_to_trig(z)) == trig + end +end + +@testset "harmonic" begin + using HarmonicBalance: is_harmonic + + @variables a, b, c, t, x(t), f, y(t) + + @test is_harmonic(cos(f * t), t) + @test is_harmonic(1, t) + @test !is_harmonic(cos(f * t^2 + a), t) + @test !is_harmonic(a + t, t) + + dEOM = DifferentialEquation([a + x, t^2 + cos(t)], [x, y]) + @test !is_harmonic(dEOM, t) +end + +@testset "fourier" begin + using HarmonicBalance: fourier_cos_term, fourier_sin_term + using HarmonicBalance.Symbolics: expand + + @variables f t θ a b + + @eqtest fourier_cos_term(cos(f * t)^2, f, t) == 0 + @eqtest fourier_sin_term(sin(f * t)^2, f, t) == 0 + + @eqtest fourier_cos_term(cos(f * t)^2, 2 * f, t) == 1//2 + @eqtest fourier_sin_term(cos(f * t)^2, 2 * f, t) == 0 + @eqtest fourier_cos_term(sin(f * t)^2, 2 * f, t) == -1//2 + @eqtest fourier_sin_term(sin(f * t)^2, 2 * f, t) == 0 + + @eqtest fourier_cos_term(cos(f * t), f, t) == 1 + @eqtest fourier_sin_term(sin(f * t), f, t) == 1 + + @eqtest fourier_cos_term(cos(f * t + θ), f, t) == cos(θ) + @eqtest fourier_sin_term(cos(f * t + θ), f, t) == -sin(θ) + + term = + (a * sin(f * t) + b * cos(f * t)) * + (a * sin(2 * f * t) + b * cos(2 * f * t)) * + (a * sin(3 * f * t) + b * cos(3 * f * t)) + + @eqtest fourier_cos_term(term, 2 * f, t) == expand(1//4 * (a^2 * b + b^3)) + @eqtest fourier_cos_term(term, 4 * f, t) == expand(1//4 * (a^2 * b + b^3)) + @eqtest fourier_cos_term(term, 6 * f, t) == expand(1//4 * (-3 * a^2 * b + b^3)) + @eqtest fourier_sin_term(term, 2 * f, t) == expand(1//4 * (a^3 + a * b^2)) + @eqtest fourier_sin_term(term, 4 * f, t) == expand(1//4 * (a^3 + a * b^2)) + @eqtest fourier_sin_term(term, 6 * f, t) == expand(1//4 * (-a^3 + 3 * a * b^2)) + + # try something harder! + term = (a + b * cos(f * t + θ)^2)^3 * sin(f * t) + @eqtest fourier_sin_term(term, f, t) == expand( + a^3 + a^2 * b * 3//2 + 9//8 * a * b^2 + 5//16 * b^3 - + 3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * cos(2 * θ), + ) + + @eqtest fourier_cos_term(term, f, t) == + expand(-3//64 * b * (16 * a^2 + 16 * a * b + 5 * b^2) * sin(2 * θ)) + + # FTing at zero : picking out constant terms + @eqtest fourier_cos_term(cos(f * t), 0, t) == 0 + @eqtest fourier_cos_term(cos(f * t)^3 + 1, 0, t) == 1 + @eqtest fourier_cos_term(cos(f * t)^2 + 1, 0, t) == 3//2 + @eqtest fourier_cos_term((cos(f * t)^2 + cos(f * t))^3, 0, t) == 23//16 +end + +@testset "_apply_termwise" begin + using HarmonicBalance: _apply_termwise + + @variables a, b, c + + @eqtest _apply_termwise(x -> x^2, a + b + c) == a^2 + b^2 + c^2 + @eqtest _apply_termwise(x -> x^2, a * b * c) == a^2 * b^2 * c^2 + @eqtest _apply_termwise(x -> x^2, a / b) == a^2 / b^2 +end + +@testset "simplify_complex" begin + using HarmonicBalance: simplify_complex + @variables a, b, c + z = Complex{Num}(a) + @test simplify_complex(z) isa Num + + z = Complex{Num}(1 + 0 * im) + @test simplify_complex(z) isa Num +end + +@testset "get_all_terms" begin + using HarmonicBalance: get_all_terms + @variables a, b, c + + @eqtest get_all_terms(a + b + c) == [a, b, c] + @eqtest get_all_terms(a * b * c) == [a, b, c] + @eqtest get_all_terms(a / b) == [a, b] + @eqtest get_all_terms(a^2 + b^2 + c^2) == [b^2, a^2, c^2] + @eqtest get_all_terms(a^2 / b^2) == [a^2, b^2] + @eqtest get_all_terms(2 * b^2) == [2, b^2] +end + +@testset "get_independent" begin + using HarmonicBalance: get_independent + @variables a, b, c, t + + @eqtest get_independent(a + b + c, t) == a + b + c + @eqtest get_independent(a * b * c, t) == a * b * c + @eqtest get_independent(a / b, t) == a / b + @eqtest get_independent(a^2 + b^2 + c^2, t) == a^2 + b^2 + c^2 + @eqtest get_independent(a^2 / b^2, t) == a^2 / b^2 + @eqtest get_independent(2 * b^2, t) == 2 * b^2 + @eqtest get_independent(cos(t), t) == 0 + @eqtest get_independent(cos(t)^2 + 5, t) == 5 +end