Skip to content

Commit

Permalink
chore: update script
Browse files Browse the repository at this point in the history
  • Loading branch information
agdestein committed Sep 22, 2024
1 parent fb0a3e4 commit 4c37146
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
9 changes: 7 additions & 2 deletions lib/NeuralClosure/src/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ If not using interactive GLMakie window, set `displayupdates` to
function create_callback(
err;
θ,
callbackstate = (; n = 0, θmin = θ, emin = eltype(θ)(Inf), hist = Point2f[]),
callbackstate = (; n = 0, θmin = θ, emin = eltype(θ)(Inf), hist = Point2f[],
ctime = time(),
),
displayref = true,
displayfig = true,
displayupdates = false,
Expand All @@ -261,7 +263,10 @@ function create_callback(
if n % nupdate == 0
(; θ) = trainstate
e = err(θ)
@info "Iteration $n \trelative error: $e"
newtime = time()
itertime = (newtime - callbackstate.ctime) / nupdate
@reset callbackstate.ctime = newtime
@info "Iteration $n \t relative error: $e \t sec/iter: $itertime"
hist = push!(copy(hist), Point2f(n, e))
@reset callbackstate.hist = hist
obs[] = hist
Expand Down
34 changes: 20 additions & 14 deletions lib/PaperDC/les3D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ palette = (; color = ["#3366cc", "#cc0000", "#669900", "#ff9900"])
# Choose where to put output
# outdir = joinpath(@__DIR__, "output", "les3D")
outdir = joinpath(ENV["DEEPDIP"], "output", "les3D")
plotdir = "$outdir/plots"
# plotdir = "$outdir/plots"
plotdir = joinpath(@__DIR__, "output", "les3D")
ispath(outdir) || mkpath(outdir)
ispath(plotdir) || mkpath(plotdir)

Expand Down Expand Up @@ -112,7 +113,7 @@ end
#
# Create filtered DNS data for training, validation, and testing.

ntrajectory = 10
ntrajectory = 4
dns_seeds = splitseed(seeds.dns, ntrajectory)
filenames = map(seed -> "$outdir/data_$(repr(seed)).jld2", dns_seeds)

Expand Down Expand Up @@ -152,13 +153,17 @@ end

# Load filtered DNS data
data = load.(filenames, "data");
Base.summarysize(data) * 1e-9
@info(
"Data: ",
Base.summarysize(data) * 1e-9,
length.(getfield.(data, :t)),
)

sum(d -> d.comptime, data) / 3600

data_train = data[1:8];
data_valid = data[9:9];
data_test = data[10:10];
data_train = data[1:3];
data_valid = data[4:4];
# data_test = data[10:10];

# Build LES setup and assemble operators
setups = map(
Expand All @@ -173,7 +178,7 @@ setups = map(
# Create input/output arrays for a-priori training (ubar vs c)
io_train = create_io_arrays(data_train, setups);
io_valid = create_io_arrays(data_valid, setups);
io_test = create_io_arrays(data_test, setups);
# io_test = create_io_arrays(data_test, setups);

# ### Plot data

Expand Down Expand Up @@ -276,15 +281,16 @@ end
let
I = CartesianIndices(io_train)
itask = parse(Int, ENV["SLURM_ARRAY_TASK_ID"])
igrid, ifil = I[itask].I
# ig, ifil = I[itask].I
ig, ifil = 1, 1
# ngrid, nfilter = size(io_train)
# for ifil = 1:nfilter, ig = 1:ngrid
clean()
starttime = time()
@info "Training a-priori for ig = $ig, ifil = $ifil"
trainseed, validseed = splitseed(seeds.prior, 2) # Same seed for all training setups
dataloader = create_dataloader_prior(io_train[ig, ifil]; batchsize = 50, device)
θ = T(1.0e0) * device(θ₀)
θ = T(1.0) * device(θ₀)
loss = create_loss_prior(mean_squared_error, closure)
opt = Adam(T(1.0e-3))
optstate = Optimisers.setup(opt, θ)
Expand All @@ -301,22 +307,22 @@ let
trainstate = (; optstate, θ, rng = Xoshiro(trainseed))
base, ext = splitext(priorfiles[ig, ifil])
checkpointname = "$(base)_checkpoint.jld2"
icheck = 1
ncheck = 0
if false
# Resume from checkpoint
icheck, trainstate, callbackstate =
load(checkpointname, "icheck", "trainstate", "callbackstate")
ncheck, trainstate, callbackstate =
load(checkpointname, "ncheck", "trainstate", "callbackstate")
trainstate = trainstate |> gpu_device()
@reset callbackstate.θmin = callbackstate.θmin |> gpu_device()
end
for ickeck = icheck:3
for icheck = ncheck+1:10
(; trainstate, callbackstate) =
train(; dataloader, loss, trainstate, callbackstate, callback, niter = 1_000)
# Save all states to resume training later
# First move all arrays to CPU
c = callbackstate |> cpu_device()
t = trainstate |> cpu_device()
jldsave(checkpointname; icheck, callbackstate = c, trainstate = t)
jldsave(checkpointname; ncheck = icheck, callbackstate = c, trainstate = t)
end
θ = callbackstate.θmin # Use best θ instead of last θ
prior = (; θ = Array(θ), comptime = time() - starttime, callbackstate.hist)
Expand Down

0 comments on commit 4c37146

Please sign in to comment.