Skip to content

Commit

Permalink
Properly deal with const template parameters
Browse files Browse the repository at this point in the history
Issue #405
  • Loading branch information
barche committed Feb 11, 2024
1 parent 447cca2 commit d446d22
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
51 changes: 41 additions & 10 deletions src/CxxWrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ Base.flipsign(x::T, y::T) where {T <: CxxSigned} = reinterpret(T, flipsign(to_ju
struct IsCxxType end
struct IsNormalType end

struct CxxConst{T}
cpp_object::Ptr{T}
end

@inline cpp_trait_type(::Type) = IsNormalType

# Enum type interface
Expand All @@ -152,29 +156,48 @@ Base.show(io::IO, x::SmartPointer) = print(io, "C++ smart pointer of type ", typ
allocated_type(t::Type) = Any
dereferenced_type(t::Type) = Any

__cxxwrap_smartptr_dereference(p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_dereference(CxxRef(p))
__cxxwrap_smartptr_construct_from_other(t::Type{<:SmartPointer{T}}, p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_construct_from_other(t,CxxRef(p))
__cxxwrap_smartptr_cast_to_base(p::SmartPointer{T}) where {T} = __cxxwrap_smartptr_cast_to_base(CxxRef(p))
function __cxxwrap_smartptr_dereference end
function __cxxwrap_smartptr_construct_from_other end
function __cxxwrap_smartptr_cast_to_base end
function __cxxwrap_make_const_smartptr end

function Base.getindex(p::SmartPointer{T}) where {T}
return __cxxwrap_smartptr_dereference(p)
return __cxxwrap_smartptr_dereference(CxxRef(p))
end

# No conversion if source and target type are identical
Base.convert(::Type{T}, p::T) where {PT,T <: SmartPointer{PT}} = p

# Conversion from non-const to const
Base.convert(::Type{CT}, p::T) where {PT,T <: SmartPointer{PT},CT <: SmartPointer{CxxConst{PT}}} = __cxxwrap_make_const_smartptr(ConstCxxRef(p))

# Construct from a related pointer, e.g. a std::weak_ptr from std::shared_ptr
function Base.convert(::Type{T1}, p::SmartPointer{T}) where {T, T1 <: SmartPointer{T}}
return __cxxwrap_smartptr_construct_from_other(T1, p)
return __cxxwrap_smartptr_construct_from_other(T1, CxxRef(p))
end

# Construct from a related pointer, e.g. a std::weak_ptr from std::shared_ptr. Const versions
function Base.convert(::Type{T1}, p::SmartPointer{CxxConst{T}}) where {T, T1 <: SmartPointer{CxxConst{T}}}
return __cxxwrap_smartptr_construct_from_other(T1, CxxRef(p))
end
# Avoid improper method resolution on Julia < 1.10
function Base.convert(::Type{T1}, p::T1) where {T, T1 <: SmartPointer{CxxConst{T}}}
return p
end

# upcast to base class
function Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{BaseT}}
function _base_convert_impl(::Type{T1}, p) where{T1}
# First convert to base type
base_p = __cxxwrap_smartptr_cast_to_base(p)
base_p = __cxxwrap_smartptr_cast_to_base(ConstCxxRef(p))
return convert(T1, base_p)
end

# upcast to base class, non-const version
Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{BaseT}} = _base_convert_impl(T1, p)
# upcast to base class, non-const to const version
Base.convert(::Type{T1}, p::SmartPointer{<:BaseT}) where {BaseT, T1 <: SmartPointer{CxxConst{BaseT}}} = _base_convert_impl(T1, p)
# upcast to base, const version
Base.convert(::Type{T1}, p::SmartPointer{CxxConst{SuperT}}) where {BaseT, SuperT<:BaseT, T1<:SmartPointer{CxxConst{BaseT}}} = _base_convert_impl(T1, p)

struct StrictlyTypedNumber{NumberT}
value::NumberT
@static if Sys.iswindows()
Expand Down Expand Up @@ -362,7 +385,7 @@ function _register_function_pointers(func, precompiling)
if haskey(__global_method_map, mkey)
existing = __global_method_map[mkey]
if existing[3] == precompiling
error("Double registration for method $mkey")
error("Double registration for method $mkey: $(func.name); $(func.argument_types); $(func.return_type)")
end
end
__global_method_map[mkey] = fptrs
Expand Down Expand Up @@ -502,6 +525,7 @@ function ptrunion(::Type{T}) where {T}
return result
end

# valuetype is the non-reference, non-pointer type that is used in the method argument list. Overloading this allows mapping these types to something more useful
valuetype(t::Type) = valuetype(Base.invokelatest(cpp_trait_type,t), t)
valuetype(::Type{IsNormalType}, ::Type{T}) where {T} = T
function valuetype(::Type{IsCxxType}, ::Type{T}) where {T}
Expand All @@ -511,10 +535,16 @@ function valuetype(::Type{IsCxxType}, ::Type{T}) where {T}
end
return T
end
# Smart pointer arguments can also take subclass arguments
function valuetype(::Type{<:SmartPointer{T}}) where {T}
result{T2 <: T} = SmartPointer{T2}
return result
end
# Smart pointers with const arguments can also take both const and non-const subclass arguments
function valuetype(::Type{<:SmartPointer{CxxConst{T}}}) where {T}
result{T2 <: T} = Union{SmartPointer{T2},SmartPointer{CxxConst{T2}}}
return result
end

map_julia_arg_type(t::Type) = Union{valuetype(t), argument_overloads(t)...}
map_julia_arg_type(a::Type{StrictlyTypedNumber{T}}) where {T} = T
Expand Down Expand Up @@ -555,7 +585,8 @@ const __excluded_names = Set([
:cxxupcast,
:__cxxwrap_smartptr_dereference,
:__cxxwrap_smartptr_construct_from_other,
:__cxxwrap_smartptr_cast_to_base
:__cxxwrap_smartptr_cast_to_base,
:__cxxwrap_make_const_smartptr,
])

function Base.cconvert(to_type::Type{<:CxxBaseRef{T}}, x) where {T}
Expand Down
11 changes: 5 additions & 6 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,6 @@ w_copy = copy(w)

# Destroy w: w and w_assigned should be dead, w_copy alive
finalize(w)
#finalize(w_lambda)
if !(Sys.iswindows() && Sys.WORD_SIZE == 32)
@test_throws ErrorException CppTypes.greet(w)
@test_throws ErrorException CppTypes.greet(w_assigned)
#@test_throws ErrorException CppTypes.greet(w_lambda)
end
@test CppTypes.greet(w_copy) == "constructed"
println("completed copy test")

Expand Down Expand Up @@ -299,3 +293,8 @@ let cd1 = CppTypes.UseCustomDelete(), cd2 = CppTypes.UseCustomClassDelete()
finalize(cd2)
@test CppTypes.get_custom_class_nb_deletes() == 1
end

let v = CppTypes.shared_vector_factory(), cv = CppTypes.shared_const_vector_factory()
@test CppTypes.get_shared_vector_msg(v) == "shared vector hello"
@test CppTypes.get_shared_vector_msg(cv) == "shared vector const hello from const overload"
end

0 comments on commit d446d22

Please sign in to comment.