From c6d43fb031c7657ed3c5e29281a6af2e5307b79e Mon Sep 17 00:00:00 2001 From: Rory Finnegan Date: Tue, 14 Feb 2023 12:19:25 -0800 Subject: [PATCH 1/2] Introduce a Handler hierarchy and DictHandler. --- Project.toml | 2 +- src/Checkpoints.jl | 22 ++++----- src/deprecated.jl | 2 + src/handler.jl | 115 ++++++++++++++++++++++++++++++--------------- src/session.jl | 38 ++++++++++----- test/runtests.jl | 36 ++++++++++++++ 6 files changed, 151 insertions(+), 64 deletions(-) diff --git a/Project.toml b/Project.toml index 423048d..58fa0e7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Checkpoints" uuid = "b4a3413d-e481-5afc-88ff-bdfbd6a50dce" authors = "Invenia Technical Computing Corporation" -version = "0.3.20" +version = "0.3.21" [deps] AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95" diff --git a/src/Checkpoints.jl b/src/Checkpoints.jl index 8893617..5944ad6 100644 --- a/src/Checkpoints.jl +++ b/src/Checkpoints.jl @@ -29,7 +29,7 @@ __init__() = Memento.register(LOGGER) include("handler.jl") -const CHECKPOINTS = Dict{String, Union{Nothing, String, Handler}}() +const CHECKPOINTS = Dict{String, Union{Nothing, String, AbstractHandler}}() @contextvar CONTEXT_TAGS::Tuple{Vararg{Pair{Symbol, Any}}} = Tuple{}() include("session.jl") @@ -75,7 +75,7 @@ available() = collect(keys(CHECKPOINTS)) Returns a vector of all enabled ([`config`](@ref)ured) and not [`deprecate`](@ref)d checkpoints. Use [`deprecated_checkpoints`](@ref) to retrieve a mapping of old / deprecated checkpoints. """ -enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa Handler, available()) +enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa AbstractHandler, available()) """ deprecated_checkpoints() -> Dict{String, String} @@ -130,31 +130,31 @@ function checkpoint(prefix::Union{Module, String}, name::String, args...; tags.. end """ - config(handler::Handler, labels::Vector{String}) - config(handler::Handler, prefix::String) + config(handler::AbstractHandler, labels::Vector{String}) + config(handler::AbstractHandler, prefix::String) config(labels::Vector{String}, args...; kwargs...) config(prefix::String, args...; kwargs...) -Configures the specified checkpoints with a `Handler`. -If the first argument is not a `Handler` then all `args` and `kwargs` are passed to a -`Handler` constructor for you. +Configures the specified checkpoints with a `AbstractHandler`. +If the first argument is not an `AbstractHandler` then all `args` and `kwargs` are +passed to a `JLSOHandler` constructor for you. """ -function config(handler::Handler, names::Vector{String}) +function config(handler::AbstractHandler, names::Vector{String}) for n in names _config(handler, n) end end -function config(handler::Handler, prefix::Union{Module, String}) +function config(handler::AbstractHandler, prefix::Union{Module, String}) config(handler, filter(l -> startswith(l, prefix), available())) end function config(names::Vector{String}, args...; kwargs...) - config(Handler(args...; kwargs...), names) + config(JLSOHandler(args...; kwargs...), names) end function config(prefix::Union{Module, String}, args...; kwargs...) - config(Handler(args...; kwargs...), prefix) + config(JLSOHandler(args...; kwargs...), prefix) end # To avoid collisions with `prefix` method above, which should probably use diff --git a/src/deprecated.jl b/src/deprecated.jl index c781d5a..a138050 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -11,3 +11,5 @@ function checkpoint_deprecation(tags...) :checkpoint ) end + +Base.@deprecate_binding Handler JLSOHandler diff --git a/src/handler.jl b/src/handler.jl index f8da4cb..1e4566b 100644 --- a/src/handler.jl +++ b/src/handler.jl @@ -1,56 +1,82 @@ -struct Handler{P<:AbstractPath} - path::P - settings # Could be Vector or Pairs on 0.6 or 1.0 respectively -end +abstract type AbstractHandler end """ - Handler(path::Union{String, AbstractPath}; kwargs...) - Handler(bucket::String, prefix::String; kwargs...) + getkey(handler, name, separator="/") -> String -Handles iteratively saving JLSO file to the specified path location. -FilePath are used to abstract away differences between paths on S3 or locally. +Combine the `CONTEXT_TAGS` and `name` into a unique checkpoint key as a string. +If the checkpoint name includes `.`, usually representing nested modules, these are +also replaced with the provided separator. """ -Handler(path::AbstractPath; kwargs...) = Handler(path, kwargs) -Handler(path::String; kwargs...) = Handler(Path(path), kwargs) -Handler(bucket::String, prefix::String; kwargs...) = Handler(S3Path("s3://$bucket/$prefix"), kwargs) +function getkey(::AbstractHandler, name::String, separator="/")::String + prefix = ["$key=$val" for (key, val) in CONTEXT_TAGS[]] + parts = split(name, '.') # Split up the name by '.' + return Base.join(vcat(prefix, parts), separator) +end -""" - path(handler, name) +path(args...) = Path(getkey(args...)) -Determines the path to save to based on the handlers path prefix, name, and context. -Tags are used to dynamically prefix the named file with the handler's path. -Names with a '.' separators will be used to form subdirectories -(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso"). """ -function path(handler::Handler{P}, name::String) where P - prefix = ["$key=$val" for (key,val) in CONTEXT_TAGS[]] + stage!(handler::AbstractHandler, objects, data::Dict{Symbol}) - # Split up the name by '.' and add the jlso extension - parts = split(name, '.') - parts[end] = string(parts[end], ".jlso") +Update the objects with the new data. +By default all handlers assume objects implements the associative interface. +""" +function stage!(handler::AbstractHandler, objects, data::Dict{Symbol}) + for (k, v) in data + objects[k] = v + end - return join(handler.path, prefix..., parts...) + return objects end """ - stage!(handler::Handler, jlso::JLSOFIle, data::Dict{Symbol}) + commit!(handler, prefix, objects) -Update the JLSOFile with the new data. +Serialize and write objects to a given path/prefix/key as defined by the handler. """ -function stage!(handler::Handler, jlso::JLSO.JLSOFile, data::Dict{Symbol}) - for (k, v) in data - jlso[k] = v +commit! + +#= +Define our no-op conditions just to be safe +=# +function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...) + checkpoint_deprecation(tags...) + with_checkpoint_tags(tags...) do + debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.") + nothing end +end + - return jlso +struct JLSOHandler{P<:AbstractPath} <: AbstractHandler + path::P + settings # Could be Vector or Pairs on 0.6 or 1.0 respectively end """ - commit!(handler, path, jlso) + JLSOHandler(path::Union{String, AbstractPath}; kwargs...) + JLSOHandler(bucket::String, prefix::String; kwargs...) + +Handles iteratively saving JLSO file to the specified path location. +FilePath are used to abstract away differences between paths on S3 or locally. +""" +JLSOHandler(path::AbstractPath; kwargs...) = JLSOHandler(path, kwargs) +JLSOHandler(path::String; kwargs...) = JLSOHandler(Path(path), kwargs) +JLSOHandler(bucket::String, prefix::String; kwargs...) = JLSOHandler(S3Path("s3://$bucket/$prefix"), kwargs) + +""" + path(handler, name) -Write the JLSOFile to the path as bytes. +Determines the path to save to based on the handlers path prefix, name, and context. +Tags are used to dynamically prefix the named file with the handler's path. +Names with a '.' separators will be used to form subdirectories +(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso"). """ -function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath +function path(handler::JLSOHandler{P}, name::String) where P + return join(handler.path, getkey(handler, name) * ".jlso") +end + +function commit!(handler::JLSOHandler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath # NOTE: This is only necessary because FilePathsBase.FileBuffer needs to support # write(::FileBuffer, ::UInt8) # https://github.com/rofinn/FilePathsBase.jl/issues/45 @@ -61,7 +87,7 @@ function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: A write(path, bytes) end -function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...) +function checkpoint(handler::JLSOHandler, name::String, data::Dict{Symbol}; tags...) checkpoint_deprecation(tags...) with_checkpoint_tags(tags...) do debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).") @@ -72,13 +98,24 @@ function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...) end end -#= -Define our no-op conditions just to be safe -=# -function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...) +""" + DictHandler(objects) + +Saves checkpointed objects into a dictionary where the keys are strings generated from +the checkpoint tags and name. +""" +struct DictHandler <: AbstractHandler + objects::Dict{String, Dict} + DictHandler() = new(Dict{String, Dict}()) +end + +commit!(handler::DictHandler, k::AbstractString, data) = setindex!(handler.objects, data, k) + +function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...) + # TODO: Remove duplicate wrapper code checkpoint_deprecation(tags...) with_checkpoint_tags(tags...) do - debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.") - nothing + debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).") + handler.objects[getkey(handler, name)] = data end end diff --git a/src/session.jl b/src/session.jl index 8c52d94..fa5b411 100644 --- a/src/session.jl +++ b/src/session.jl @@ -1,4 +1,4 @@ -struct Session{H<:Union{Nothing, Handler}} +struct Session{H<:Union{Nothing, AbstractHandler}} name::String handler::H objects::DefaultDict @@ -8,11 +8,7 @@ function Session(name::String) # Create our objects dictionary which defaults to returning # an empty JLSOFile handler = CHECKPOINTS[name] - - objects = DefaultDict{AbstractPath, JLSO.JLSOFile}() do - JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...) - end - + objects = session_objects(handler) Session{typeof(handler)}(name, handler, objects) end @@ -34,29 +30,45 @@ function Session(f::Function, prefix::Union{Module, String}, names::Vector{Strin Session(f, map(n -> "$prefix.$n", names)) end +function session_objects(handler) + return DefaultDict{AbstractString, Dict}() do + Dict{Symbol, Any}() + end +end + +function session_objects(handler::JLSOHandler) + return DefaultDict{AbstractPath, JLSO.JLSOFile}() do + JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...) + end +end + """ commit!(session) -Write all staged JLSOFiles to the respective paths. +Write all staged objects to the respective keys. """ function commit!(session::Session) # No-ops skip when handler is nothing session.handler === nothing && return nothing - for (p, jlso) in session.objects - commit!(session.handler, p, jlso) + for (k, v) in session.objects + commit!(session.handler, k, v) end end function checkpoint(session::Session, data::Dict{Symbol}; tags...) checkpoint_deprecation(tags...) with_checkpoint_tags(tags...) do + handler = session.handler + name = session.name + K = keytype(session.objects) + # No-ops skip when handler is nothing - session.handler === nothing && return nothing + handler === nothing && return nothing - p = path(session.handler, session.name) - jlso = session.objects[p] - session.objects[p] = stage!(session.handler, jlso, data) + # Our handler may not always be storing data in filepaths + k = K <: AbstractPath ? path(handler, name) : getkey(handler, name) + session.objects[k] = stage!(handler, session.objects[k], data) end end diff --git a/test/runtests.jl b/test/runtests.jl index 3509119..7066a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using Test using AWS: AWSConfig using AWSS3: S3Path, s3_put, s3_list_buckets, s3_create_bucket using Tables: Tables +using Checkpoints: JLSOHandler, DictHandler Distributed.addprocs(5) @everywhere using Checkpoints @@ -256,5 +257,40 @@ Distributed.addprocs(5) @test data[:data] == b end end + + # We're largely reusing the same code for different handlers, but make sure + # that saving to a dict also works. + @testset "DictHandler" begin + a = Dict(zip( + map(x -> Symbol(randstring(4)), 1:10), + map(x -> rand(10), 1:10) + )) + b = rand(10) + handler = DictHandler() + objects = handler.objects + Checkpoints.config(handler, "TestPkg") + + @test isempty(handler.objects) + TestPkg.foo(x, y) + @test haskey(objects, "TestPkg/foo") + @test issetequal(keys(objects["TestPkg/foo"]), [:x, :y]) + @test objects["TestPkg/foo"][:x] == x + @test objects["TestPkg/foo"][:y] == y + + TestPkg.bar(b) + @test haskey(objects, "date=2017-01-01/TestPkg/bar") + @test objects["date=2017-01-01/TestPkg/bar"][:data] == b + + TestPkg.baz(a) + @test haskey(objects, "TestPkg/baz") + @test objects["TestPkg/baz"] == a + + TestPkg.qux(a, b) + @test haskey(objects, "TestPkg/qux_a") + @test objects["TestPkg/qux_a"] == a + + @test haskey(objects, "TestPkg/qux_b") + @test objects["TestPkg/qux_b"][:data] == b + end end end From 27e02eb572b32c088e5df7cf996ad220eef280d5 Mon Sep 17 00:00:00 2001 From: Rory Finnegan Date: Thu, 16 Feb 2023 10:45:10 -0800 Subject: [PATCH 2/2] Support erroring when overwriting checkpoints. --- src/handler.jl | 16 +++++++++++++--- test/runtests.jl | 10 ++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/handler.jl b/src/handler.jl index 1e4566b..d5095a8 100644 --- a/src/handler.jl +++ b/src/handler.jl @@ -106,16 +106,26 @@ the checkpoint tags and name. """ struct DictHandler <: AbstractHandler objects::Dict{String, Dict} - DictHandler() = new(Dict{String, Dict}()) + force::Bool end -commit!(handler::DictHandler, k::AbstractString, data) = setindex!(handler.objects, data, k) +DictHandler(; objects=Dict{String, Dict}(), force=false) = DictHandler(objects, force) + +function commit!(handler::DictHandler, k::AbstractString, data) + if handler.force + return setindex!(handler.objects, data, k) + else + res = get!(handler.objects, k, data) + isequal(res, data) || throw(ArgumentError("$k has already been stored")) + return res + end +end function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...) # TODO: Remove duplicate wrapper code checkpoint_deprecation(tags...) with_checkpoint_tags(tags...) do debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).") - handler.objects[getkey(handler, name)] = data + commit!(handler, getkey(handler, name), data) end end diff --git a/test/runtests.jl b/test/runtests.jl index 7066a75..553be00 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -291,6 +291,16 @@ Distributed.addprocs(5) @test haskey(objects, "TestPkg/qux_b") @test objects["TestPkg/qux_b"][:data] == b + + # Test that rerunning a function and overwriting a checkpoint fails by default + @test_throws ArgumentError TestPkg.foo(x, rand(10, 10)) + + # Retry after setting `force=true` + handler = DictHandler(; objects=objects, force=true) + Checkpoints.config(handler, "TestPkg") + TestPkg.foo(x, rand(10, 10)) + @test objects["TestPkg/foo"][:x] == x + @test objects["TestPkg/foo"][:y] != y end end end