diff --git a/Project.toml b/Project.toml index 3e3722c..1cd6885 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,16 @@ name = "ParallelUtilities" uuid = "fad6cfc8-4f83-11e9-06cc-151124046ad0" authors = ["Jishnu Bhattacharya"] -version = "0.8.5" +version = "0.8.6" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e" [compat] DataStructures = "0.17, 0.18" +SplittablesBase = "0.1" julia = "1.2" [extras] diff --git a/src/ParallelUtilities.jl b/src/ParallelUtilities.jl index 7acebb8..713d604 100644 --- a/src/ParallelUtilities.jl +++ b/src/ParallelUtilities.jl @@ -1,6 +1,7 @@ module ParallelUtilities using Distributed +using SplittablesBase export pmapreduce export pmapreduce_productsplit diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 6b73f60..fca8de1 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -37,21 +37,44 @@ getiterators(h::Hold) = getiterators(h.iterators) Base.length(h::Hold) = length(h.iterators) -check_knownsize(iterators::Tuple) = _check_knownsize(first(iterators)) & check_knownsize(Base.tail(iterators)) -check_knownsize(::Tuple{}) = true -function _check_knownsize(iterator) +function check_knownsize(iterator) itsz = Base.IteratorSize(iterator) itsz isa Base.HasLength || itsz isa Base.HasShape end -function zipsplit(iterators::Tuple, np::Integer, p::Integer) - check_knownsize(iterators) - itzip = zip(iterators...) +struct ZipSplit{Z, I} + z :: Z + it :: I + skip :: Int + N :: Int +end + +# This constructor differs from zipsplit, as it uses skipped and retained elements +# and not p and np. This type is added to increase compatibility with SplittablesBase +function ZipSplit(itzip, skipped_elements::Integer, elements_on_proc::Integer) + it = Iterators.take(Iterators.drop(itzip, skipped_elements), elements_on_proc) + ZipSplit{typeof(itzip), typeof(it)}(itzip, it, skipped_elements, elements_on_proc) +end + +Base.length(zs::ZipSplit) = length(zs.it) +Base.eltype(zs::ZipSplit) = eltype(zs.it) +Base.iterate(z::ZipSplit, i...) = iterate(takedrop(z), i...) +takedrop(zs::ZipSplit) = zs.it + +function SplittablesBase.halve(zs::ZipSplit) + nleft = zs.N ÷ 2 + ZipSplit(zs.z, zs.skip, nleft), ZipSplit(zs.z, zs.skip + nleft, zs.N - nleft) +end + +zipsplit(iterators::Tuple, np::Integer, p::Integer) = zipsplit(zip(iterators...), np, p) + +function zipsplit(itzip::Iterators.Zip, np::Integer, p::Integer) + check_knownsize(itzip) d,r = divrem(length(itzip), np) skipped_elements = d*(p-1) + min(r,p-1) lastind = d*p + min(r,p) elements_on_proc = lastind - skipped_elements - Iterators.take(Iterators.drop(itzip, skipped_elements), elements_on_proc) + ZipSplit(itzip, skipped_elements, elements_on_proc) end _split_iterators(iterators, np, p) = (zipsplit(iterators, np, p),) diff --git a/src/productsplit.jl b/src/productsplit.jl index a5dffc3..b332ac0 100644 --- a/src/productsplit.jl +++ b/src/productsplit.jl @@ -217,6 +217,21 @@ Base.lastindex(ps::AbstractConstrainedProduct) = lastindexglobal(ps) - firstinde firstindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).firstind lastindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).lastind +# SplittablesBase interface +function SplittablesBase.halve(ps::AbstractConstrainedProduct) + iter = getiterators(ps) + firstind = firstindexglobal(ps) + lastind = lastindexglobal(ps) + nleft = length(ps) ÷ 2 + firstindleft = firstind + lastindleft = firstind + nleft - 1 + firstindright = lastindleft + 1 + lastindright = lastind + tl = togglelevels(ps) + ProductSection(iter, tl, firstindleft, lastindleft), + ProductSection(iter, tl, firstindright, lastindright) +end + """ childindex(ps::AbstractConstrainedProduct, ind) diff --git a/test/Project.toml b/test/Project.toml index 817c485..ea8f759 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/misctests_singleprocess.jl b/test/misctests_singleprocess.jl index ea45f44..beec03a 100644 --- a/test/misctests_singleprocess.jl +++ b/test/misctests_singleprocess.jl @@ -9,7 +9,11 @@ ProductSplit, SegmentedOrderedBinaryTree import ParallelUtilities.ClusterQueryUtils: chooseworkers @testset "Project quality" begin - Aqua.test_all(ParallelUtilities) + if VERSION < v"1.6.0" + Aqua.test_all(ParallelUtilities, ambiguities=false) + else + Aqua.test_all(ParallelUtilities) + end end DocMeta.setdocmeta!(ParallelUtilities, :DocTestSetup, :(using ParallelUtilities); recursive=true) diff --git a/test/productsplit.jl b/test/productsplit.jl index 0ee7f8e..e11c3b4 100644 --- a/test/productsplit.jl +++ b/test/productsplit.jl @@ -1,10 +1,11 @@ using Distributed using Test using ParallelUtilities -import ParallelUtilities: ProductSplit, ProductSection, +import ParallelUtilities: ProductSplit, ProductSection, ZipSplit, zipsplit, minimumelement, maximumelement, extremaelement, nelements, dropleading, indexinproduct, extremadims, localindex, extrema_commonlastdim, whichproc, procrange_recast, whichproc_localindex, getiterators, _niterators +using SplittablesBase macro testsetwithinfo(str, ex) quote @@ -423,6 +424,17 @@ end @test nelements(ps, dims = 3) == 1 end + @testset "SplittablesBase" begin + for iters in [(1:4, 1:3), (1:4, 1:4)] + for ps = Any[ProductSplit(iters, 3, 2), ProductSection(iters, 3:8)] + l, r = SplittablesBase.halve(ps) + lc, rc = SplittablesBase.halve(collect(ps)) + @test collect(l) == lc + @test collect(r) == rc + end + end + end + @test ParallelUtilities._checknorollover((), (), ()) end; @@ -453,3 +465,14 @@ end; @test a <= b end end; + +@testset "ZipSplit" begin + @testset "SplittablesBase" begin + for ps in [zipsplit((1:4, 1:4), 3, 2), zipsplit((1:5, 1:5), 3, 2)] + l, r = SplittablesBase.halve(ps) + lc, rc = SplittablesBase.halve(collect(ps)) + @test collect(l) == lc + @test collect(r) == rc + end + end +end