Skip to content

Commit

Permalink
Merge pull request #50 from invenia/rf/handler-hierarchy
Browse files Browse the repository at this point in the history
Introduce a Handler hierarchy and DictHandler
  • Loading branch information
rofinn authored Feb 16, 2023
2 parents ee797f5 + 27e02eb commit 214d355
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpoints"
uuid = "b4a3413d-e481-5afc-88ff-bdfbd6a50dce"
authors = "Invenia Technical Computing Corporation"
version = "0.3.20"
version = "0.3.21"

[deps]
AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95"
Expand Down
22 changes: 11 additions & 11 deletions src/Checkpoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ __init__() = Memento.register(LOGGER)

include("handler.jl")

const CHECKPOINTS = Dict{String, Union{Nothing, String, Handler}}()
const CHECKPOINTS = Dict{String, Union{Nothing, String, AbstractHandler}}()
@contextvar CONTEXT_TAGS::Tuple{Vararg{Pair{Symbol, Any}}} = Tuple{}()

include("session.jl")
Expand Down Expand Up @@ -75,7 +75,7 @@ available() = collect(keys(CHECKPOINTS))
Returns a vector of all enabled ([`config`](@ref)ured) and not [`deprecate`](@ref)d checkpoints.
Use [`deprecated_checkpoints`](@ref) to retrieve a mapping of old / deprecated checkpoints.
"""
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa Handler, available())
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa AbstractHandler, available())

"""
deprecated_checkpoints() -> Dict{String, String}
Expand Down Expand Up @@ -130,31 +130,31 @@ function checkpoint(prefix::Union{Module, String}, name::String, args...; tags..
end

"""
config(handler::Handler, labels::Vector{String})
config(handler::Handler, prefix::String)
config(handler::AbstractHandler, labels::Vector{String})
config(handler::AbstractHandler, prefix::String)
config(labels::Vector{String}, args...; kwargs...)
config(prefix::String, args...; kwargs...)
Configures the specified checkpoints with a `Handler`.
If the first argument is not a `Handler` then all `args` and `kwargs` are passed to a
`Handler` constructor for you.
Configures the specified checkpoints with a `AbstractHandler`.
If the first argument is not an `AbstractHandler` then all `args` and `kwargs` are
passed to a `JLSOHandler` constructor for you.
"""
function config(handler::Handler, names::Vector{String})
function config(handler::AbstractHandler, names::Vector{String})
for n in names
_config(handler, n)
end
end

function config(handler::Handler, prefix::Union{Module, String})
function config(handler::AbstractHandler, prefix::Union{Module, String})
config(handler, filter(l -> startswith(l, prefix), available()))
end

function config(names::Vector{String}, args...; kwargs...)
config(Handler(args...; kwargs...), names)
config(JLSOHandler(args...; kwargs...), names)
end

function config(prefix::Union{Module, String}, args...; kwargs...)
config(Handler(args...; kwargs...), prefix)
config(JLSOHandler(args...; kwargs...), prefix)
end

# To avoid collisions with `prefix` method above, which should probably use
Expand Down
2 changes: 2 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ function checkpoint_deprecation(tags...)
:checkpoint
)
end

Base.@deprecate_binding Handler JLSOHandler
125 changes: 86 additions & 39 deletions src/handler.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,82 @@
struct Handler{P<:AbstractPath}
path::P
settings # Could be Vector or Pairs on 0.6 or 1.0 respectively
end
abstract type AbstractHandler end

"""
Handler(path::Union{String, AbstractPath}; kwargs...)
Handler(bucket::String, prefix::String; kwargs...)
getkey(handler, name, separator="/") -> String
Handles iteratively saving JLSO file to the specified path location.
FilePath are used to abstract away differences between paths on S3 or locally.
Combine the `CONTEXT_TAGS` and `name` into a unique checkpoint key as a string.
If the checkpoint name includes `.`, usually representing nested modules, these are
also replaced with the provided separator.
"""
Handler(path::AbstractPath; kwargs...) = Handler(path, kwargs)
Handler(path::String; kwargs...) = Handler(Path(path), kwargs)
Handler(bucket::String, prefix::String; kwargs...) = Handler(S3Path("s3://$bucket/$prefix"), kwargs)
function getkey(::AbstractHandler, name::String, separator="/")::String
prefix = ["$key=$val" for (key, val) in CONTEXT_TAGS[]]
parts = split(name, '.') # Split up the name by '.'
return Base.join(vcat(prefix, parts), separator)
end

