From a4e4834d9d5322e76ad17bb729e5b26b5b812530 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 19 Feb 2024 15:01:27 +0100 Subject: [PATCH] Fix type-instabilities due to function composition Fixes Leaves iterator type-unstable? #141 --- src/base.jl | 2 +- src/cursors.jl | 30 +++++++++++++++--------------- src/iteration.jl | 12 ++++++------ test/trees.jl | 36 ++++++++++++++++++------------------ 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/base.jl b/src/base.jl index aa62063..b6bdbba 100644 --- a/src/base.jl +++ b/src/base.jl @@ -100,7 +100,7 @@ or `parent(root, x) ≡ nothing`. That is, while any node is the root of some t true for nodes which have parents which cannot be obtained with the `AbstractTrees` interface. """ isroot(root, x) = isnothing(parent(root, x)) -isroot(x) = (isnothing ∘ parent)(x) +isroot(x) = isnothing(parent(x)) """ intree(node, root; equiv=(≡)) diff --git a/src/cursors.jl b/src/cursors.jl index 0229bcc..7d3ab17 100644 --- a/src/cursors.jl +++ b/src/cursors.jl @@ -60,10 +60,10 @@ parenttype(csr::TreeCursor) = parenttype(typeof(csr)) # this is a fallback and may not always be the case Base.IteratorSize(::Type{<:TreeCursor{N,P}}) where {N,P} = Base.IteratorSize(childrentype(N)) -Base.length(tc::TreeCursor) = (length ∘ children ∘ nodevalue)(tc) +Base.length(tc::TreeCursor) = length(children(nodevalue(tc))) # this is needed in case an iterator declares IteratorSize to be HasSize -Base.size(tc::TreeCursor) = (size ∘ children ∘ nodevalue)(tc) +Base.size(tc::TreeCursor) = size(children(nodevalue(tc))) Base.IteratorEltype(::Type{<:TreeCursor}) = EltypeUnknown() @@ -112,7 +112,7 @@ end TrivialCursor(node) = TrivialCursor(parent(node), node) function Base.iterate(csr::TrivialCursor, s=InitialState()) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) r = s isa InitialState ? iterate(cs) : iterate(cs, s) isnothing(r) && return nothing (n′, s′) = r @@ -159,12 +159,12 @@ function Base.eltype(::Type{ImplicitCursor{N,P,S}}) where {N,P,S} end function Base.eltype(csr::ImplicitCursor) - cst = (childstatetype ∘ parent ∘ nodevalue)(csr) + cst = childstatetype(parent(nodevalue(csr))) ImplicitCursor{childtype(nodevalue(csr)),nodevaluetype(csr),cst} end function Base.iterate(csr::ImplicitCursor, s=InitialState()) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) # do NOT just write an iterate(x, ::InitialState) method, it's an ambiguity nightmare r = s isa InitialState ? iterate(cs) : iterate(cs, s) isnothing(r) && return nothing @@ -202,22 +202,22 @@ struct IndexedCursor{N,P} <: TreeCursor{N,P} IndexedCursor(p::Union{Nothing,IndexedCursor}, n, idx::Integer=1) = new{typeof(n),typeof(nodevalue(p))}(p, n, idx) end -IndexedCursor(node) = IndexedCursor(nothing, node) +IndexedCursor(node) = IndexedCursor(nothing, node) Base.IteratorSize(::Type{<:IndexedCursor}) = HasLength() Base.eltype(::Type{IndexedCursor{N,P}}) where {N,P} = IndexedCursor{childtype(N),N} Base.eltype(csr::IndexedCursor) = IndexedCursor{childtype(nodevalue(csr)),nodevaluetype(csr)} -Base.length(csr::IndexedCursor) = (length ∘ children ∘ nodevalue)(csr) +Base.length(csr::IndexedCursor) = length(children(nodevalue(csr))) function Base.getindex(csr::IndexedCursor, idx) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) IndexedCursor(csr, cs[idx], idx) end function Base.iterate(csr::IndexedCursor, idx=1) idx > length(csr) && return nothing - (csr[idx], idx+1) + (csr[idx], idx + 1) end function nextsibling(csr::IndexedCursor) @@ -254,7 +254,7 @@ Base.IteratorEltype(::Type{<:SiblingCursor}) = HasEltype() Base.eltype(::Type{SiblingCursor{N,P}}) where {N,P} = SiblingCursor{childtype(N),N} function Base.iterate(csr::SiblingCursor, s=InitialState()) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) r = s isa InitialState ? iterate(cs) : iterate(cs, s) isnothing(r) && return nothing (n′, s′) = r @@ -283,7 +283,7 @@ struct StableCursor{N,S} <: TreeCursor{N,N} # note that this very deliberately takes childstatetype(n) and *not* childstatetype(p) # this is because p may be nothing StableCursor(::Nothing, n, st) = new{typeof(n),childstatetype(n)}(nothing, n, st) - + # this method is important for eliminating expensive calls to childstatetype StableCursor(p::StableCursor{N,S}, n, st) where {N,S} = new{N,S}(p, n, st) end @@ -295,7 +295,7 @@ Base.IteratorEltype(::Type{<:StableCursor}) = HasEltype() Base.eltype(::Type{T}) where {T<:StableCursor} = T function Base.iterate(csr::StableCursor, s=InitialState()) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) r = s isa InitialState ? iterate(cs) : iterate(cs, s) isnothing(r) && return nothing (n′, s′) = r @@ -331,16 +331,16 @@ Base.IteratorEltype(::Type{<:StableIndexedCursor}) = HasEltype() Base.eltype(::Type{T}) where {T<:StableIndexedCursor} = T -Base.length(csr::StableIndexedCursor) = (length ∘ children ∘ nodevalue)(csr) +Base.length(csr::StableIndexedCursor) = length(children(nodevalue(csr))) function Base.getindex(csr::StableIndexedCursor, idx) - cs = (children ∘ nodevalue)(csr) + cs = children(nodevalue(csr)) StableIndexedCursor(csr, cs[idx], idx) end function Base.iterate(csr::StableIndexedCursor, idx=1) idx > length(csr) && return nothing - (csr[idx], idx+1) + (csr[idx], idx + 1) end function nextsibling(csr::StableIndexedCursor) diff --git a/src/iteration.jl b/src/iteration.jl index 32a3ac0..fe01eac 100644 --- a/src/iteration.jl +++ b/src/iteration.jl @@ -100,7 +100,7 @@ struct PreOrderState{T<:TreeCursor} <: IteratorState{T} PreOrderState(csr::TreeCursor) = new{typeof(csr)}(csr) end -PreOrderState(node) = (PreOrderState ∘ TreeCursor)(node) +PreOrderState(node) = PreOrderState(TreeCursor(node)) initial(::Type{PreOrderState}, node) = PreOrderState(node) @@ -175,9 +175,9 @@ struct PostOrderState{T<:TreeCursor} <: IteratorState{T} PostOrderState(csr::TreeCursor) = new{typeof(csr)}(csr) end -PostOrderState(node) = (PostOrderState ∘ TreeCursor)(node) +PostOrderState(node) = PostOrderState(TreeCursor(node)) -initial(::Type{PostOrderState}, node) = (PostOrderState ∘ descendleft ∘ TreeCursor)(node) +initial(::Type{PostOrderState}, node) = PostOrderState(descendleft(TreeCursor(node))) function next(s::PostOrderState) n = nextsibling(s.cursor) @@ -226,9 +226,9 @@ struct LeavesState{T<:TreeCursor} <: IteratorState{T} LeavesState(csr::TreeCursor) = new{typeof(csr)}(csr) end -LeavesState(node) = (LeavesState ∘ TreeCursor)(node) +LeavesState(node) = LeavesState(TreeCursor(node)) -initial(::Type{LeavesState}, node) = (LeavesState ∘ descendleft ∘ TreeCursor)(node) +initial(::Type{LeavesState}, node) = LeavesState(descendleft(TreeCursor(node))) function next(s::LeavesState) csr = s.cursor @@ -278,7 +278,7 @@ struct SiblingState{T<:TreeCursor} <: IteratorState{T} SiblingState(csr::TreeCursor) = new{typeof(csr)}(csr) end -SiblingState(node) = (SiblingState ∘ TreeCursor)(node) +SiblingState(node) = SiblingState(TreeCursor(node)) initial(::Type{SiblingState}, node) = SiblingState(node) diff --git a/test/trees.jl b/test/trees.jl index c3f80ba..9dc629b 100644 --- a/test/trees.jl +++ b/test/trees.jl @@ -1,7 +1,7 @@ using AbstractTrees using Test -include(joinpath(@__DIR__,"examples","idtree.jl")) +include(joinpath(@__DIR__, "examples", "idtree.jl")) @testset "IDTree" begin tree = IDTree(1 => [ @@ -25,9 +25,9 @@ include(joinpath(@__DIR__,"examples","idtree.jl")) # Node/subtree properties # 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 - @test treesize.(nodes) == [16, 4, 1, 2, 1, 1, 9, 8, 1, 1, 4, 1, 1, 1, 1, 1] + @test treesize.(nodes) == [16, 4, 1, 2, 1, 1, 9, 8, 1, 1, 4, 1, 1, 1, 1, 1] @test treebreadth.(nodes) == [10, 2, 1, 1, 1, 1, 6, 6, 1, 1, 3, 1, 1, 1, 1, 1] - @test treeheight.(nodes) == [ 4, 2, 0, 1, 0, 0, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0] + @test treeheight.(nodes) == [4, 2, 0, 1, 0, 0, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0] # Child/descendant checking @test ischild(nodes[2], nodes[1]) @@ -61,44 +61,44 @@ include(joinpath(@__DIR__,"examples","idtree.jl")) @test [n.id for n in Leaves(tree.root)] == [3, 5, 6, 9, 10, 12, 13, 14, 15, 16] end -include(joinpath(@__DIR__,"examples","onenode.jl")) +include(joinpath(@__DIR__, "examples", "onenode.jl")) @testset "OneNode" begin - ot = OneNode([2,3,4,0], 1) + ot = OneNode([2, 3, 4, 0], 1) @inferred collect(Leaves(ot)) @test nodevalue.(collect(Leaves(ot))) == [0] @test eltype(nodevalue.(collect(Leaves(ot)))) ≡ Int - @test nodevalue.(collect(PreOrderDFS(ot))) == [2,3,4,0] - @test nodevalue.(collect(PostOrderDFS(ot))) == [0,4,3,2] + @test nodevalue.(collect(PreOrderDFS(ot))) == [2, 3, 4, 0] + @test nodevalue.(collect(PostOrderDFS(ot))) == [0, 4, 3, 2] end -include(joinpath(@__DIR__,"examples","onetree.jl")) +include(joinpath(@__DIR__, "examples", "onetree.jl")) @testset "OneTree" begin - ot = OneTree([2,3,4,0]) + ot = OneTree([2, 3, 4, 0]) n = IndexNode(ot) @inferred collect(Leaves(n)) @test nodevalue.(collect(Leaves(n))) == [0] @test eltype(nodevalue.(collect(Leaves(n)))) ≡ Int - @test nodevalue.(collect(PreOrderDFS(n))) == [2,3,4,0] - @test nodevalue.(collect(PostOrderDFS(n))) == [0,4,3,2] + @test nodevalue.(collect(PreOrderDFS(n))) == [2, 3, 4, 0] + @test nodevalue.(collect(PostOrderDFS(n))) == [0, 4, 3, 2] end -include(joinpath(@__DIR__,"examples","fstree.jl")) +include(joinpath(@__DIR__, "examples", "fstree.jl")) @testset "FSNode" begin Base.VERSION >= v"1.6" && mk_tree_test_dir() do path tree = Directory(".") - ls = nodevalue.((collect ∘ Leaves)(tree)) + ls = nodevalue.(collect(Leaves(tree))) # use set so we don't have to guarantee ordering - @test Set(ls) == Set([joinpath(".","A","f2"), joinpath(".","B"), joinpath(".","f1")]) + @test Set(ls) == Set([joinpath(".", "A", "f2"), joinpath(".", "B"), joinpath(".", "f1")]) @test treeheight(tree) == 2 end end -include(joinpath(@__DIR__,"examples","binarytree.jl")) +include(joinpath(@__DIR__, "examples", "binarytree.jl")) @testset "BinaryNode" begin t = binarynode_example() @@ -107,12 +107,12 @@ include(joinpath(@__DIR__,"examples","binarytree.jl")) @test nodevalue.(ls) == [3, 2] predfs = @inferred collect(PreOrderDFS(t)) - @test nodevalue.(predfs) == [0,1,3,2] + @test nodevalue.(predfs) == [0, 1, 3, 2] postdfs = @inferred collect(PostOrderDFS(t)) - @test nodevalue.(postdfs) == [3,1,2,0] + @test nodevalue.(postdfs) == [3, 1, 2, 0] sbfs = @inferred collect(StatelessBFS(t)) - @test nodevalue.(sbfs) == [0,1,2,3] + @test nodevalue.(sbfs) == [0, 1, 2, 3] end