diff --git a/src/handler.jl b/src/handler.jl index 1e4566b..d5095a8 100644 --- a/src/handler.jl +++ b/src/handler.jl @@ -106,16 +106,26 @@ the checkpoint tags and name. """ struct DictHandler <: AbstractHandler objects::Dict{String, Dict} - DictHandler() = new(Dict{String, Dict}()) + force::Bool end -commit!(handler::DictHandler, k::AbstractString, data) = setindex!(handler.objects, data, k) +DictHandler(; objects=Dict{String, Dict}(), force=false) = DictHandler(objects, force) + +function commit!(handler::DictHandler, k::AbstractString, data) + if handler.force + return setindex!(handler.objects, data, k) + else + res = get!(handler.objects, k, data) + isequal(res, data) || throw(ArgumentError("$k has already been stored")) + return res + end +end function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...) # TODO: Remove duplicate wrapper code checkpoint_deprecation(tags...) with_checkpoint_tags(tags...) do debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).") - handler.objects[getkey(handler, name)] = data + commit!(handler, getkey(handler, name), data) end end diff --git a/test/runtests.jl b/test/runtests.jl index 7066a75..553be00 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -291,6 +291,16 @@ Distributed.addprocs(5) @test haskey(objects, "TestPkg/qux_b") @test objects["TestPkg/qux_b"][:data] == b + + # Test that rerunning a function and overwriting a checkpoint fails by default + @test_throws ArgumentError TestPkg.foo(x, rand(10, 10)) + + # Retry after setting `force=true` + handler = DictHandler(; objects=objects, force=true) + Checkpoints.config(handler, "TestPkg") + TestPkg.foo(x, rand(10, 10)) + @test objects["TestPkg/foo"][:x] == x + @test objects["TestPkg/foo"][:y] != y end end end