Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly deal with const template parameters #409

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ libcxxwrap_julia_jll = "3eaa8342-bff7-56a5-9981-c04077f7cee7"
[compat]
MacroTools = "0.5.9"
julia = "1.6"
libcxxwrap_julia_jll = "0.12.0"
libcxxwrap_julia_jll = "0.12.1"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Expand Down
53 changes: 42 additions & 11 deletions src/CxxWrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ConstCxxPtr, ConstCxxRef, CxxRef, CxxPtr,
CppEnum, ConstArray, CxxBool, CxxLong, CxxULong, CxxChar, CxxChar16, CxxChar32, CxxWchar, CxxUChar, CxxSignedChar,
CxxLongLong, CxxULongLong, ptrunion, gcprotect, gcunprotect, isnull, libcxxwrapversion

const libcxxwrap_version_range = (v"0.12.0", v"0.13")
const libcxxwrap_version_range = (v"0.12.1", v"0.13")

using libcxxwrap_julia_jll # for libcxxwrap_julia and libcxxwrap_julia_stl

Expand Down Expand Up @@ -129,6 +129,10 @@ Base.flipsign(x::T, y::T) where {T <: CxxSigned} = reinterpret(T, flipsign(to_ju
struct IsCxxType end
struct IsNormalType end

struct CxxConst{T}
cpp_object::Ptr{T}
end

@inline cpp_trait_type(::Type) = IsNormalType

# Enum type interface
Expand All @@ -152,29 +156,48 @@ Base.show(io::IO, x::SmartPointer) = print(io, "C++ smart pointer of type ", typ
allocated_type(t::Type) = Any
dereferenced_type(t::Type) = Any

__cxxwrap_smartptr_dereference(p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_dereference(CxxRef(p))
__cxxwrap_smartptr_construct_from_other(t::Type{<:SmartPointer{T}}, p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_construct_from_other(t,CxxRef(p))
__cxxwrap_smartptr_cast_to_base(p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_cast_to_base(CxxRef(p))
function __cxxwrap_smartptr_dereference end
function __cxxwrap_smartptr_construct_from_other end
function __cxxwrap_smartptr_cast_to_base end
function __cxxwrap_make_const_smartptr end

function Base.getindex(p::SmartPointer{T}) where {T}
return __cxxwrap_smartptr_dereference(p)
return __cxxwrap_smartptr_dereference(CxxRef(p))
end

# No conversion if source and target type are identical
Base.convert(::Type{T}, p::T) where {PT,T <: SmartPointer{PT}} = p

# Conversion from non-const to const
Base.convert(::Type{CT}, p::T) where {PT,T <: SmartPointer{PT},CT <: SmartPointer{CxxConst{PT}}} = __cxxwrap_make_const_smartptr(ConstCxxRef(p))

# Construct from a related pointer, e.g. a std::weak_ptr from std::shared_ptr
function Base.convert(::Type{T1}, p::SmartPointer{T}) where {T, T1 <: SmartPointer{T}}
return __cxxwrap_smartptr_construct_from_other(T1, p)
return __cxxwrap_smartptr_construct_from_other(T1, CxxRef(p))
end

# Construct from a related pointer, e.g. a std::weak_ptr from std::shared_ptr. Const versions
function Base.convert(::Type{T1}, p::SmartPointer{CxxConst{T}}) where {T, T1 <: SmartPointer{CxxConst{T}}}
return __cxxwrap_smartptr_construct_from_other(T1, CxxRef(p))
end
# Avoid improper method resolution on Julia < 1.10
function Base.convert(::Type{T1}, p::T1) where {T, T1 <: SmartPointer{CxxConst{T}}}
return p
end

# upcast to base class
function Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{BaseT}}
function _base_convert_impl(::Type{T1}, p) where{T1}
# First convert to base type
base_p = __cxxwrap_smartptr_cast_to_base(p)
base_p = __cxxwrap_smartptr_cast_to_base(ConstCxxRef(p))
return convert(T1, base_p)
end

# upcast to base class, non-const version
Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{BaseT}} = _base_convert_impl(T1, p)
# upcast to base class, non-const to const version
Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{CxxConst{BaseT}}} = _base_convert_impl(T1, p)
# upcast to base, const version
Base.convert(::Type{T1}, p::SmartPointer{CxxConst{SuperT}}) where {BaseT, SuperT<:BaseT, T1<:SmartPointer{CxxConst{BaseT}}} = _base_convert_impl(T1, p)

struct StrictlyTypedNumber{NumberT}
value::NumberT
@static if Sys.iswindows()
Expand Down Expand Up @@ -362,7 +385,7 @@ function _register_function_pointers(func, precompiling)
if haskey(__global_method_map, mkey)
existing = __global_method_map[mkey]
if existing[3] == precompiling
error("Double registration for method $mkey")
error("Double registration for method $mkey: $(func.name); $(func.argument_types); $(func.return_type)")
end
end
__global_method_map[mkey] = fptrs
Expand Down Expand Up @@ -502,6 +525,7 @@ function ptrunion(::Type{T}) where {T}
return result
end

# valuetype is the non-reference, non-pointer type that is used in the method argument list. Overloading this allows mapping these types to something more useful
valuetype(t::Type) = valuetype(Base.invokelatest(cpp_trait_type,t), t)
valuetype(::Type{IsNormalType}, ::Type{T}) where {T} = T
function valuetype(::Type{IsCxxType}, ::Type{T}) where {T}
Expand All @@ -511,10 +535,16 @@ function valuetype(::Type{IsCxxType}, ::Type{T}) where {T}
end
return T
end
# Smart pointer arguments can also take subclass arguments
function valuetype(::Type{<:SmartPointer{T}}) where {T}
result{T2 <: T} = SmartPointer{T2}
return result
end
# Smart pointers with const arguments can also take both const and non-const subclass arguments
function valuetype(::Type{<:SmartPointer{CxxConst{T}}}) where {T}
result{T2 <: T} = Union{SmartPointer{T2},SmartPointer{CxxConst{T2}}}
return result
end

map_julia_arg_type(t::Type) = Union{valuetype(t), argument_overloads(t)...}
map_julia_arg_type(a::Type{StrictlyTypedNumber{T}}) where {T} = T
Expand Down Expand Up @@ -555,7 +585,8 @@ const __excluded_names = Set([
:cxxupcast,
:__cxxwrap_smartptr_dereference,
:__cxxwrap_smartptr_construct_from_other,
:__cxxwrap_smartptr_cast_to_base
:__cxxwrap_smartptr_cast_to_base,
:__cxxwrap_make_const_smartptr,
])

function Base.cconvert(to_type::Type{<:CxxBaseRef{T}}, x) where {T}
Expand Down
12 changes: 11 additions & 1 deletion test/basic_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ end

@testset "$(basename(@__FILE__)[1:end-3])" begin

function compare_collections(a, b)
equalities = a .== b
result = all(equalities)
if !result
neqs = (!).(equalities)
println("collections differ: $(a[neqs]) ≠ $(b[neqs])")
end
return result
end

let funcs = CxxWrap.CxxWrapCore.get_module_functions(CxxWrap.StdLib)
@test CxxWrap.StdLib.__cxxwrap_methodkeys[1] == CxxWrap.CxxWrapCore.methodkey(funcs[1])
@test all(CxxWrap.StdLib.__cxxwrap_methodkeys .== CxxWrap.CxxWrapCore.methodkey.(funcs))
@test compare_collections(CxxWrap.StdLib.__cxxwrap_methodkeys, CxxWrap.CxxWrapCore.methodkey.(funcs))
end

let a = BasicTypes.A(2,3)
Expand Down
17 changes: 10 additions & 7 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,6 @@ w_copy = copy(w)

# Destroy w: w and w_assigned should be dead, w_copy alive
finalize(w)
#finalize(w_lambda)
if !(Sys.iswindows() && Sys.WORD_SIZE == 32)
@test_throws ErrorException CppTypes.greet(w)
@test_throws ErrorException CppTypes.greet(w_assigned)
#@test_throws ErrorException CppTypes.greet(w_lambda)
end
@test CppTypes.greet(w_copy) == "constructed"
println("completed copy test")

Expand Down Expand Up @@ -241,7 +235,11 @@ empty!(warr1)

@test bench_greet() == 1000*length(CppTypes.greet(CppTypes.World()))
_, _, _, _, memallocs = @timed bench_greet()
@test 0 < memallocs.poolalloc < 100
@show memallocs.poolalloc
@test 0 < memallocs.poolalloc < 400 # Jumped from +/- 6 to 360 in Julia 1.12
if memallocs.poolalloc > 100
@warn "Abnormally high number of allocations: $(memallocs.poolalloc)"
end

if isdefined(CppTypes, :IntDerived)
Base.promote_rule(::Type{<:CppTypes.IntDerived}, ::Type{<:Number}) = Int
Expand Down Expand Up @@ -299,3 +297,8 @@ let cd1 = CppTypes.UseCustomDelete(), cd2 = CppTypes.UseCustomClassDelete()
finalize(cd2)
@test CppTypes.get_custom_class_nb_deletes() == 1
end

let v = CppTypes.shared_vector_factory(), cv = CppTypes.shared_const_vector_factory()
@test CppTypes.get_shared_vector_msg(v) == "shared vector hello"
@test CppTypes.get_shared_vector_msg(cv) == "shared vector const hello from const overload"
end
Loading