Skip to content

Commit

Permalink
Merge pull request #142 from mofeing/fix-type-instability-due-composi…
Browse files Browse the repository at this point in the history
…tion

Fix type-instabilities due to function composition
  • Loading branch information
oscardssmith authored Feb 19, 2024
2 parents a3e43a8 + a4e4834 commit 0356cc1
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ or `parent(root, x) ≡ nothing`. That is, while any node is the root of some t
true for nodes which have parents which cannot be obtained with the `AbstractTrees` interface.
"""
isroot(root, x) = isnothing(parent(root, x))
isroot(x) = (isnothing parent)(x)
isroot(x) = isnothing(parent(x))

"""
intree(node, root; equiv=(≡))
Expand Down
30 changes: 15 additions & 15 deletions src/cursors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ parenttype(csr::TreeCursor) = parenttype(typeof(csr))
# this is a fallback and may not always be the case
Base.IteratorSize(::Type{<:TreeCursor{N,P}}) where {N,P} = Base.IteratorSize(childrentype(N))

Base.length(tc::TreeCursor) = (length children nodevalue)(tc)
Base.length(tc::TreeCursor) = length(children(nodevalue(tc)))

# this is needed in case an iterator declares IteratorSize to be HasSize
Base.size(tc::TreeCursor) = (size children nodevalue)(tc)
Base.size(tc::TreeCursor) = size(children(nodevalue(tc)))

Base.IteratorEltype(::Type{<:TreeCursor}) = EltypeUnknown()

Expand Down Expand Up @@ -112,7 +112,7 @@ end
TrivialCursor(node) = TrivialCursor(parent(node), node)

function Base.iterate(csr::TrivialCursor, s=InitialState())
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
r = s isa InitialState ? iterate(cs) : iterate(cs, s)
isnothing(r) && return nothing
(n′, s′) = r
Expand Down Expand Up @@ -159,12 +159,12 @@ function Base.eltype(::Type{ImplicitCursor{N,P,S}}) where {N,P,S}
end

function Base.eltype(csr::ImplicitCursor)
cst = (childstatetype parent nodevalue)(csr)
cst = childstatetype(parent(nodevalue(csr)))
ImplicitCursor{childtype(nodevalue(csr)),nodevaluetype(csr),cst}
end

function Base.iterate(csr::ImplicitCursor, s=InitialState())
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
# do NOT just write an iterate(x, ::InitialState) method, it's an ambiguity nightmare
r = s isa InitialState ? iterate(cs) : iterate(cs, s)
isnothing(r) && return nothing
Expand Down Expand Up @@ -202,22 +202,22 @@ struct IndexedCursor{N,P} <: TreeCursor{N,P}
IndexedCursor(p::Union{Nothing,IndexedCursor}, n, idx::Integer=1) = new{typeof(n),typeof(nodevalue(p))}(p, n, idx)
end

IndexedCursor(node) = IndexedCursor(nothing, node)
IndexedCursor(node) = IndexedCursor(nothing, node)

Base.IteratorSize(::Type{<:IndexedCursor}) = HasLength()

Base.eltype(::Type{IndexedCursor{N,P}}) where {N,P} = IndexedCursor{childtype(N),N}
Base.eltype(csr::IndexedCursor) = IndexedCursor{childtype(nodevalue(csr)),nodevaluetype(csr)}
Base.length(csr::IndexedCursor) = (length children nodevalue)(csr)
Base.length(csr::IndexedCursor) = length(children(nodevalue(csr)))

function Base.getindex(csr::IndexedCursor, idx)
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
IndexedCursor(csr, cs[idx], idx)
end

function Base.iterate(csr::IndexedCursor, idx=1)
idx > length(csr) && return nothing
(csr[idx], idx+1)
(csr[idx], idx + 1)
end

function nextsibling(csr::IndexedCursor)
Expand Down Expand Up @@ -254,7 +254,7 @@ Base.IteratorEltype(::Type{<:SiblingCursor}) = HasEltype()
Base.eltype(::Type{SiblingCursor{N,P}}) where {N,P} = SiblingCursor{childtype(N),N}