"""
path(handler, name)
path(args...) = Path(getkey(args...))

Determines the path to save to based on the handlers path prefix, name, and context.
Tags are used to dynamically prefix the named file with the handler's path.
Names with a '.' separators will be used to form subdirectories
(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso").
"""
function path(handler::Handler{P}, name::String) where P
prefix = ["$key=$val" for (key,val) in CONTEXT_TAGS[]]
stage!(handler::AbstractHandler, objects, data::Dict{Symbol})
# Split up the name by '.' and add the jlso extension
parts = split(name, '.')
parts[end] = string(parts[end], ".jlso")
Update the objects with the new data.
By default all handlers assume objects implements the associative interface.
"""
function stage!(handler::AbstractHandler, objects, data::Dict{Symbol})
for (k, v) in data
objects[k] = v
end

return join(handler.path, prefix..., parts...)
return objects
end

"""
stage!(handler::Handler, jlso::JLSOFIle, data::Dict{Symbol})
commit!(handler, prefix, objects)
Update the JLSOFile with the new data.
Serialize and write objects to a given path/prefix/key as defined by the handler.
"""
function stage!(handler::Handler, jlso::JLSO.JLSOFile, data::Dict{Symbol})
for (k, v) in data
jlso[k] = v
commit!

#=
Define our no-op conditions just to be safe
=#
function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.")
nothing
end
end

return jlso

struct JLSOHandler{P<:AbstractPath} <: AbstractHandler
path::P
settings # Could be Vector or Pairs on 0.6 or 1.0 respectively
end

"""
commit!(handler, path, jlso)
JLSOHandler(path::Union{String, AbstractPath}; kwargs...)
JLSOHandler(bucket::String, prefix::String; kwargs...)
Handles iteratively saving JLSO file to the specified path location.
FilePath are used to abstract away differences between paths on S3 or locally.
"""
JLSOHandler(path::AbstractPath; kwargs...) = JLSOHandler(path, kwargs)
JLSOHandler(path::String; kwargs...) = JLSOHandler(Path(path), kwargs)
JLSOHandler(bucket::String, prefix::String; kwargs...) = JLSOHandler(S3Path("s3://$bucket/$prefix"), kwargs)

Write the JLSOFile to the path as bytes.
"""
function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath
path(handler, name)
Determines the path to save to based on the handlers path prefix, name, and context.
Tags are used to dynamically prefix the named file with the handler's path.
Names with a '.' separators will be used to form subdirectories
(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso").
"""
function path(handler::JLSOHandler{P}, name::String) where P
return join(handler.path, getkey(handler, name) * ".jlso")
end

function commit!(handler::JLSOHandler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath
# NOTE: This is only necessary because FilePathsBase.FileBuffer needs to support
# write(::FileBuffer, ::UInt8)
# https://github.com/rofinn/FilePathsBase.jl/issues/45
Expand All @@ -61,7 +87,7 @@ function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: A
write(path, bytes)
end

function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...)
function checkpoint(handler::JLSOHandler, name::String, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).")
Expand All @@ -72,13 +98,34 @@ function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...)
end
end

#=
Define our no-op conditions just to be safe
=#
function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...)
"""
DictHandler(objects)
Saves checkpointed objects into a dictionary where the keys are strings generated from
the checkpoint tags and name.
"""
struct DictHandler <: AbstractHandler
objects::Dict{String, Dict}
force::Bool
end

DictHandler(; objects=Dict{String, Dict}(), force=false) = DictHandler(objects, force)

function commit!(handler::DictHandler, k::AbstractString, data)
if handler.force
return setindex!(handler.objects, data, k)
else
res = get!(handler.objects, k, data)
isequal(res, data) || throw(ArgumentError("$k has already been stored"))
return res
end
end

function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...)
# TODO: Remove duplicate wrapper code
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.")
nothing
debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).")
commit!(handler, getkey(handler, name), data)
end
end
38 changes: 25 additions & 13 deletions src/session.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct Session{H<:Union{Nothing, Handler}}
struct Session{H<:Union{Nothing, AbstractHandler}}
name::String
handler::H
objects::DefaultDict
Expand All @@ -8,11 +8,7 @@ function Session(name::String)
# Create our objects dictionary which defaults to returning
# an empty JLSOFile
handler = CHECKPOINTS[name]

objects = DefaultDict{AbstractPath, JLSO.JLSOFile}() do
JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...)
end

objects = session_objects(handler)
Session{typeof(handler)}(name, handler, objects)
end

Expand All @@ -34,29 +30,45 @@ function Session(f::Function, prefix::Union{Module, String}, names::Vector{Strin
Session(f, map(n -> "$prefix.$n", names))
end

function session_objects(handler)
return DefaultDict{AbstractString, Dict}() do
Dict{Symbol, Any}()
end
end

function session_objects(handler::JLSOHandler)
return DefaultDict{AbstractPath, JLSO.JLSOFile}() do
JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...)
end
end

"""
commit!(session)
Write all staged JLSOFiles to the respective paths.
Write all staged objects to the respective keys.
"""
function commit!(session::Session)
# No-ops skip when handler is nothing
session.handler === nothing && return nothing

for (p, jlso) in session.objects
commit!(session.handler, p, jlso)
for (k, v) in session.objects
commit!(session.handler, k, v)
end
end

function checkpoint(session::Session, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
handler = session.handler
name = session.name
K = keytype(session.objects)

# No-ops skip when handler is nothing
session.handler === nothing && return nothing
handler === nothing && return nothing

p = path(session.handler, session.name)
jlso = session.objects[p]
session.objects[p] = stage!(session.handler, jlso, data)
# Our handler may not always be storing data in filepaths
k = K <: AbstractPath ? path(handler, name) : getkey(handler, name)
session.objects[k] = stage!(handler, session.objects[k], data)
end
end

Expand Down
46 changes: 46 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Test
using AWS: AWSConfig
using AWSS3: S3Path, s3_put, s3_list_buckets, s3_create_bucket
using Tables: Tables
using Checkpoints: JLSOHandler, DictHandler

Distributed.addprocs(5)
@everywhere using Checkpoints
Expand Down Expand Up @@ -256,5 +257,50 @@ Distributed.addprocs(5)
@test data[:data] == b
end
end

# We're largely reusing the same code for different handlers, but make sure
# that saving to a dict also works.
@testset "DictHandler" begin
a = Dict(zip(
map(x -> Symbol(randstring(4)), 1:10),
map(x -> rand(10), 1:10)
))
b = rand(10)
handler = DictHandler()
objects = handler.objects
Checkpoints.config(handler, "TestPkg")

@test isempty(handler.objects)
TestPkg.foo(x, y)
@test haskey(objects, "TestPkg/foo")
@test issetequal(keys(objects["TestPkg/foo"]), [:x, :y])
@test objects["TestPkg/foo"][:x] == x
@test objects["TestPkg/foo"][:y] == y

TestPkg.bar(b)
@test haskey(objects, "date=2017-01-01/TestPkg/bar")
@test objects["date=2017-01-01/TestPkg/bar"][:data] == b

TestPkg.baz(a)
@test haskey(objects, "TestPkg/baz")
@test objects["TestPkg/baz"] == a

TestPkg.qux(a, b)
@test haskey(objects, "TestPkg/qux_a")
@test objects["TestPkg/qux_a"] == a

@test haskey(objects, "TestPkg/qux_b")
@test objects["TestPkg/qux_b"][:data] == b

# Test that rerunning a function and overwriting a checkpoint fails by default
@test_throws ArgumentError TestPkg.foo(x, rand(10, 10))

# Retry after setting `force=true`
handler = DictHandler(; objects=objects, force=true)
Checkpoints.config(handler, "TestPkg")
TestPkg.foo(x, rand(10, 10))
@test objects["TestPkg/foo"][:x] == x
@test objects["TestPkg/foo"][:y] != y
end
end
end

2 comments on commit 214d355

@rofinn
Copy link
Member Author

@rofinn rofinn commented on 214d355 Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/77842

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.21 -m "<description of version>" 214d355b2e696a71bc444b332d7dfb85cbbb2d95
git push origin v0.3.21

Please sign in to comment.