Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Searching through Vector{IndexEntry} #46

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpoints"
uuid = "b4a3413d-e481-5afc-88ff-bdfbd6a50dce"
authors = "Invenia Technical Computing Corporation"
version = "0.3.17"
version = "0.3.18"

[deps]
AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95"
Expand Down
47 changes: 41 additions & 6 deletions src/Checkpoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Memento
using OrderedCollections

export checkpoint, with_checkpoint_tags # creating stuff
export enabled_checkpoints
export enabled_checkpoints, deprecated_checkpoints
# indexing stuff
export IndexEntry, index_checkpoint_files, index_files
export checkpoint_fullname, checkpoint_name, checkpoint_path, prefixes, tags
Expand All @@ -29,7 +29,7 @@ __init__() = Memento.register(LOGGER)

include("handler.jl")

const CHECKPOINTS = Dict{String, Union{Nothing, Handler}}()
const CHECKPOINTS = Dict{String, Union{Nothing, String, Handler}}()
@contextvar CONTEXT_TAGS::Tuple{Vararg{Pair{Symbol, Any}}} = Tuple{}()

include("session.jl")
Expand Down Expand Up @@ -74,7 +74,16 @@ available() = collect(keys(CHECKPOINTS))

Returns a vector of all enabled ([`config`](@ref)ured) checkpoints.
"""
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] !== nothing, available())
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa Handler, available())

"""
deprecated_checkpoints() -> Dict{String, String}

Returns a Dict mapping deprecated checkpoints to the corresponding new names.
"""
function deprecated_checkpoints()
return Dict{String, String}(filter(p -> last(p) isa String, CHECKPOINTS))
end

"""
checkpoint([prefix], name, data)
Expand Down Expand Up @@ -131,9 +140,7 @@ If the first argument is not a `Handler` then all `args` and `kwargs` are passed
"""
function config(handler::Handler, names::Vector{String})
for n in names
haskey(CHECKPOINTS, n) || warn(LOGGER, "$n is not a registered checkpoint")
debug(LOGGER, "Checkpoint $n set to use $(handler)")
CHECKPOINTS[n] = handler
_config(handler, n)
end
end

Expand All @@ -149,6 +156,20 @@ function config(prefix::Union{Module, String}, args...; kwargs...)
config(Handler(args...; kwargs...), prefix)
end

# To avoid collisions with `prefix` method above, which should probably use
# a regex / glob syntax
function _config(handler, name::String)
haskey(CHECKPOINTS, name) || warn(LOGGER, "$name is not a registered checkpoint")

# Warn about deprecated checkpoints and recurse if necessary
if CHECKPOINTS[name] isa String
Base.depwarn("$name has been deprecated to $(CHECKPOINTS[name])", :config)
return _config(handler, CHECKPOINTS[name])
else
debug(LOGGER, "Checkpoint $name set to use $(handler)")
return setindex!(CHECKPOINTS, handler, name)
end
end

"""
register([prefix], labels)
Expand All @@ -171,4 +192,18 @@ function register(prefix::Union{Module, String}, labels::Vector{String})
register(map(l -> join([prefix, l], "."), labels))
end


"""
deprecate([prefix], prev, curr)

Deprecate a checkpoint that has been renamed.
"""
function deprecate end

deprecate(prev, curr) = setindex!(CHECKPOINTS, curr, prev)

function deprecate(prefix::Union{Module, String}, prev, curr)
deprecate(join([prefix, prev], "."), join([prefix, curr], "."))
end

end # module
91 changes: 85 additions & 6 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
IndexEntry(checkpoint_path, base_dir)
IndexEntry(checkpoint_path, base_dir, depnames_lookup=_depnames_lookup())
IndexEntry(checkpoint_path, checkpoint_name, prefixes, tags)

This is an index entry describing the output file from a checkpoint.
Expand All @@ -22,9 +22,10 @@ struct IndexEntry
checkpoint_name::AbstractString
prefixes::NTuple{<:Any, AbstractString}
tags::NTuple{<:Any, Pair{Symbol, <:AbstractString}}
deprecated_names::NTuple{<:Any, AbstractString}
end

function IndexEntry(filepath::AbstractPath, base_dir)
function IndexEntry(filepath::AbstractPath, base_dir, depnames_lookup=_depnames_lookup())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're passing this in rather than calling it in the function?

if dirname(filepath) == base_dir
# workaround for relpath erroring on equal S3Paths
# https://github.com/rofinn/FilePathsBase.jl/issues/156
Expand All @@ -39,8 +40,23 @@ function IndexEntry(filepath::AbstractPath, base_dir)
return Symbol(tag)=>val
end
end

checkpoint_name = filename(filepath)
return IndexEntry(filepath, checkpoint_name, prefixes, tags)
checkpoint_fullname = join((prefixes..., checkpoint_name), ".")
if haskey(depnames_lookup, checkpoint_fullname)
deprecated_names = Tuple(depnames_lookup[checkpoint_fullname])
else
filtered = filter(e -> checkpoint_fullname in last(e), depnames_lookup)
if isempty(filtered)
deprecated_names = ()
else
k, v = only(filtered)
checkpoint_name = last(split(k, "."))
deprecated_names = Tuple(v)
end
end

return IndexEntry(filepath, checkpoint_name, prefixes, tags, deprecated_names)
end


Expand Down Expand Up @@ -103,11 +119,27 @@ Note that if the tags are unique, then their values call also be accessed via a
"""
tags(x::IndexEntry) = getfield(x, :tags)

"""
deprecated_names(x::IndexEntry)