function Base.iterate(csr::SiblingCursor, s=InitialState())
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
r = s isa InitialState ? iterate(cs) : iterate(cs, s)
isnothing(r) && return nothing
(n′, s′) = r
Expand Down Expand Up @@ -283,7 +283,7 @@ struct StableCursor{N,S} <: TreeCursor{N,N}
# note that this very deliberately takes childstatetype(n) and *not* childstatetype(p)
# this is because p may be nothing
StableCursor(::Nothing, n, st) = new{typeof(n),childstatetype(n)}(nothing, n, st)

# this method is important for eliminating expensive calls to childstatetype
StableCursor(p::StableCursor{N,S}, n, st) where {N,S} = new{N,S}(p, n, st)
end
Expand All @@ -295,7 +295,7 @@ Base.IteratorEltype(::Type{<:StableCursor}) = HasEltype()
Base.eltype(::Type{T}) where {T<:StableCursor} = T

function Base.iterate(csr::StableCursor, s=InitialState())
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
r = s isa InitialState ? iterate(cs) : iterate(cs, s)
isnothing(r) && return nothing
(n′, s′) = r
Expand Down Expand Up @@ -331,16 +331,16 @@ Base.IteratorEltype(::Type{<:StableIndexedCursor}) = HasEltype()

Base.eltype(::Type{T}) where {T<:StableIndexedCursor} = T

Base.length(csr::StableIndexedCursor) = (length children nodevalue)(csr)
Base.length(csr::StableIndexedCursor) = length(children(nodevalue(csr)))

function Base.getindex(csr::StableIndexedCursor, idx)
cs = (children nodevalue)(csr)
cs = children(nodevalue(csr))
StableIndexedCursor(csr, cs[idx], idx)
end

function Base.iterate(csr::StableIndexedCursor, idx=1)
idx > length(csr) && return nothing
(csr[idx], idx+1)
(csr[idx], idx + 1)
end

function nextsibling(csr::StableIndexedCursor)
Expand Down
12 changes: 6 additions & 6 deletions src/iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct PreOrderState{T<:TreeCursor} <: IteratorState{T}
PreOrderState(csr::TreeCursor) = new{typeof(csr)}(csr)
end

PreOrderState(node) = (PreOrderState TreeCursor)(node)
PreOrderState(node) = PreOrderState(TreeCursor(node))

initial(::Type{PreOrderState}, node) = PreOrderState(node)

Expand Down Expand Up @@ -175,9 +175,9 @@ struct PostOrderState{T<:TreeCursor} <: IteratorState{T}
PostOrderState(csr::TreeCursor) = new{typeof(csr)}(csr)
end

PostOrderState(node) = (PostOrderState TreeCursor)(node)
PostOrderState(node) = PostOrderState(TreeCursor(node))

initial(::Type{PostOrderState}, node) = (PostOrderState descendleft TreeCursor)(node)
initial(::Type{PostOrderState}, node) = PostOrderState(descendleft(TreeCursor(node)))

function next(s::PostOrderState)
n = nextsibling(s.cursor)
Expand Down Expand Up @@ -226,9 +226,9 @@ struct LeavesState{T<:TreeCursor} <: IteratorState{T}
LeavesState(csr::TreeCursor) = new{typeof(csr)}(csr)
end

LeavesState(node) = (LeavesState TreeCursor)(node)
LeavesState(node) = LeavesState(TreeCursor(node))

initial(::Type{LeavesState}, node) = (LeavesState descendleft TreeCursor)(node)
initial(::Type{LeavesState}, node) = LeavesState(descendleft(TreeCursor(node)))

function next(s::LeavesState)
csr = s.cursor
Expand Down Expand Up @@ -278,7 +278,7 @@ struct SiblingState{T<:TreeCursor} <: IteratorState{T}
SiblingState(csr::TreeCursor) = new{typeof(csr)}(csr)
end

SiblingState(node) = (SiblingState TreeCursor)(node)
SiblingState(node) = SiblingState(TreeCursor(node))

initial(::Type{SiblingState}, node) = SiblingState(node)

Expand Down
36 changes: 18 additions & 18 deletions test/trees.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using AbstractTrees
using Test

include(joinpath(@__DIR__,"examples","idtree.jl"))
include(joinpath(@__DIR__, "examples", "idtree.jl"))

