Skip to content

Commit

Permalink
Support erroring when overwriting checkpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Feb 16, 2023
1 parent c6d43fb commit 27e02eb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/handler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,26 @@ the checkpoint tags and name.
"""
struct DictHandler <: AbstractHandler
objects::Dict{String, Dict}
DictHandler() = new(Dict{String, Dict}())
force::Bool
end

commit!(handler::DictHandler, k::AbstractString, data) = setindex!(handler.objects, data, k)
DictHandler(; objects=Dict{String, Dict}(), force=false) = DictHandler(objects, force)

function commit!(handler::DictHandler, k::AbstractString, data)
if handler.force
return setindex!(handler.objects, data, k)
else
res = get!(handler.objects, k, data)
isequal(res, data) || throw(ArgumentError("$k has already been stored"))
return res
end
end

function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...)
# TODO: Remove duplicate wrapper code
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).")
handler.objects[getkey(handler, name)] = data
commit!(handler, getkey(handler, name), data)
end
end
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ Distributed.addprocs(5)

@test haskey(objects, "TestPkg/qux_b")
@test objects["TestPkg/qux_b"][:data] == b

# Test that rerunning a function and overwriting a checkpoint fails by default
@test_throws ArgumentError TestPkg.foo(x, rand(10, 10))

# Retry after setting `force=true`
handler = DictHandler(; objects=objects, force=true)
Checkpoints.config(handler, "TestPkg")
TestPkg.foo(x, rand(10, 10))
@test objects["TestPkg/foo"][:x] == x
@test objects["TestPkg/foo"][:y] != y
end
end
end

0 comments on commit 27e02eb

Please sign in to comment.