Previous `checkpoint_name`s that have since been renamed.
If the checkpoint was previously saved used `checkpoint(Forecasters, "predictions", ...)`,
but has since been renamed to `checkpoint(Forecasters, "forecasts", ...)` then
predictions" would live in this list.
"""
deprecated_names(x::IndexEntry) = getfield(x, :deprecated_names)

_tag_names(x::IndexEntry) = first.(tags(x))

#Tables.columnnames(x::IndexEntry) = propertynames(x)
function Base.propertynames(x::IndexEntry)
return [:prefixes, :checkpoint_name, _tag_names(x)..., :checkpoint_path]
return [
:prefixes,
:checkpoint_name,
_tag_names(x)...,
:checkpoint_path,
:deprecated_names,
]
end

function Base.getproperty(x::IndexEntry, name::Symbol)
Expand Down Expand Up @@ -162,8 +194,10 @@ You can also work with it directly, say you wanted to get all checkpoints files
"""
function index_checkpoint_files(dir::AbstractPath)
isdir(dir) || throw(ArgumentError("Need an existing directory."))
depnames_lookup = _depnames_lookup()

map(Iterators.filter(==("jlso") ∘ extension, walkpath(dir))) do checkpoint_path
return IndexEntry(checkpoint_path, dir)
return IndexEntry(checkpoint_path, dir, depnames_lookup)
end
end

Expand All @@ -176,9 +210,54 @@ Constructs a index for all the files located within `dir`.
Same as [`index_checkpoint_files`] except not restricted to files created by Checkpoints.jl.
"""
function index_files(dir::AbstractPath)
depnames_lookup = _depnames_lookup()
map(Iterators.filter(isfile, walkpath(dir))) do path
return IndexEntry(path, dir)
return IndexEntry(path, dir, depnames_lookup)
end
end

index_files(dir) = index_files(Path(dir))

"""
search(name::AbstractString, index)

Returns elements where `name` matches either the full checkpoint name or deprecated names.
If the `name` is deprecated then a deprecation warning is thrown.

# Arguments
- `name`: The full checkpoint name to search for (ie: `"Forecasters.forecasts"`)
- `index`: Iterable of `IndexEntry` elements
"""
function search(name::AbstractString, index)
results = filter(index) do idx
(
checkpoint_fullname(idx) == name ||
name in idx.deprecated_names
)
end

isempty(results) && return results

fullname = checkpoint_fullname(first(results))
name == fullname || Base.depwarn("$name has been deprecated to $fullname.", :search)
return results
end


# Utility function for generating a checkpoint name lookup table from the current registry
function _depnames_lookup()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The amount of added complexity here makes me wonder if we're accounting for deprecations in the correct way. I'm not sure we should allow a deprecation to point to another deprecation.

deps = deprecated_checkpoints()
results = Dict{String, Set{String}}(
x => Set{String}() for x in setdiff(available(), keys(deps))
)

# Simple recursive find_name function to find the original non-deprecated name
find_name(x) = haskey(deps, x) ? find_name(deps[x]) : x

for (prev, curr) in deps
k = find_name(curr)
push!(results[k], prev)
end

return results
end
20 changes: 20 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@
end
end

@testset "Searching for deprecated checkpoint" begin
mktempdir(SystemPath) do path
@test_deprecated Checkpoints.config("TestPkg.quuz", path)
a = Dict(zip(
map(x -> Symbol(randstring(4)), 1:10),
map(x -> rand(10), 1:10)
))
b = rand(10)
TestPkg.qux(a, b)

index = index_checkpoint_files(path)
entry= only(index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
entry= only(index)
entry = only(index)

@test checkpoint_name(entry) == "qux_b"
@test checkpoint_fullname(entry) == "TestPkg.qux_b"
@test Checkpoints.deprecated_names(entry) == ("TestPkg.quuz",)
res = @test_deprecated Checkpoints.search("TestPkg.quuz", index)
@test res == index
end
end

@testset "files not saved by Checkpoints.jl" begin
mktempdir(SystemPath) do path
Checkpoints.config("TestPkg.bar", path)
Expand Down
21 changes: 21 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ Distributed.addprocs(5)

Checkpoints.config("c2", path)
@test enabled_checkpoints() == ["c1", "c2"]

# Manually disable the checkpoint again
Checkpoints.CHECKPOINTS["c1"] = nothing
Checkpoints.CHECKPOINTS["c2"] = nothing
end
end

@testset "deprecated" begin
mktempdir() do path
@test deprecated_checkpoints() == Dict(
"TestPkg.quux" => "TestPkg.qux_a",
"TestPkg.quuz" => "TestPkg.qux_b",
)

@show Checkpoints.CHECKPOINTS
@test_deprecated Checkpoints.config("TestPkg.quux", path)
@show Checkpoints.CHECKPOINTS
@test enabled_checkpoints() == ["TestPkg.qux_a"]

# Manually disable the checkpoint again
Checkpoints.CHECKPOINTS["TestPkg.qux_a"] = nothing
end
end

Expand Down
8 changes: 6 additions & 2 deletions test/testpkg.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
module TestPkg

using Checkpoints: register, checkpoint, with_checkpoint_tags, Session
using Checkpoints: deprecate, register, checkpoint, with_checkpoint_tags, Session

# We aren't using `@__MODULE__` because that would return TestPkg on 0.6 and Main.TestPkg on 0.7
const MODULE = "TestPkg"

__init__() = register(MODULE, ["foo", "bar", "baz", "qux_a", "qux_b", "deprecated"])
function __init__()
register(MODULE, ["foo", "bar", "baz", "qux_a", "qux_b", "deprecated"])
deprecate(MODULE, "quux", "qux_a")
deprecate(MODULE, "quuz", "qux_b")
end

function foo(x::Matrix, y::Matrix)
# Save multiple variables to 1 foo.jlso file by passing in pairs of variables
Expand Down