@testset "IDTree" begin
tree = IDTree(1 => [
Expand All @@ -25,9 +25,9 @@ include(joinpath(@__DIR__,"examples","idtree.jl"))

# Node/subtree properties
# 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
@test treesize.(nodes) == [16, 4, 1, 2, 1, 1, 9, 8, 1, 1, 4, 1, 1, 1, 1, 1]
@test treesize.(nodes) == [16, 4, 1, 2, 1, 1, 9, 8, 1, 1, 4, 1, 1, 1, 1, 1]
@test treebreadth.(nodes) == [10, 2, 1, 1, 1, 1, 6, 6, 1, 1, 3, 1, 1, 1, 1, 1]
@test treeheight.(nodes) == [ 4, 2, 0, 1, 0, 0, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0]
@test treeheight.(nodes) == [4, 2, 0, 1, 0, 0, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0]

# Child/descendant checking
@test ischild(nodes[2], nodes[1])
Expand Down Expand Up @@ -61,44 +61,44 @@ include(joinpath(@__DIR__,"examples","idtree.jl"))
@test [n.id for n in Leaves(tree.root)] == [3, 5, 6, 9, 10, 12, 13, 14, 15, 16]
end

include(joinpath(@__DIR__,"examples","onenode.jl"))
include(joinpath(@__DIR__, "examples", "onenode.jl"))

@testset "OneNode" begin
ot = OneNode([2,3,4,0], 1)
ot = OneNode([2, 3, 4, 0], 1)
@inferred collect(Leaves(ot))
@test nodevalue.(collect(Leaves(ot))) == [0]
@test eltype(nodevalue.(collect(Leaves(ot)))) Int
@test nodevalue.(collect(PreOrderDFS(ot))) == [2,3,4,0]
@test nodevalue.(collect(PostOrderDFS(ot))) == [0,4,3,2]
@test nodevalue.(collect(PreOrderDFS(ot))) == [2, 3, 4, 0]
@test nodevalue.(collect(PostOrderDFS(ot))) == [0, 4, 3, 2]
end

include(joinpath(@__DIR__,"examples","onetree.jl"))
include(joinpath(@__DIR__, "examples", "onetree.jl"))

@testset "OneTree" begin
ot = OneTree([2,3,4,0])
ot = OneTree([2, 3, 4, 0])
n = IndexNode(ot)

@inferred collect(Leaves(n))
@test nodevalue.(collect(Leaves(n))) == [0]
@test eltype(nodevalue.(collect(Leaves(n)))) Int
@test nodevalue.(collect(PreOrderDFS(n))) == [2,3,4,0]
@test nodevalue.(collect(PostOrderDFS(n))) == [0,4,3,2]
@test nodevalue.(collect(PreOrderDFS(n))) == [2, 3, 4, 0]
@test nodevalue.(collect(PostOrderDFS(n))) == [0, 4, 3, 2]
end

include(joinpath(@__DIR__,"examples","fstree.jl"))
include(joinpath(@__DIR__, "examples", "fstree.jl"))

@testset "FSNode" begin
Base.VERSION >= v"1.6" && mk_tree_test_dir() do path
tree = Directory(".")

ls = nodevalue.((collect Leaves)(tree))
ls = nodevalue.(collect(Leaves(tree)))
# use set so we don't have to guarantee ordering
@test Set(ls) == Set([joinpath(".","A","f2"), joinpath(".","B"), joinpath(".","f1")])
@test Set(ls) == Set([joinpath(".", "A", "f2"), joinpath(".", "B"), joinpath(".", "f1")])
@test treeheight(tree) == 2
end
end

include(joinpath(@__DIR__,"examples","binarytree.jl"))
include(joinpath(@__DIR__, "examples", "binarytree.jl"))

@testset "BinaryNode" begin
t = binarynode_example()
Expand All @@ -107,12 +107,12 @@ include(joinpath(@__DIR__,"examples","binarytree.jl"))
@test nodevalue.(ls) == [3, 2]

predfs = @inferred collect(PreOrderDFS(t))
@test nodevalue.(predfs) == [0,1,3,2]
@test nodevalue.(predfs) == [0, 1, 3, 2]

postdfs = @inferred collect(PostOrderDFS(t))
@test nodevalue.(postdfs) == [3,1,2,0]
@test nodevalue.(postdfs) == [3, 1, 2, 0]

sbfs = @inferred collect(StatelessBFS(t))
@test nodevalue.(sbfs) == [0,1,2,3]
@test nodevalue.(sbfs) == [0, 1, 2, 3]
end

0 comments on commit 0356cc1

Please sign in to comment.