From bb80cda3947af1f7549a4638e4d7d02b573a243a Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 11 Nov 2024 19:15:15 +0100 Subject: [PATCH 01/14] Implement Enzyme AD interface --- Project.toml | 15 +- ext/IncompressibleNavierStokesMakieExt.jl | 423 ++++++++++++++++++ lib/NeuralClosure/Project.toml | 4 +- lib/NeuralClosure/test/Project.toml | 16 +- lib/PaperDC/Project.toml | 16 +- lib/PaperDC/postanalysis.jl | 3 + lib/PaperDC/postanalysis3D.jl | 46 +- lib/SymmetryClosure/Project.toml | 8 +- src/IncompressibleNavierStokes.jl | 11 +- src/boundary_conditions.jl | 40 +- src/initializers.jl | 2 +- src/operators.jl | 281 ++++++++++-- src/pressure.jl | 20 + src/processors.jl | 379 +--------------- src/sciml.jl | 97 ++++ src/setup.jl | 13 + src/solver.jl | 6 +- .../step_explicit_runge_kutta.jl | 1 + src/time_steppers/step_lmwray3.jl | 119 ++++- src/utils.jl | 45 +- test/Project.toml | 2 + test/chainrules_enzyme.jl | 364 +++++++++++++++ test/enzyme_integration.jl | 119 +++++ test/operators.jl | 2 +- test/runtests.jl | 8 +- test/timesteppers.jl | 161 ++----- 26 files changed, 1569 insertions(+), 632 deletions(-) create mode 100644 ext/IncompressibleNavierStokesMakieExt.jl create mode 100644 src/sciml.jl create mode 100644 test/chainrules_enzyme.jl create mode 100644 test/enzyme_integration.jl diff --git a/Project.toml b/Project.toml index 443e061a7..8f25ebe23 100644 --- a/Project.toml +++ b/Project.toml @@ -1,17 +1,19 @@ name = "IncompressibleNavierStokes" uuid = "5e318141-6589-402b-868d-77d7df8c442e" authors = ["Syver Døving Agdestein, Benjamin Sanderse, and contributors"] -version = "2.0.0" +version = "2.0.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Observables = "510215fc-4207-5dde-b226-833fc4488ee2" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -20,24 +22,26 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" [extensions] IncompressibleNavierStokesCUDSSExt = ["CUDSS"] +IncompressibleNavierStokesMakieExt = ["Makie"] [compat] Adapt = "4" CUDA = "5" CUDSS = "0.3" -CairoMakie = "0.11" ChainRulesCore = "1" DocStringExtensions = "0.9" FFTW = "1" -GLMakie = "0.9" IterativeSolvers = "0.9" KernelAbstractions = "0.9" LinearAlgebra = "1" @@ -54,5 +58,4 @@ WriteVTK = "1" julia = "1.9" [extras] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" diff --git a/ext/IncompressibleNavierStokesMakieExt.jl b/ext/IncompressibleNavierStokesMakieExt.jl new file mode 100644 index 000000000..0245d92da --- /dev/null +++ b/ext/IncompressibleNavierStokesMakieExt.jl @@ -0,0 +1,423 @@ +""" +# Makie extension for IncompressibleNavierStokes + +This module adds methods to empty plotting functions defined in +IncompressibleNavierStokes. The methods are loaded when Makie is loaded in the +environment, through GLMakie or CairoMakie. This allows for installing +IncompressibleNavierStokes without Makie on servers to reduce precompilation +time. +""" +module IncompressibleNavierStokesMakieExt + +using DocStringExtensions +using IncompressibleNavierStokes +using IncompressibleNavierStokes: Dimension, kinetic_energy!, scalewithvolume! +using Makie +using Observables + +# We will extend these functions +import IncompressibleNavierStokes: + plotgrid, + animator, + realtimeplotter, + fieldplot, + energy_history_plot, + energy_spectrum_plot + +# Inherit docstring templates +@template (MODULES, FUNCTIONS, METHODS, TYPES) = IncompressibleNavierStokes + +plotgrid(x, y; kwargs...) = wireframe( + x, + y, + zeros(eltype(x), length(x), length(y)); + axis = (; aspect = DataAspect(), xlabel = "x", ylabel = "y"), + kwargs..., +) + +function plotgrid(x, y, z) + nx, ny, nz = length(x), length(y), length(z) + T = eltype(x) + + # x = repeat(x, 1, ny, nz) + # y = repeat(reshape(y, 1, :, 1), nx, 1, nz) + # z = repeat(reshape(z, 1, 1, :), nx, ny, 1) + # vol = repeat(reshape(z, 1, 1, :), nx, ny, 1) + # volume(x, y, z, vol) + fig = Figure() + + ax = Axis3(fig[1, 1]) + wireframe!(ax, x, y, fill(z[1], length(x), length(y))) + wireframe!(ax, x, y, fill(z[end], length(x), length(y))) + wireframe!(ax, x, fill(y[1], length(z)), repeat(z, 1, length(x))') + wireframe!(ax, x, fill(y[end], length(z)), repeat(z, 1, length(x))') + wireframe!(ax, fill(x[1], length(z)), y, repeat(z, 1, length(y))) + wireframe!(ax, fill(x[end], length(z)), y, repeat(z, 1, length(y))) + ax.aspect = :data + + ax = Axis(fig[1, 2]; xlabel = "x", ylabel = "y") + wireframe!(ax, x, y, zeros(T, length(x), length(y))) + ax.aspect = DataAspect() + + ax = Axis(fig[2, 1]; xlabel = "y", ylabel = "z") + wireframe!(ax, y, z, zeros(T, length(y), length(z))) + ax.aspect = DataAspect() + + ax = Axis(fig[2, 2]; xlabel = "x", ylabel = "z") + wireframe!(ax, x, z, zeros(T, length(x), length(z))) + ax.aspect = DataAspect() + + fig +end + +""" +Animate a plot of the solution every `update` iteration. +The animation is saved to `path`, which should have one +of the following extensions: + +- ".mkv" +- ".mp4" +- ".webm" +- ".gif" + +The plot is determined by a `plotter` processor. +Additional `kwargs` are passed to `plot`. +""" +animator(; + setup, + path, + plot = fieldplot, + nupdate = 1, + framerate = 24, + visible = true, + screen = nothing, + kwargs..., +) = + processor((stream, state) -> save(path, stream)) do outerstate + ispath(dirname(path)) || mkpath(dirname(path)) + state = Observable(outerstate[]) + fig = plot(state; setup, kwargs...) + visible && isnothing(screen) && display(fig) + visible && !isnothing(screen) && display(screen, fig) + stream = VideoStream(fig; framerate, visible) + on(outerstate) do outerstate + outerstate.n % nupdate == 0 || return + state[] = outerstate + recordframe!(stream) + end + stream + end + +""" +Processor for plotting the solution in real time. + +Keyword arguments: + +- `plot`: Plot function. +- `nupdate`: Show solution every `nupdate` time step. +- `displayfig`: Display the figure at the start. +- `screen`: If `nothing`, use default display. + If `GLMakie.screen()` multiple plots can be displayed in separate + windows like in MATLAB (see also `GLMakie.closeall()`). +- `displayupdates`: Display the figure at every update (if using CairoMakie). +- `sleeptime`: The `sleeptime` is slept at every update, to give Makie + time to update the plot. Set this to `nothing` to skip sleeping. + +Additional `kwargs` are passed to the `plot` function. +""" +realtimeplotter(; + setup, + plot = fieldplot, + nupdate = 1, + displayfig = true, + screen = nothing, + displayupdates = false, + sleeptime = nothing, + kwargs..., +) = + processor() do outerstate + state = Observable(outerstate[]) + fig = plot(state; setup, kwargs...) + displayfig && isnothing(screen) && display(fig) + displayfig && !isnothing(screen) && display(screen, fig) + on(outerstate) do outerstate + outerstate.n % nupdate == 0 || return + state[] = outerstate + displayupdates && display(fig) + isnothing(sleeptime) || sleep(sleeptime) + end + fig + end + +""" +Plot `state` field in pressure points. +If `state` is `Observable`, then the plot is interactive. + +Available fieldnames are: + +- `:velocity`, +- `:vorticity`, +- `:streamfunction`, +- `:pressure`. + +Available plot `type`s for 2D are: + +- `heatmap` (default), +- `image`, +- `contour`, +- `contourf`. + +Available plot `type`s for 3D are: + +- `contour` (default). + +The `alpha` value gets passed to `contour` in 3D. +""" +fieldplot(state; setup, kwargs...) = fieldplot( + setup.grid.dimension, + state isa Observable ? state : Observable(state); + setup, + kwargs..., +) + +function fieldplot( + ::Dimension{2}, + state; + setup, + fieldname = :vorticity, + psolver = nothing, + type = heatmap, + equal_axis = true, + docolorbar = true, + size = nothing, + title = nothing, + kwargs..., +) + (; grid) = setup + (; dimension, xlims, xp, Ip, Δ) = grid + D = dimension() + + xf = Array.(getindex.(xp, Ip.indices)) + + field = observefield(state; setup, fieldname, psolver) + + lims = lift(field) do f + if type ∈ (heatmap, image) + lims = get_lims(f) + elseif type ∈ (contour, contourf) + if ≈(extrema(f)...; rtol = 1e-10) + μ = mean(f) + a = μ - 1 + b = μ + 1 + f[1] += 1 + f[end] -= 1 + else + a, b = get_lims(f) + end + lims = (a, b) + end + lims + end + + if type ∈ (heatmap, image) + kwargs = (; colorrange = lims, kwargs...) + elseif type ∈ (contour, contourf) + kwargs = (; + extendlow = :auto, + extendhigh = :auto, + levels = @lift(LinRange($(lims)..., 10)), + # colorrange = lims, + kwargs..., + ) + end + + axis = (; + xlabel = "x", + ylabel = "y", + title = isnothing(title) ? titlecase(string(fieldname)) : title, + limits = (xlims[1]..., xlims[2]...), + ) + equal_axis && (axis = (axis..., aspect = DataAspect())) + + # Image requires boundary coordinates only + if type == image + Δx = first.(Array.(Δ)) + @assert all(≈(Δx[1]), Δx) "Image requires rectangular pixels" + @assert(all(α -> all(≈(Δx[α]), Δ[α]), 1:D), "Image requires uniform grid",) + xf = map(extrema, xf) + end + + size = isnothing(size) ? (;) : (; size) + fig = Figure(; size...) + ax, hm = type(fig[1, 1], xf..., field; axis, kwargs...) + docolorbar && Colorbar(fig[1, 2], hm) + + fig +end + +function fieldplot( + ::Dimension{3}, + state; + setup, + psolver = nothing, + fieldname = :eig2field, + alpha = convert(eltype(setup.grid.x[1]), 0.1), + # isorange = convert(eltype(setup.grid.x[1]), 0.5), + equal_axis = true, + levels = LinRange{eltype(setup.grid.x[1])}(-10, 5, 10), + docolorbar = false, + size = nothing, + type = contour, + kwargs..., +) + (; grid) = setup + (; xp, Ip) = grid + + xf = Array.(getindex.(xp, Ip.indices)) + dxf = diff.(xf) + xf = map(xf) do xf + dxf = diff(xf) + if all(≈(dxf[1]), dxf) + LinRange(xf[1], xf[end], length(xf)) + else + xf + end + end + + field = observefield(state; setup, fieldname, psolver) + + # color = lift(state) do (; temp) + # Array(view(temp, Ip)) + # end + # colorrange = lift(state) do (; temp) + # extrema(view(temp, Ip)) + # end + + # lims = @lift get_lims($field) + lims = isnothing(levels) ? lift(get_lims, field) : extrema(levels) + + isnothing(levels) && (levels = @lift(LinRange($(lims)..., 10))) + + # aspect = equal_axis ? (; aspect = :data) : (;) + size = isnothing(size) ? (;) : (; size) + fig = Figure(; size...) + # ax = Axis3(fig[1, 1]; title = titlecase(string(fieldname)), aspect...) + if type == volume + hm = volume( + fig[1, 1], + xf..., + field; + # colorrange = lims, + kwargs..., + ) + elseif type == contour + hm = contour( + fig[1, 1], + # ax, + xf..., + field; + levels, + # color = xf[2]' .+ 0 .* field[], + # colorrange, + colorrange = lims, + # colorrange = extrema(levels), + alpha, + # isorange, + # highclip = :red, + # lowclip = :red, + kwargs..., + ) + end + docolorbar && Colorbar(fig[1, 2], hm) + fig +end + +""" +Create energy history plot. +""" +function energy_history_plot(state; setup) + @assert state isa Observable "Energy history requires observable state." + (; Ip) = setup.grid + e = scalarfield(setup) + _points = Point2f[] + points = lift(state) do (; u, t) + kinetic_energy!(e, u, setup) + scalewithvolume!(e, setup) + E = sum(e[Ip]) + push!(_points, Point2f(t, E)) + end + fig = lines(points; axis = (; xlabel = "t", ylabel = "Kinetic energy")) + on(_ -> autolimits!(fig.axis), points) + fig +end + +""" +Create energy spectrum plot. +The energy at a scalar wavenumber level ``\\kappa \\in \\mathbb{N}`` is defined by + +```math +\\hat{e}(\\kappa) = \\int_{\\kappa \\leq \\| k \\|_2 < \\kappa + 1} | \\hat{e}(k) | \\mathrm{d} k, +``` + +as in San and Staples [San2012](@cite). + +Keyword arguments: + +- `sloperange = [0.6, 0.9]`: Percentage (between 0 and 1) of x-axis where the slope is plotted. +- `slopeoffset = 1.3`: How far above the energy spectrum the inertial slope is plotted. +- `kwargs...`: They are passed to [`observespectrum`](@ref). +""" +function energy_spectrum_plot( + state; + setup, + sloperange = [0.6, 0.9], + slopeoffset = 1.3, + kwargs..., +) + state isa Observable || (state = Observable(state)) + + (; dimension, xp, Ip) = setup.grid + T = eltype(xp[1]) + D = dimension() + + (; ehat, κ) = observespectrum(state; setup, kwargs...) + + kmax = maximum(κ) + + # Build inertial slope above energy + krange = kmax .^ sloperange + slope, slopelabel = D == 2 ? (-T(3), L"$k^{-3}$") : (-T(5 / 3), L"$k^{-5/3}$") + inertia = lift(ehat) do ehat + (m, i) = findmax(ehat ./ κ .^ slope) + slopeconst = m + dk = exp(log(kmax) * 0.5) + # kpoints = κ[i] / dk, κ[i] * dk + kpoints = κ[i] / (dk / 3), min(κ[i] * dk, kmax) + slopepoints = @. slopeoffset * slopeconst * kpoints^slope + [Point2f(kpoints[1], slopepoints[1]), Point2f(kpoints[2], slopepoints[2])] + end + + # Nice ticks + logmax = round(Int, log2(kmax + 1)) + xticks = T(2) .^ (0:logmax) + + fig = Figure() + ax = Axis( + fig[1, 1]; + xticks, + xlabel = "k", + # ylabel = "E(k)", + xscale = log10, + yscale = log10, + limits = (1, kmax, T(1e-8), T(1)), + ) + lines!(ax, κ, ehat; label = "Kinetic energy") + lines!(ax, inertia; label = slopelabel, linestyle = :dash, color = Cycled(2)) + axislegend(ax; position = :lb) + # autolimits!(ax) + on(e -> autolimits!(ax), ehat) + autolimits!(ax) + fig +end + +end diff --git a/lib/NeuralClosure/Project.toml b/lib/NeuralClosure/Project.toml index 491ba0510..e2927e907 100644 --- a/lib/NeuralClosure/Project.toml +++ b/lib/NeuralClosure/Project.toml @@ -29,7 +29,7 @@ Accessors = "0.1" ComponentArrays = "0.15" DocStringExtensions = "0.9" IncompressibleNavierStokes = "2" -JLD2 = "0.5.7" +JLD2 = "0.5" KernelAbstractions = "0.9" LinearAlgebra = "1" Lux = "1" @@ -37,7 +37,7 @@ MLUtils = "0.4" Makie = "0.21" NNlib = "0.9" Observables = "0.5" -Optimisers = "0.3" +Optimisers = "0.3, 0.4" Printf = "1" Random = "1" Zygote = "0.6" diff --git a/lib/NeuralClosure/test/Project.toml b/lib/NeuralClosure/test/Project.toml index ce9a5a50d..8d30cdefa 100644 --- a/lib/NeuralClosure/test/Project.toml +++ b/lib/NeuralClosure/test/Project.toml @@ -10,18 +10,18 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +[sources.IncompressibleNavierStokes] +path = "../../.." + +[sources.NeuralClosure] +path = ".." + [compat] Aqua = "0.8" CairoMakie = "0.12" IncompressibleNavierStokes = "2" Logging = "1" -Optimisers = "0.3" -TestItemRunner = "1" +Optimisers = "0.3, 0.4" Random = "1" +TestItemRunner = "1" julia = "1.9" - -[sources.IncompressibleNavierStokes] -path = "../../.." - -[sources.NeuralClosure] -path = ".." diff --git a/lib/PaperDC/Project.toml b/lib/PaperDC/Project.toml index 844af48e6..ce93ca4a5 100644 --- a/lib/PaperDC/Project.toml +++ b/lib/PaperDC/Project.toml @@ -29,16 +29,18 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[sources] -IncompressibleNavierStokes = {path = "../.."} -NeuralClosure = {path = "../NeuralClosure"} +[sources.IncompressibleNavierStokes] +path = "../.." + +[sources.NeuralClosure] +path = "../NeuralClosure" [compat] Accessors = "0.1" Adapt = "4" CUDA = "5" CairoMakie = "0.12" -Dates = "1.11" +Dates = "1" DocStringExtensions = "0.9" EnumX = "1" FFTW = "1" @@ -49,11 +51,11 @@ LinearAlgebra = "1" LoggingExtras = "1" Lux = "1" LuxCUDA = "0.3" -MLUtils = "0.4.4" +MLUtils = "0.4" NNlib = "0.9" NeuralClosure = "1" Observables = "0.5" -Optimisers = "0.3" -ParameterSchedulers = "0.4.2" +Optimisers = "0.3, 0.4" +ParameterSchedulers = "0.4" SparseArrays = "1" julia = "1.9" diff --git a/lib/PaperDC/postanalysis.jl b/lib/PaperDC/postanalysis.jl index c4ca92fec..2c53308a4 100644 --- a/lib/PaperDC/postanalysis.jl +++ b/lib/PaperDC/postanalysis.jl @@ -625,6 +625,8 @@ with_theme(; palette) do title = "$lesmodel", ylabel = "A-posteriori error", ylabelvisible = iorder == 1, + yticksvisible = iorder == 1, + yticklabelsvisible = iorder == 1, ) for (e, marker, label, color) in [ (epost.nomodel, :circle, "No closure", Cycled(1)), @@ -639,6 +641,7 @@ with_theme(; palette) do end # ylims!(ax, (T(0.025), T(1.00))) end + linkaxes!(filter(x -> x isa Axis, fig.content)...) g = GridLayout(fig[1, end+1]) Legend(g[1, 1], filter(x -> x isa Axis, fig.content)[1]; valign = :bottom) Legend( diff --git a/lib/PaperDC/postanalysis3D.jl b/lib/PaperDC/postanalysis3D.jl index eda872585..2a63b92ed 100644 --- a/lib/PaperDC/postanalysis3D.jl +++ b/lib/PaperDC/postanalysis3D.jl @@ -686,6 +686,50 @@ with_theme(; palette) do display(fig) end +# ## Plot both + +with_theme(; palette) do + doplot() || return + fig = Figure(; size = (800, 300)) + axes = [] + ifil = 1 + iorder = 1 + axprior = Axis( + fig[1, ifil]; + xscale = log10, + xticks = params.nles, + xlabel = "Resolution", + title = "A-priori error", + ) + for (e, marker, label, color) in [ + (eprior.nomodel, :circle, "No closure", Cycled(1)), + (eprior.prior[:, ifil], :rect, "CNN (prior)", Cycled(3)), + (eprior.post[:, ifil, 1], :diamond, "CNN (post)", Cycled(4)), + ] + scatterlines!(axprior, params.nles, e; marker, color, label) + end + axpost = Axis( + fig[1, 2]; + xscale = log10, + # yscale = log10, + xticks = params.nles, + xlabel = "Resolution", + title = "A-posteriori error", + ) + for (e, marker, label, color) in [ + (epost.nomodel, :circle, "No closure", Cycled(1)), + (epost.smag, :utriangle, "Smagorinsky", Cycled(2)), + (epost.cnn_prior, :rect, "CNN (Lprior)", Cycled(3)), + (epost.cnn_post, :diamond, "CNN (Lpost)", Cycled(4)), + ] + scatterlines!(axpost, params.nles, e[:, ifil, iorder]; color, marker, label) + end + # linkaxes!(axprior, axpost) + Legend(fig[1, end+1], axpost) + save("$plotdir/epriorandpost.pdf", fig) + display(fig) +end + ########################################################################## #src # ## Energy evolution @@ -707,7 +751,7 @@ let psolver = default_psolver(setup) sample = namedtupleload(getdatafile(outdir, nles, Φ, dns_seeds_test[1])) ustart = selectdim(sample.u, ndims(sample.u), 1) |> collect |> device - T = eltype(ustart[1]) + T = eltype(ustart) # Shorter time for DIF t_DIF = T(1) diff --git a/lib/SymmetryClosure/Project.toml b/lib/SymmetryClosure/Project.toml index 989bf032d..838a9504d 100644 --- a/lib/SymmetryClosure/Project.toml +++ b/lib/SymmetryClosure/Project.toml @@ -23,10 +23,10 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -[sources.IncompressibleNavierStokes] +[sources.IncompressibleNavierStokes] path = "../.." -[sources.NeuralClosure] +[sources.NeuralClosure] path = "../NeuralClosure" [compat] @@ -34,7 +34,7 @@ Adapt = "4" CUDA = "5" CairoMakie = "0.12" ComponentArrays = "0.15" -FFTW = "1.8" +FFTW = "1" GLMakie = "0.10" IncompressibleNavierStokes = "2" JLD2 = "0.5" @@ -44,6 +44,6 @@ LuxCUDA = "0.3" NNlib = "0.9" NeuralClosure = "1" Observables = "0.5" -Optimisers = "0.3" +Optimisers = "0.3, 0.4" Zygote = "0.6" julia = "1.9" diff --git a/src/IncompressibleNavierStokes.jl b/src/IncompressibleNavierStokes.jl index 684ac04ca..a9931c465 100644 --- a/src/IncompressibleNavierStokes.jl +++ b/src/IncompressibleNavierStokes.jl @@ -13,12 +13,15 @@ using Adapt using ChainRulesCore using DocStringExtensions using FFTW +using Enzyme +import .EnzymeRules: reverse, augmented_primal +using .EnzymeRules using IterativeSolvers using KernelAbstractions using KernelAbstractions.Extras.LoopInfo: @unroll using LinearAlgebra -using Makie using NNlib +using Observables using PrecompileTools using Printf using Random @@ -26,6 +29,7 @@ using SparseArrays using StaticArrays using Statistics using WriteVTK: CollectionFile, paraview_collection, vtk_grid, vtk_save +using Zygote # Docstring templates @template MODULES = """ @@ -66,6 +70,7 @@ include("eddyviscosity.jl") include("matrices.jl") include("initializers.jl") include("processors.jl") +include("sciml.jl") include("solver.jl") include("utils.jl") @@ -145,4 +150,8 @@ export apply_bc_u, Dfield, Qfield +# SciML operations +export create_right_hand_side, + right_hand_side! + end diff --git a/src/boundary_conditions.jl b/src/boundary_conditions.jl index 57a30fefa..5eac0d218 100644 --- a/src/boundary_conditions.jl +++ b/src/boundary_conditions.jl @@ -155,7 +155,6 @@ ChainRulesCore.rrule(::typeof(apply_bc_temp), temp, t, setup) = ( NoTangent(), ), ) - "Apply velocity boundary conditions (in-place version)." function apply_bc_u!(u, t, setup; kwargs...) (; boundary_conditions) = setup @@ -349,7 +348,7 @@ function apply_bc_u!(bc::DirichletBC, u, β, t, setup; isright, dudt = false, kw (α, args...) -> dudt ? zero(bc.u[α]) : bc.u[α] elseif dudt # Use central difference to approximate dudt - h = sqrt(eps(eltype(u[1]))) / 2 + h = sqrt(eps(eltype(u))) / 2 function (args...) args..., t = args (bc.u(args..., t + h) - bc.u(args..., t - h)) / 2h @@ -528,3 +527,40 @@ apply_bc_temp!(bc::PressureBC, temp, β, t, setup; isright, kwargs...) = apply_bc_temp_pullback!(bc::PressureBC, φbar, β, t, setup; isright, kwargs...) = apply_bc_p_pullback!(SymmetricBC(), φbar, β, t, setup; isright, kwargs...) + + +# Wrap a function to return `nothing`, because Enzyme can not handle vector return values. +function enzyme_wrap(f::Union{typeof(apply_bc_u!), typeof(apply_bc_p!), typeof(apply_bc_temp!)}) + # the boundary condition modifies x which is usually the field that we want to differentiate, so we need to introduce a copy of it and modify it instead + function wrapped_f(y, x, args...) + y .= x + f(y, args...) + return nothing + end + return wrapped_f +end + +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(apply_bc_u!))}, Const{typeof(enzyme_wrap(apply_bc_p!))}, Const{typeof(enzyme_wrap(apply_bc_temp!))}}, ::Type{<:Const}, y::Duplicated, x::Duplicated, t::Const, setup::Const) + @info "augmented_primal" + primal = func.val(y.val, x.val, t.val, setup.val) + return AugmentedReturn(primal, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_u!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) + @info "reverse" + adj = apply_bc_u_pullback!(x.val, t.val, setup.val) + x.dval .+= adj + y.dval .= x.dval # y is a copy of x + return (nothing, nothing, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_p!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) + adj = apply_bc_p_pullback!(x.val, t.val, setup.val) + x.dval .+= adj + y.dval .= x.dval # y is a copy of x + return (nothing, nothing, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_temp!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) + adj = apply_bc_temp_pullback!(x.val, t.val, setup.val) + x.dval .+= adj + y.dval .= x.dval # y is a copy of x + return (nothing, nothing, nothing, nothing) +end diff --git a/src/initializers.jl b/src/initializers.jl index d6e6cc2e7..3833eaf97 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -202,7 +202,7 @@ function random_field( all(==((PeriodicBC(), PeriodicBC())), boundary_conditions), "Random field requires periodic boundary conditions." ) - @assert all(Δ -> all(≈(Δ[1]), Δ), Δ) "Random field requires uniform grid spacing." + @assert all(Δ -> all(≈(Δ[1]), Δ), Array.(Δ)) "Random field requires uniform grid spacing." @assert all(iseven, N) "Random field requires even number of volumes." # Create random velocity field diff --git a/src/operators.jl b/src/operators.jl index 01e54ef41..fc6dcea3c 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -198,12 +198,12 @@ end end # "Subtract pressure gradient (differentiable version)." -# applypressure(u, p, setup) = applypressure!(copy.(u), p, setup) -# -# ChainRulesCore.rrule(::typeof(applypressure), p, setup) = ( -# applypressure(u, p, setup), -# φ -> (NoTangent(), applypressure_adjoint!(scalarfield(setup), (φ...,), setup), NoTangent()), -# ) +applypressure(u, p, setup) = applypressure!(copy.(u), p, setup) + +ChainRulesCore.rrule(::typeof(applypressure), u, p, setup) = ( + applypressure(u, p, setup), + φ -> (NoTangent(), NoTangent(), applypressure_adjoint!(scalarfield(setup), φ, nothing, setup), NoTangent()), +) "Subtract pressure gradient (in-place version)." function applypressure!(u, p, setup) @@ -225,22 +225,61 @@ end end end -# function applypressure_adjoint!(pbar, φ, u, setup) -# (; grid, backend, workgroupsize) = setup -# (; dimension, Δu, N, Iu) = grid -# D = dimension() -# e = Offset(D) -# @kernel function applypressure_adjoint_kernel!(p, φ) -# I = @index(Global, Cartesian) -# p[I] = zero(eltype(p)) -# for α = 1:D -# I - e(α) ∈ Iu[α] && (p[I] += φ[α][I-e(α)] / Δu[α][I[α]-1]) -# I ∈ Iu[α] && (p[I] -= φ[α][I] / Δu[α][I[α]]) -# end -# end -# applypressure_adjoint_kernel!(backend, workgroupsize)(pbar, φ; ndrange = N) -# pbar -# end +#function applypressure_adjoint!(pbar, φ, u, setup) +# (; grid, backend, workgroupsize) = setup +# (; dimension, Δu, N, Iu) = grid +# D = dimension() +# e = Offset(D) +# @kernel function applypressure_adjoint_kernel!(p, φ) +# I = @index(Global, Cartesian) +# p[I] = zero(eltype(p)) +# for α = 1:D +# I - e(α) ∈ Iu[α] && (p[I] += φ[I-e(α),α] / Δu[α][I[α]-1]) +# I ∈ Iu[α] && (p[I] -= φ[I,α] / Δu[α][I[α]]) +# end +# end +# applypressure_adjoint_kernel!(backend, workgroupsize)(pbar, φ; ndrange = N) +# pbar +#end +function applypressure_adjoint!(pbar, φ, u, setup) + # Extract necessary components from the setup structure + (; grid, backend, workgroupsize) = setup + (; dimension, Δu, N, Iu) = grid + D = dimension() # Get the spatial dimension + e = Offset(D) # Offset function for indexing neighbors + + # Kernel definition for computing the adjoint + @kernel function applypressure_adjoint_kernel!(p, φ) + # Get the global index for the current thread + I = @index(Global, Cartesian) + + # Initialize the adjoint value at the current index to zero + local p_I = zero(eltype(p)) + + # Loop over each dimension to compute adjoint contributions + for α in 1:D + # Contribution from φ[I - e(α)] / Δu[α][I[α] - 1] + if I - e(α) ∈ Iu[α] + p_I += φ[I - e(α), α] / Δu[α][I[α] - 1] + end + + # Contribution from -φ[I, α] / Δu[α][I[α]] + if I ∈ Iu[α] + p_I -= φ[I, α] / Δu[α][I[α]] + end + end + + # Assign the computed value back to p + p[I] = p_I + end + + # Run the adjoint kernel on the backend, with specified workgroup size + applypressure_adjoint_kernel!(backend, workgroupsize)(pbar, φ; ndrange = N) + + # Return the adjoint result for p + return pbar +end + "Compute Laplacian of pressure field (differentiable version)." laplacian(p, setup) = laplacian!(scalarfield(setup), p, setup) @@ -532,7 +571,7 @@ end dims = getval(valdims) I = @index(Global, Cartesian) @unroll for α in dims - val = zero(eltype(u[1])) + val = zero(eltype(u)) @unroll for β in dims Δuαβ = α == β ? Δu[β] : Δ[β] # F[α][I] += visc * u[I+e(β), α] / (β == α ? Δ[β][I[β]+1] : Δu[β][I[β]]) @@ -636,6 +675,7 @@ convection_diffusion_temp(u, temp, setup) = function ChainRulesCore.rrule(::typeof(convection_diffusion_temp), u, temp, setup) conv = convection_diffusion_temp(u, temp, setup) convection_diffusion_temp_pullback(φ) = (NoTangent(), du, dtemp, NoTangent()) + @warn "Check if convection_diffusion_temp pullback behaves as expected" (conv, pullback) end @@ -774,12 +814,12 @@ end "Compute body force (differentiable version)." function applybodyforce(u, t, setup) (; grid, bodyforce, issteadybodyforce) = setup - (; dimension, x) = grid + (; dimension, xu) = grid D = dimension() if issteadybodyforce bodyforce else - stack(map(α -> bodyforce.(α, x[α]..., t), 1:D)) + stack(map(α -> bodyforce.(α, xu[α]..., t), 1:D)) end end @@ -797,10 +837,10 @@ function applybodyforce!(F, u, t, setup) (; grid, bodyforce, issteadybodyforce) = setup (; dimension, Iu, xu) = grid D = dimension() - for (α, Fα) in enumerate(eachslice(F; dims = D + 1)) - if issteadybodyforce - F .+= bodyforce - else + if issteadybodyforce + F .+= bodyforce + else + for (α, Fα) in enumerate(eachslice(F; dims = D + 1)) # xin = ntuple( # β -> reshape(xu[α][β][Iu[α].indices[β]], ntuple(Returns(1), β - 1)..., :), # D, @@ -870,15 +910,13 @@ Right hand side of momentum equations, excluding pressure gradient (differentiable version). """ function momentum(u, temp, t, setup) - (; grid, bodyforce, closure_model) = setup - (; dimension) = grid - D = dimension() + (; bodyforce) = setup d = diffusion(u, setup) c = convection(u, setup) F = @. d + c if !isnothing(bodyforce) f = applybodyforce(u, t, setup) - F = F .+ f + F = @. F + f end if !isnothing(temp) g = gravity(temp, setup) @@ -902,26 +940,16 @@ Right hand side of momentum equations, excluding pressure gradient (in-place version). """ function momentum!(F, u, temp, t, setup) - (; grid, closure_model, bodyforce, temperature) = setup + (; grid, bodyforce) = setup (; dimension) = grid D = dimension() fill!(F, 0) - # diffusion!(F, u, setup) - # convection!(F, u, setup) convectiondiffusion!(F, u, setup) isnothing(bodyforce) || applybodyforce!(F, u, t, setup) isnothing(temp) || gravity!(F, temp, setup) F end -# monitor(u) = (@info("Forward", typeof(u)); u) -# ChainRulesCore.rrule(::typeof(monitor), u) = -# (monitor(u), φ -> (@info("Reverse", typeof(φ)); (NoTangent(), φ))) - -# tupleadd(u...) = ntuple(α -> sum(u -> u[α], u), length(u[1])) -# ChainRulesCore.rrule(::typeof(tupleadd), u...) = -# (tupleadd(u...), φ -> (NoTangent(), map(u -> φ, u)...)) - "Compute vorticity field (differentiable version)." vorticity(u, setup) = vorticity!( setup.grid.dimension() == 2 ? scalarfield(setup) : vectorfield(setup), @@ -979,9 +1007,9 @@ end @SMatrix [∂x(u, I, α, β, Δ[β], Δu[β]) for α = 1:2, β = 1:2] @inline ∇(u, I::CartesianIndex{3}, Δ, Δu) = @SMatrix [∂x(u, I, α, β, Δ[β], Δu[β]) for α = 1:3, β = 1:3] -@inline idtensor(u, I::CartesianIndex{2}) = +@inline idtensor(u, ::CartesianIndex{2}) = @SMatrix [(α == β) * oneunit(eltype(u)) for α = 1:2, β = 1:2] -@inline idtensor(u, I::CartesianIndex{3}) = +@inline idtensor(u, ::CartesianIndex{3}) = @SMatrix [(α == β) * oneunit(eltype(u)) for α = 1:3, β = 1:3] @inline function strain(u, I, Δ, Δu) ∇u = ∇(u, I, Δ, Δu) @@ -1432,3 +1460,164 @@ function get_scale_numbers(u, setup) τ = nothing (; uavg, ϵ, η, λ, Reλ, L, τ) end + +# Wrap a function to return `nothing`, because Enzyme can not handle vector return values. +function enzyme_wrap(f::Union{typeof(divergence!), typeof(pressuregradient!), typeof(convection!), typeof(diffusion!), typeof(applybodyforce!), typeof(gravity!), typeof(dissipation!), typeof(convection_diffusion_temp!), typeof(momentum!)}) + function wrapped_f(args...) + f(args...) + return nothing + end + return wrapped_f +end + + +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(divergence!))},Const{typeof(enzyme_wrap(pressuregradient!))},Const{typeof(enzyme_wrap(convection!))},Const{typeof(enzyme_wrap(diffusion!))},Const{typeof(enzyme_wrap(gravity!))}}, ::Type{<:Const}, y::Duplicated, u::Duplicated, setup::Const) + primal = func.val(y.val, u.val, setup.val) + if overwritten(config)[3] + tape = copy(u.val) + else + tape = nothing + end + return AugmentedReturn(primal, nothing, tape) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(divergence!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) + adj = vectorfield(setup.val) + divergence_adjoint!(adj, y.val, setup.val) + u.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(pressuregradient!))}, dret, tape, y::Duplicated, p::Duplicated, setup::Const) + adj = scalarfield(setup.val) + pressuregradient_adjoint!(adj, y.val, setup.val) + p.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(convection!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) + adj = zero(u.val) + convection_adjoint!(adj, y.val, u.val, setup.val) + u.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(diffusion!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) + adj = zero(u.val) + diffusion_adjoint!(adj, y.val, setup.val) + u.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing) +end + + +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(applybodyforce!))}}, ::Type{<:Const}, y::Duplicated, u::Duplicated, t::Const, setup::Const) + primal = func.val(y.val, u.val, t.val, setup.val) + if overwritten(config)[3] + tape = copy(u.val) + else + tape = nothing + end + return AugmentedReturn(primal, nothing, tape) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(applybodyforce!))}, dret, tape, y::Duplicated, u::Duplicated, t::Const, setup::Const) + @warn "bodyforce Enzyme-AD tested only for issteadybodyforce=true" + adj = setup.val.bodyforce + u.dval .+= adj .* y.dval + make_zero!(y.dval) + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(gravity!))}, dret, tape, y::Duplicated, temp::Duplicated, setup::Const) + (; grid, backend, workgroupsize, temperature) = setup.val + (; dimension, Δ, N, Iu) = grid + (; gdir, α2) = temperature + backend = get_backend(temp.val) + D = dimension() + e = Offset(D) + function gravity_pullback(φ) + @kernel function g!(tempbar, φbar, valα) + α = getval(valα) + J = @index(Global, Cartesian) + t = zero(eltype(tempbar)) + # 1 + I = J + I ∈ Iu[α] && (t += α2 * Δ[α][I[α]+1] * φbar[I, α] / (Δ[α][I[α]] + Δ[α][I[α]+1])) + # 2 + I = J - e(α) + I ∈ Iu[α] && (t += α2 * Δ[α][I[α]] * φbar[I, α] / (Δ[α][I[α]] + Δ[α][I[α]+1])) + tempbar[J] = t + end + tempbar = zero(temp.val) + g!(backend, workgroupsize)(tempbar, φ, Val(gdir); ndrange = N) + tempbar + end + adj = gravity_pullback(y.val) + temp.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing) +end + + +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(dissipation!))}, Const{typeof(enzyme_wrap(convection_diffusion_temp!))}}, ::Type{<:Const}, y::Duplicated, x1::Duplicated, x2::Duplicated, setup::Const) + primal = func.val(y.val, x1.val, x2.val, setup.val) + if overwritten(config)[3] + tape = copy(x2.val) + else + tape = nothing + end + return AugmentedReturn(primal, nothing, tape) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(dissipation!))}, dret, tape, y::Duplicated, d::Duplicated, u::Duplicated, setup::Const) + (; grid, backend, workgroupsize, Re, temperature) = setup.val + (; dimension, N, Ip) = grid + (; α1, γ) = temperature + D = dimension() + e = Offset(D) + @kernel function ∂φ!(ubar, dbar, φbar, d, u, valdims) + J = @index(Global, Cartesian) + @unroll for β in getval(valdims) + # Compute ubar + a = zero(eltype(u)) + # 1 + I = J + e(β) + I ∈ Ip && (a += Re * α1 / γ * d[I-e(β), β] / 2) + # 2 + I = J + I ∈ Ip && (a += Re * α1 / γ * d[I, β] / 2) + ubar[J, β] += a + + # Compute dbar + b = zero(eltype(u)) + # 1 + I = J + e(β) + I ∈ Ip && (b += Re * α1 / γ * u[I-e(β), β] / 2) + # 2 + I = J + I ∈ Ip && (b += Re * α1 / γ * u[I, β] / 2) + dbar[J, β] += b + end + end + function dissipation_pullback(φbar) + # Dφ/Du = ∂φ(u, d)/∂u + ∂φ(u, d)/∂d ⋅ ∂d(u)/∂u + dbar = zero(u.val) + ubar = zero(u.val) + ∂φ!(backend, workgroupsize)(ubar, dbar, φbar, d.val, u.val, Val(1:D); ndrange = N) + diffusion_adjoint!(ubar, dbar, setup.val) + ubar + end + adj = dissipation_pullback(y.val) + u.dval .+= adj + make_zero!(y.dval) + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(convection_diffusion_temp!))}, dret, tape, y::Duplicated, temp::Duplicated, u::Duplicated, setup::Const) + @error "convection_diffusion_temp Enzyme-AD not yet implemented" +end + +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(momentum!))}}, ::Type{<:Const}, y::Duplicated, x1::Duplicated, x2::Duplicated, x3::Duplicated, t::Const, setup::Const) + @error "momentum Enzyme-AD not yet implemented" +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(momentum!))}, dret, tape, y::Duplicated, u::Duplicated, temp::Duplicated, t::Const, setup::Const) + @error "momentum Enzyme-AD not yet implemented" +end \ No newline at end of file diff --git a/src/pressure.jl b/src/pressure.jl index b96c8d1d1..a9d8d4521 100644 --- a/src/pressure.jl +++ b/src/pressure.jl @@ -357,3 +357,23 @@ function psolver_spectral(setup) p end end + +# Wrap a function to return `nothing`, because Enzyme can not handle vector return values. +function enzyme_wrap(f::typeof(poisson!)) + function wrapped_f(y, args...) + y.= f(args...) + return nothing + end + return wrapped_f +end +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(poisson!))}, ::Type{<:Const}, y::Duplicated, psolver::Const, div::Duplicated) + primal = func.val(y.val, psolver.val, div.val) + return AugmentedReturn(primal, nothing, nothing) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(poisson!))}, dret, tape, y::Duplicated, psolver::Const, div::Duplicated) + auto_adj = copy(y.val) + func.val(auto_adj, psolver.val, y.val ) + div.dval .+= auto_adj .* y.dval + make_zero!(y.dval) + return (nothing, nothing, nothing) +end diff --git a/src/processors.jl b/src/processors.jl index 421977198..bc3f7114d 100644 --- a/src/processors.jl +++ b/src/processors.jl @@ -139,7 +139,7 @@ function observefield( end # Observe field - field = lift(state) do (; u, temp, t) + field = map(state) do (; u, temp, t) f = if fieldname in (1, 2, 3) interpolate_u_p!(up, u, setup) upf @@ -285,287 +285,6 @@ fieldsaver(; setup, nupdate = 1) = states end -""" -Animate a plot of the solution every `update` iteration. -The animation is saved to `path`, which should have one -of the following extensions: - -- ".mkv" -- ".mp4" -- ".webm" -- ".gif" - -The plot is determined by a `plotter` processor. -Additional `kwargs` are passed to `plot`. -""" -animator(; - setup, - path, - plot = fieldplot, - nupdate = 1, - framerate = 24, - visible = true, - screen = nothing, - kwargs..., -) = - processor((stream, state) -> save(path, stream)) do outerstate - ispath(dirname(path)) || mkpath(dirname(path)) - state = Observable(outerstate[]) - fig = plot(state; setup, kwargs...) - visible && isnothing(screen) && display(fig) - visible && !isnothing(screen) && display(screen, fig) - stream = VideoStream(fig; framerate, visible) - on(outerstate) do outerstate - outerstate.n % nupdate == 0 || return - state[] = outerstate - recordframe!(stream) - end - stream - end - -""" -Processor for plotting the solution in real time. - -Keyword arguments: - -- `plot`: Plot function. -- `nupdate`: Show solution every `nupdate` time step. -- `displayfig`: Display the figure at the start. -- `screen`: If `nothing`, use default display. - If `GLMakie.screen()` multiple plots can be displayed in separate - windows like in MATLAB (see also `GLMakie.closeall()`). -- `displayupdates`: Display the figure at every update (if using CairoMakie). -- `sleeptime`: The `sleeptime` is slept at every update, to give Makie - time to update the plot. Set this to `nothing` to skip sleeping. - -Additional `kwargs` are passed to the `plot` function. -""" -realtimeplotter(; - setup, - plot = fieldplot, - nupdate = 1, - displayfig = true, - screen = nothing, - displayupdates = false, - sleeptime = nothing, - kwargs..., -) = - processor() do outerstate - state = Observable(outerstate[]) - fig = plot(state; setup, kwargs...) - displayfig && isnothing(screen) && display(fig) - displayfig && !isnothing(screen) && display(screen, fig) - on(outerstate) do outerstate - outerstate.n % nupdate == 0 || return - state[] = outerstate - displayupdates && display(fig) - isnothing(sleeptime) || sleep(sleeptime) - end - fig - end - -""" -Plot `state` field in pressure points. -If `state` is `Observable`, then the plot is interactive. - -Available fieldnames are: - -- `:velocity`, -- `:vorticity`, -- `:streamfunction`, -- `:pressure`. - -Available plot `type`s for 2D are: - -- `heatmap` (default), -- `image`, -- `contour`, -- `contourf`. - -Available plot `type`s for 3D are: - -- `contour` (default). - -The `alpha` value gets passed to `contour` in 3D. -""" -fieldplot(state; setup, kwargs...) = fieldplot( - setup.grid.dimension, - state isa Observable ? state : Observable(state); - setup, - kwargs..., -) - -function fieldplot( - ::Dimension{2}, - state; - setup, - fieldname = :vorticity, - psolver = nothing, - type = heatmap, - equal_axis = true, - docolorbar = true, - size = nothing, - title = nothing, - kwargs..., -) - (; grid) = setup - (; dimension, xlims, xp, Ip, Δ) = grid - D = dimension() - - xf = Array.(getindex.(xp, Ip.indices)) - - field = observefield(state; setup, fieldname, psolver) - - lims = lift(field) do f - if type ∈ (heatmap, image) - lims = get_lims(f) - elseif type ∈ (contour, contourf) - if ≈(extrema(f)...; rtol = 1e-10) - μ = mean(f) - a = μ - 1 - b = μ + 1 - f[1] += 1 - f[end] -= 1 - else - a, b = get_lims(f) - end - lims = (a, b) - end - lims - end - - if type ∈ (heatmap, image) - kwargs = (; colorrange = lims, kwargs...) - elseif type ∈ (contour, contourf) - kwargs = (; - extendlow = :auto, - extendhigh = :auto, - levels = @lift(LinRange($(lims)..., 10)), - # colorrange = lims, - kwargs..., - ) - end - - axis = (; - xlabel = "x", - ylabel = "y", - title = isnothing(title) ? titlecase(string(fieldname)) : title, - limits = (xlims[1]..., xlims[2]...), - ) - equal_axis && (axis = (axis..., aspect = DataAspect())) - - # Image requires boundary coordinates only - if type == image - Δx = first.(Array.(Δ)) - @assert all(≈(Δx[1]), Δx) "Image requires rectangular pixels" - @assert(all(α -> all(≈(Δx[α]), Δ[α]), 1:D), "Image requires uniform grid",) - xf = map(extrema, xf) - end - - size = isnothing(size) ? (;) : (; size) - fig = Figure(; size...) - ax, hm = type(fig[1, 1], xf..., field; axis, kwargs...) - docolorbar && Colorbar(fig[1, 2], hm) - - fig -end - -function fieldplot( - ::Dimension{3}, - state; - setup, - psolver = nothing, - fieldname = :eig2field, - alpha = convert(eltype(setup.grid.x[1]), 0.1), - # isorange = convert(eltype(setup.grid.x[1]), 0.5), - equal_axis = true, - levels = LinRange{eltype(setup.grid.x[1])}(-10, 5, 10), - docolorbar = false, - size = nothing, - type = contour, - kwargs..., -) - (; grid) = setup - (; xp, Ip) = grid - - xf = Array.(getindex.(xp, Ip.indices)) - dxf = diff.(xf) - xf = map(xf) do xf - dxf = diff(xf) - if all(≈(dxf[1]), dxf) - LinRange(xf[1], xf[end], length(xf)) - else - xf - end - end - - field = observefield(state; setup, fieldname, psolver) - - # color = lift(state) do (; temp) - # Array(view(temp, Ip)) - # end - # colorrange = lift(state) do (; temp) - # extrema(view(temp, Ip)) - # end - - # lims = @lift get_lims($field) - lims = isnothing(levels) ? lift(get_lims, field) : extrema(levels) - - isnothing(levels) && (levels = @lift(LinRange($(lims)..., 10))) - - # aspect = equal_axis ? (; aspect = :data) : (;) - size = isnothing(size) ? (;) : (; size) - fig = Figure(; size...) - # ax = Axis3(fig[1, 1]; title = titlecase(string(fieldname)), aspect...) - if type == volume - hm = volume( - fig[1, 1], - xf..., - field; - # colorrange = lims, - kwargs..., - ) - elseif type == contour - hm = contour( - fig[1, 1], - # ax, - xf..., - field; - levels, - # color = xf[2]' .+ 0 .* field[], - # colorrange, - colorrange = lims, - # colorrange = extrema(levels), - alpha, - # isorange, - # highclip = :red, - # lowclip = :red, - kwargs..., - ) - end - docolorbar && Colorbar(fig[1, 2], hm) - fig -end - -""" -Create energy history plot. -""" -function energy_history_plot(state; setup) - @assert state isa Observable "Energy history requires observable state." - (; Ip) = setup.grid - e = scalarfield(setup) - _points = Point2f[] - points = lift(state) do (; u, t) - kinetic_energy!(e, u, setup) - scalewithvolume!(e, setup) - E = sum(e[Ip]) - push!(_points, Point2f(t, E)) - end - fig = lines(points; axis = (; xlabel = "t", ylabel = "Kinetic energy")) - on(_ -> autolimits!(fig.axis), points) - fig -end - "Observe energy spectrum of `state`." function observespectrum(state; setup, npoint = 100, a = typeof(setup.Re)(1 + sqrt(5)) / 2) state isa Observable || (state = Observable(state)) @@ -580,7 +299,7 @@ function observespectrum(state; setup, npoint = 100, a = typeof(setup.Re)(1 + sq uhat = similar(xp[1], Complex{T}, Np) # up = interpolate_u_p(state[].u, setup) _ehat = zeros(T, length(κ)) - ehat = lift(state) do (; u) + ehat = map(state) do (; u) # interpolate_u_p!(up, u, setup) up = u # TODO: Maybe preallocate e and A * e @@ -598,88 +317,12 @@ function observespectrum(state; setup, npoint = 100, a = typeof(setup.Re)(1 + sq (; ehat, κ) end -""" -Create energy spectrum plot. -The energy at a scalar wavenumber level ``\\kappa \\in \\mathbb{N}`` is defined by - -```math -\\hat{e}(\\kappa) = \\int_{\\kappa \\leq \\| k \\|_2 < \\kappa + 1} | \\hat{e}(k) | \\mathrm{d} k, -``` - -as in San and Staples [San2012](@cite). - -Keyword arguments: - -- `sloperange = [0.6, 0.9]`: Percentage (between 0 and 1) of x-axis where the slope is plotted. -- `slopeoffset = 1.3`: How far above the energy spectrum the inertial slope is plotted. -- `kwargs...`: They are passed to [`observespectrum`](@ref). -""" -function energy_spectrum_plot( - state; - setup, - sloperange = [0.6, 0.9], - slopeoffset = 1.3, - kwargs..., -) - state isa Observable || (state = Observable(state)) - - (; dimension, xp, Ip) = setup.grid - T = eltype(xp[1]) - D = dimension() - - (; ehat, κ) = observespectrum(state; setup, kwargs...) - - kmax = maximum(κ) - - # Build inertial slope above energy - krange = kmax .^ sloperange - slope, slopelabel = D == 2 ? (-T(3), L"$k^{-3}$") : (-T(5 / 3), L"$k^{-5/3}$") - inertia = lift(ehat) do ehat - (m, i) = findmax(ehat ./ κ .^ slope) - slopeconst = m - dk = exp(log(kmax) * 0.5) - # kpoints = κ[i] / dk, κ[i] * dk - kpoints = κ[i] / (dk / 3), min(κ[i] * dk, kmax) - slopepoints = @. slopeoffset * slopeconst * kpoints^slope - [Point2f(kpoints[1], slopepoints[1]), Point2f(kpoints[2], slopepoints[2])] - end - - # Nice ticks - logmax = round(Int, log2(kmax + 1)) - xticks = T(2) .^ (0:logmax) - - fig = Figure() - ax = Axis( - fig[1, 1]; - xticks, - xlabel = "k", - # ylabel = "E(k)", - xscale = log10, - yscale = log10, - limits = (1, kmax, T(1e-8), T(1)), - ) - lines!(ax, κ, ehat; label = "Kinetic energy") - lines!(ax, inertia; label = slopelabel, linestyle = :dash, color = Cycled(2)) - axislegend(ax; position = :lb) - # autolimits!(ax) - on(e -> autolimits!(ax), ehat) - autolimits!(ax) - fig -end - -# # Make sure the figure is fully rendered before allowing code to continue -# if displayfig -# render = display(espec) -# done_rendering = Ref(false) -# on(render.render_tick) do _ -# done_rendering[] = true -# end -# on(state) do s -# # State is updated, block code execution until GLMakie has rendered -# # figure update -# done_rendering[] = false -# while !done_rendering[] -# sleep(checktime) -# end -# end -# end +# These empty functions are defined here, but implemented in +# ext/IncompressibleNavierStokesMakieExt.jl +# To use them, load Makie using e.g. GLMakie or CairoMakie +# This reduces a lot of dependencies if plotting is not required +function animator end +function realtimeplotter end +function fieldplot end +function energy_history_plot end +function energy_spectrum_plot end diff --git a/src/sciml.jl b/src/sciml.jl new file mode 100644 index 000000000..b6474cbbd --- /dev/null +++ b/src/sciml.jl @@ -0,0 +1,97 @@ + +function create_right_hand_side(setup, psolver) + function right_hand_side(u, param, t) + F = zeros(size(u)) + u = apply_bc_u(u, t, setup) + #F = convection(u, setup) .+ diffusion(u, setup) + F = momentum(u, nothing, t, setup) + F = apply_bc_u(F, t, setup; dudt = true) + FP = project(F, setup; psolver) + #p = divergence(F, setup) + #p = scalewithvolume(p, setup) + #p = poisson(psolver, p) + #p = apply_bc_p(p, t, setup) + #G = pressuregradient(p, setup) + #F .- G + end +end + + +function right_hand_side!(dudt, u, params_ref, t) + params = params_ref[] + setup = params[1] + psolver = params[2] + p = scalarfield(setup) + # [!]*** be careful to not touch u in this function! + temp_vector = copy(u) + apply_bc_u!(temp_vector, t, setup) + momentum!(dudt, temp_vector, nothing, t, setup) + #fill!(dudt, 0) + #convectiondiffusion!(dudt, temp_vector, setup) + apply_bc_u!(dudt, t, setup) + project!(dudt, setup; psolver, p) + #divergence!(p, dudt, setup) + #scalewithvolume!(p, setup) + #poisson!(psolver, p) + #apply_bc_p!(p, t, setup) + #applypressure!(dudt, p, setup) + return nothing +end + + +using Enzyme +import .EnzymeRules: reverse, augmented_primal +using .EnzymeRules +function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, ::Type{<:Const}, dudt::Duplicated, u::Duplicated, params_ref::Any, t::Const) + # this runs function to modify dudt and store the intermediates + params = params_ref.val[] + setup = params[1] + psolver = params[2] + p = scalarfield(setup) + u_bc = copy(u.val) + apply_bc_u!(u_bc, t.val, setup) + momentum!(dudt.val, u_bc, nothing, t, setup) + #fill!(dudt.val, 0) + #convectiondiffusion!(dudt.val, u_bc, setup) + apply_bc_u!(dudt.val, t.val, setup) + project!(dudt.val, setup; psolver, p) + #divergence!(p, dudt.val, setup) + #scalewithvolume!(p, setup) + #poisson!(psolver, p) + #apply_bc_p!(p, t.val, setup) + #applypressure!(dudt.val, p, setup) + return AugmentedReturn(nothing, nothing, u_bc) +end +function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, dret, u_bc, dudt::Duplicated, u::Duplicated, params_ref::Const, t::Const) + # unpack the parameters + params = params_ref.val[] + setup = params[1] + psolver = params[2] + temp_scalar = scalarfield(setup) + dp = scalarfield(setup) + temp_vector = vectorfield(setup) + + # traverse the graph backwards + # [!] notice that the chain starts from the final value of dudt because it gets modified in place in the forward pass + dudt.dval .*= dudt.val + # [!] the minus sign is missing somewhere in the adjoint + dp .= - applypressure_adjoint!(temp_scalar, dudt.dval, nothing, setup) + + apply_bc_p_pullback!(dp, t.val, setup) + + poisson!(psolver,dp) + scalewithvolume!(dp, setup) + + dudt.dval .+= divergence_adjoint!(temp_vector, dp, setup) + + apply_bc_u_pullback!(dudt.dval, t.val, setup) + + fill!(temp_vector, 0) + u.dval .= convection_adjoint!(temp_vector, dudt.dval, u_bc, setup) + fill!(temp_vector, 0) + u.dval .+= diffusion_adjoint!(temp_vector, dudt.dval, setup) + + apply_bc_u_pullback!(u.dval, t.val, setup) + + return (nothing, nothing, nothing, nothing) +end \ No newline at end of file diff --git a/src/setup.jl b/src/setup.jl index eb56c63af..38d841a82 100644 --- a/src/setup.jl +++ b/src/setup.jl @@ -3,6 +3,7 @@ function Setup(; x, boundary_conditions = ntuple(d -> (PeriodicBC(), PeriodicBC()), length(x)), bodyforce = nothing, + dbodyforce = nothing, issteadybodyforce = true, closure_model = nothing, backend = CPU(), @@ -29,6 +30,18 @@ function Setup(; bodyforce = applybodyforce!(F, u, T(0), setup) setup = (; setup..., issteadybodyforce = true, bodyforce) end + if !isnothing(dbodyforce) + @warn "dbodyforce is not used at the moment. No need to define it." + if issteadybodyforce + dsetup = (; setup..., bodyforce = dbodyforce, issteadybodyforce = false) + (; x) = setup.grid + T = eltype(x[1]) + u = vectorfield(setup) + F = vectorfield(setup) + dbodyforce = applybodyforce!(F, u, T(0), dsetup) + end + setup = (; setup..., dbodyforce) + end setup end diff --git a/src/solver.jl b/src/solver.jl index a82a85670..d85143242 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -20,11 +20,11 @@ function solve_unsteady(; tlims, ustart, tempstart = nothing, - method = RKMethods.RK44(; T = eltype(ustart[1])), + method = RKMethods.RK44(; T = eltype(ustart)), psolver = default_psolver(setup), Δt = nothing, Δt_min = nothing, - cfl = eltype(ustart[1])(0.9), + cfl = eltype(ustart)(0.9), n_adapt_Δt = 1, docopy = true, processors = (;), @@ -104,7 +104,7 @@ function get_cfl_timestep!(buf, u, setup) D = dimension() # Initial maximum step size - Δt = eltype(u[1])(Inf) + Δt = eltype(u)(Inf) # Check maximum step size in each dimension for (α, uα) in enumerate(eachslice(u; dims = D + 1)) diff --git a/src/time_steppers/step_explicit_runge_kutta.jl b/src/time_steppers/step_explicit_runge_kutta.jl index 89e6193e9..64ec6e953 100644 --- a/src/time_steppers/step_explicit_runge_kutta.jl +++ b/src/time_steppers/step_explicit_runge_kutta.jl @@ -68,6 +68,7 @@ function timestep(method::ExplicitRungeKuttaMethod, stepper, Δt; θ = nothing) # Update current solution (does not depend on previous step size) tstart = t ustart = u + tempstart = temp ku = () ktemp = () diff --git a/src/time_steppers/step_lmwray3.jl b/src/time_steppers/step_lmwray3.jl index 5ced422a7..9ef7f5545 100644 --- a/src/time_steppers/step_lmwray3.jl +++ b/src/time_steppers/step_lmwray3.jl @@ -13,18 +13,20 @@ function timestep!(method::LMWray3, stepper, Δt; θ = nothing, cache) # Right-hand side function (without projection) function f!(dx, x, t, setup) - (; u, temp) = x - apply_bc_u!(u, t, setup) - isnothing(temp) || apply_bc_temp!(temp, t, setup) - momentum!(dx.u, u, temp, t, setup) - if !isnothing(temp) - dx.temp .= 0 - convection_diffusion_temp!(dx.temp, u, temp, setup) - temperature.dodissipation && dissipation!(dx.temp, diff, u, setup) - end + # Velocity equation + apply_bc_u!(x.u, t, setup) + isnothing(x.temp) || apply_bc_temp!(x.temp, t, setup) + momentum!(dx.u, x.u, x.temp, t, setup) # Add closure term - isnothing(m) || (dx.u .+= m(u, θ)) + isnothing(m) || (dx.u .+= m(x.u, θ)) + + # Temperature equation + if !isnothing(x.temp) + fill!(dx.temp, 0) + convection_diffusion_temp!(dx.temp, x.u, x.temp, setup) + temperature.dodissipation && dissipation!(dx.temp, diff, x.u, setup) + end dx end @@ -47,10 +49,8 @@ function timestep!(method::LMWray3, stepper, Δt; θ = nothing, cache) # Compute y = a * x + y for states x, y function state_axpy!(a, x, y) - @. y.u += a * x.u - if !isnothing(temp) - @. y.temp += a * x.temp - end + axpy!(a, x.u, y.u) + isnothing(temp) || axpy!(a, x.temp, y.temp) end # States @@ -69,7 +69,7 @@ function timestep!(method::LMWray3, stepper, Δt; θ = nothing, cache) # c4 | b1 b2 a3 0 ⋯ 0 # ⋮ | ⋮ ⋮ ⋮ ⋱ ⋱ ⋮ # cn | b1 b2 b3 ⋯ an-1 0 - # --+-------------------- + # ---+-------------------- # | b1 b2 b3 ⋯ bn-1 an # # Note the definition of (ai)i. @@ -101,9 +101,94 @@ function timestep!(method::LMWray3, stepper, Δt; θ = nothing, cache) # since we divide by an infinitely thin (eps(T)) volume width in the # diffusion term apply_bc_u!(x.u, t, setup) - isnothing(temp) || apply_bc_temp!(x.temp, t, setup) + isnothing(x.temp) || apply_bc_temp!(x.temp, t, setup) create_stepper(method; setup, psolver, x.u, x.temp, t, n = n + 1) end -timestep(method::LMWray3, stepper, Δt; θ = nothing) = error("Not yet implemented") +function timestep(method::LMWray3, stepper, Δt; θ = nothing) + (; setup, psolver, u, temp, t, n) = stepper + (; closure_model, temperature) = setup + m = closure_model + T = eltype(u) + + # We wrap the state in x = (; u, temp), and define some + # functions that operate on x + + # Right-hand side function (without projection) + function f(u, temp, t, setup) + u = apply_bc_u(u, t, setup) + if isnothing(temp) + dtemp = nothing + else + temp = apply_bc_temp(temp, t, setup) + dtemp = convection_diffusion_temp(u, temp, setup) + if temperature.dodissipation + dtemp += dissipation(u, setup) + end + end + du = momentum(u, temp, t, setup) + + # Add closure term + isnothing(m) || (du += m(u, θ)) + + du, dtemp + end + + # Update current state + tstart = t + ustart = u + tempstart = temp + + # Low-storage Butcher tableau: + # c1 | 0 ⋯ 0 + # c2 | a1 0 ⋯ 0 + # c3 | b1 a2 0 ⋯ 0 + # c4 | b1 b2 a3 0 ⋯ 0 + # ⋮ | ⋮ ⋮ ⋮ ⋱ ⋱ ⋮ + # cn | b1 b2 b3 ⋯ an-1 0 + # ---+-------------------- + # | b1 b2 b3 ⋯ bn-1 an + # + # Note the definition of (ai)i. + # They are shifted to simplify the for-loop. + # TODO: Make generic by passing a, b, c as inputs + a = T(8 / 15), T(5 / 12), T(3 / 4) + b = T(1 / 4), T(0) + c = T(0), T(8 / 15), T(2 / 3) + nstage = length(a) + + for i = 1:nstage + t = tstart + c[i] * Δt + du, dtemp = f(u, temp, t, setup) + + # Compute state at current stage + u = @. ustart + Δt * a[i] * du + u = apply_bc_u(u, t, setup) + u = project(u, setup; psolver) + if !isnothing(temp) + temp = @. tempstart + Δt * a[i] * dtemp + end + + # Advance start state (skip for last iter) + if i < nstage + ustart = @. ustart + Δt * b[i] * du + if !isnothing(temp) + tempstart = @. tempstart + Δt * b[i] * dtemp + end + end + end + + # Full time step + t = tstart + Δt + + # This is redundant, but Neumann BC need to have _exact_ copies + # since we divide by an infinitely thin (eps(T)) volume width in the + # diffusion term + u = apply_bc_u(u, t, setup) + if !isnothing(temp) + temp = apply_bc_temp(temp, t, setup) + end + + create_stepper(method; setup, psolver, u, temp, t, n = n + 1) +end diff --git a/src/utils.jl b/src/utils.jl index e4cde8afa..3f6abce7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -31,49 +31,6 @@ Plot nonuniform Cartesian grid. """ function plotgrid end -plotgrid(x, y; kwargs...) = wireframe( - x, - y, - zeros(eltype(x), length(x), length(y)); - axis = (; aspect = DataAspect(), xlabel = "x", ylabel = "y"), - kwargs..., -) - -function plotgrid(x, y, z) - nx, ny, nz = length(x), length(y), length(z) - T = eltype(x) - - # x = repeat(x, 1, ny, nz) - # y = repeat(reshape(y, 1, :, 1), nx, 1, nz) - # z = repeat(reshape(z, 1, 1, :), nx, ny, 1) - # vol = repeat(reshape(z, 1, 1, :), nx, ny, 1) - # volume(x, y, z, vol) - fig = Figure() - - ax = Axis3(fig[1, 1]) - wireframe!(ax, x, y, fill(z[1], length(x), length(y))) - wireframe!(ax, x, y, fill(z[end], length(x), length(y))) - wireframe!(ax, x, fill(y[1], length(z)), repeat(z, 1, length(x))') - wireframe!(ax, x, fill(y[end], length(z)), repeat(z, 1, length(x))') - wireframe!(ax, fill(x[1], length(z)), y, repeat(z, 1, length(y))) - wireframe!(ax, fill(x[end], length(z)), y, repeat(z, 1, length(y))) - ax.aspect = :data - - ax = Axis(fig[1, 2]; xlabel = "x", ylabel = "y") - wireframe!(ax, x, y, zeros(T, length(x), length(y))) - ax.aspect = DataAspect() - - ax = Axis(fig[2, 1]; xlabel = "y", ylabel = "z") - wireframe!(ax, y, z, zeros(T, length(y), length(z))) - ax.aspect = DataAspect() - - ax = Axis(fig[2, 2]; xlabel = "x", ylabel = "z") - wireframe!(ax, x, z, zeros(T, length(x), length(z))) - ax.aspect = DataAspect() - - fig -end - "Get utilities to compute energy spectrum." function spectral_stuff(setup; npoint = 100, a = typeof(setup.Re)(1 + sqrt(5)) / 2) (; dimension, xp, Np) = setup.grid @@ -169,4 +126,4 @@ function get_spectrum(setup; npoint = 100, a = typeof(e.setup.Re)(1 + sqrt(5)) / BoolArray = typeof(similar(xp[1], Bool, ntuple(Returns(0), D)...)) masks = adapt.(BoolArray, masks) (; κ, masks, K) -end +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index aab0361a3..9b2102deb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl new file mode 100644 index 000000000..62a09c095 --- /dev/null +++ b/test/chainrules_enzyme.jl @@ -0,0 +1,364 @@ + + +@testsnippet ChainRulesStuff begin + using ChainRulesCore + using ChainRulesTestUtils + using Enzyme + using Zygote + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS + import .EnzymeRules: reverse, augmented_primal + using .EnzymeRules +end + +@testmodule Case begin + using IncompressibleNavierStokes + + D2, D3 = map((2, 3)) do D + T = Float64 + Re = T(1_000) + n = if D == 2 + 8 + elseif D == 3 + # 4^3 = 64 grid points + # 3*64 = 192 velocity components + # 192^2 = 36864 fininite difference pairs in convection/diffusion + # TODO: Check if `test_rrule` computes all combinations or only a subset + 4 + end + lims = T(0), T(1) + x = if D == 2 + tanh_grid(lims..., n), tanh_grid(lims..., n, 1.3) + elseif D == 3 + tanh_grid(lims..., n, 1.2), tanh_grid(lims..., n, 1.1), cosine_grid(lims..., n) + end + boundary_conditions = ntuple(d -> (DirichletBC(), DirichletBC()), D) + temperature = temperature_equation(; + Pr = T(0.71), + Ra = T(1e6), + Ge = T(1.0), + boundary_conditions, + ) + if D == 2 + bodyforce = (dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y) + dbodyforce = (dim, x, y, t) -> (dim == 1) * 5 * pi * 8 * cos(pi * 8 * y) + elseif D == 3 + bodyforce = (dim, x, y, z, t) -> (dim == 1) * 5 * sinpi(8 * y) + dbodyforce = (dim, x, y, z, t) -> (dim == 1) * 5 * pi * 8 * cos(pi * 8 * y) + end + setup = Setup(; x, boundary_conditions, Re, temperature, bodyforce, dbodyforce, issteadybodyforce = true) + psolver = default_psolver(setup) + u = randn(T, setup.grid.N..., D) + p = randn(T, setup.grid.N) + temp = randn(T, setup.grid.N) + div = divergence(u, setup) + (; setup, psolver, u, p, temp, div) + end +end + +@testitem "Chain rules (boundary conditions)" setup = [ChainRulesStuff] begin + import .EnzymeRules: reverse, augmented_primal + using .EnzymeRules + T = Float64 + Re = T(1_000) + Pr = T(0.71) + Ra = T(1e6) + Ge = T(1.0) + n = 7 + lims = T(0), T(1) + x = range(lims..., n + 1), range(lims..., n + 1) + + for bc in (PeriodicBC(), DirichletBC(), SymmetricBC(), PressureBC()) + boundary_conditions = (bc, bc), (bc, bc) + setup = Setup(; + x, + Re, + boundary_conditions, + temperature = temperature_equation(; Pr, Ra, Ge, boundary_conditions), + ) + u = randn(T, setup.grid.N..., 2) + p = randn(T, setup.grid.N) + temp = randn(T, setup.grid.N) + u0 = copy(u) + p0 = copy(p) + temp0 = copy(temp) + + # --- bc_u + Zygote.pullback(apply_bc_u, u, nothing, setup)[2](u0)[1] + zpull, z_time = @timed Zygote.pullback(apply_bc_u, u, nothing, setup)[2](u0)[1] + du = Enzyme.make_zero(u) + y = Enzyme.make_zero(u) + dy = Enzyme.make_zero(u) .+1 + f = INS.enzyme_wrap(INS.apply_bc_u!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(u, du), Const(nothing), Const(setup)) + du = Enzyme.make_zero(u) + y = Enzyme.make_zero(u) + dy = Enzyme.make_zero(u) .+1 + eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(u0, du), Const(nothing), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (bc_u): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_u): ", z_time, " vs ", e_time + end + @test du == zpull + + # --- bc_p + Zygote.pullback(apply_bc_p, p, nothing, setup)[2](p0)[1] + zpull, z_time = @timed Zygote.pullback(apply_bc_p, p, nothing, setup)[2](p0)[1] + dp = Enzyme.make_zero(p) + y = Enzyme.make_zero(p) + dy = Enzyme.make_zero(p) .+1 + f = INS.enzyme_wrap(INS.apply_bc_p!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(p0, dp), Const(nothing), Const(setup)) + dp = Enzyme.make_zero(p) + y = Enzyme.make_zero(p) + dy = Enzyme.make_zero(p) .+1 + eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(p0, dp), Const(nothing), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (bc_p): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_p): ", z_time, " vs ", e_time + end + @test dp == zpull + + + # --- bc_temp + Zygote.pullback(apply_bc_temp, temp, nothing, setup)[2](temp0)[1] + zpull, z_time = @timed Zygote.pullback(apply_bc_temp, temp, nothing, setup)[2](temp0)[1] + dtemp = Enzyme.make_zero(temp) + y = Enzyme.make_zero(temp) + dy = Enzyme.make_zero(temp) .+1 + f = INS.enzyme_wrap(INS.apply_bc_temp!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(temp0, dtemp), Const(nothing), Const(setup)) + dtemp = Enzyme.make_zero(temp) + y = Enzyme.make_zero(temp) + dy = Enzyme.make_zero(temp) .+1 + f = INS.enzyme_wrap(INS.apply_bc_temp!) + eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(temp0, dtemp), Const(nothing), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (bc_temp): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_temp): ", z_time, " vs ", e_time + end + @test dtemp == zpull + + end +end + + +@testitem "Divergence" setup = [Case, ChainRulesStuff] begin + + for (u,setup,d) in ((Case.D2.u, Case.D2.setup, Case.D2.div), (Case.D3.u, Case.D3.setup, Case.D3.div)) + d0 = copy(d) + u0 = copy(u) + Zygote.pullback(INS.divergence, u, setup)[2](d0)[1] + zpull, z_time = @timed Zygote.pullback(INS.divergence, u, setup)[2](d0)[1] + dd = Enzyme.make_zero(d) .+1 + du = Enzyme.make_zero(u) + f = INS.enzyme_wrap(INS.divergence!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) + dd = Enzyme.make_zero(d) .+1 + du = Enzyme.make_zero(u) + eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d0, dd), Duplicated(u0, du), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (divergence): ", e_time, " vs ", z_time + else + @info "Zygote is faster (divergence): ", z_time, " vs ", e_time + end + @test du == zpull + end +end + +@testitem "Pressuregradient" setup = [Case, ChainRulesStuff] begin + for (p,setup) in ((Case.D2.p, Case.D2.setup), (Case.D3.p, Case.D3.setup)) + p0 = copy(p) + pg = INS.pressuregradient(p, setup) + pg0 = copy(pg) + Zygote.pullback(INS.pressuregradient, p, setup)[2](pg0)[1] + zpull, z_time = @timed Zygote.pullback(INS.pressuregradient, p, setup)[2](pg0)[1] + dpg = Enzyme.make_zero(pg) .+1 + dp = Enzyme.make_zero(p) + f = INS.enzyme_wrap(INS.pressuregradient!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(pg, dpg), Duplicated(p, dp), Const(setup)) + dpg = Enzyme.make_zero(pg) .+1 + dp = Enzyme.make_zero(p) + eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(pg0, dpg), Duplicated(p0, dp), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (pressuregradient): ", e_time, " vs ", z_time + else + @info "Zygote is faster (pressuregradient): ", z_time, " vs ", e_time + end + @test dp == zpull + end + +end + +@testitem "Poisson" setup = [Case, ChainRulesStuff] begin + for (psolver,d,setup) in ((Case.D2.psolver, Case.D2.div, Case.D2.setup), (Case.D3.psolver, Case.D3.div, Case.D3.setup)) + + p0 = INS.poisson(psolver, d) + Zygote.pullback(INS.poisson, psolver, d)[2](p0)[1] + zpull, z_time = @timed Zygote.pullback(INS.poisson, psolver, d)[2](p0)[2] + + dd = Enzyme.make_zero(d) + p = Enzyme.make_zero(p0) + dp = Enzyme.make_zero(p) .+1 + f = INS.enzyme_wrap(INS.poisson!) + + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p, dp), Const(psolver), Duplicated(d, dd)) + ep, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p0, dp), Const(psolver), Duplicated(d, dd)) + if e_time < z_time + @info "Enzyme is faster (poisson): ", e_time, " vs ", z_time + else + @info "Zygote is faster (poisson): ", z_time, " vs ", e_time + end + @test dd == zpull + end + +end + +@testitem "Convection" setup = [Case, ChainRulesStuff] begin + for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + c = INS.convection(u, setup) + u0 = copy(u) + Zygote.pullback(INS.convection, u, setup)[2](u)[1] + zpull, z_time = @timed Zygote.pullback(INS.convection, u, setup)[2](c)[1] + + # [!] convection! wants to start from 0 initialized field + Enzyme.make_zero!(c) + dc = Enzyme.make_zero(c) .+1 + du = Enzyme.make_zero(u) + f = INS.enzyme_wrap(INS.convection!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(c, dc), Duplicated(u, du), Const(setup)) + Enzyme.make_zero!(c) + dc = Enzyme.make_zero(c) .+1 + du = Enzyme.make_zero(u) + ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(c, dc), Duplicated(u, du), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (convection): ", e_time, " vs ", z_time + else + @info "Zygote is faster (convection): ", z_time, " vs ", e_time + end + @test du == zpull + end +end + +@testitem "Diffusion" setup = [Case, ChainRulesStuff] begin + for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + d = INS.diffusion(u, setup) + u0 = copy(u) + Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] + zpull, z_time = @timed Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] + + # [!] diffusion! wants to start from 0 initialized field + Enzyme.make_zero!(d) + dd = Enzyme.make_zero(d) .+1 + du = Enzyme.make_zero(u) + f = INS.enzyme_wrap(INS.diffusion!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) + Enzyme.make_zero!(d) + dd = Enzyme.make_zero(d) .+1 + du = Enzyme.make_zero(u) + ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (diffusion): ", e_time, " vs ", z_time + else + @info "Zygote is faster (diffusion): ", z_time, " vs ", e_time + end + @test du == zpull + end +end + +@testitem "Bodyforce" setup = [Case, ChainRulesStuff] begin + @warn "bodyforce is tested only in the static case" + for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + t = 0.5 + bf = INS.applybodyforce(u, t, setup) + bf0 = copy(bf) + setup0 = deepcopy(setup) + Zygote.pullback(INS.applybodyforce, u, t, setup)[2](bf0) + zpull, z_time = @timed Zygote.pullback(INS.applybodyforce, u, t, setup)[2](bf0)[3].bodyforce + + # We can also test Zygote autodiff + @test zpull == setup.bodyforce + + bf = bf .*0 + dbf = Enzyme.make_zero(bf) .+1 + du = Enzyme.make_zero(u) + f = INS.enzyme_wrap(INS.applybodyforce!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(bf, dbf), Duplicated(u, du), Const(t), Const(setup)) + bf = bf .*0 + dbf = Enzyme.make_zero(bf) .+1 + du = Enzyme.make_zero(u) + eb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(bf, dbf), Duplicated(u, du), Const(t), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (bodyforce): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bodyforce): ", z_time, " vs ", e_time + end + @test du == zpull + end +end + +@testitem "Gravity" setup = [Case, ChainRulesStuff] begin + for (t,setup) in ((Case.D2.temp, Case.D2.setup), (Case.D3.temp, Case.D3.setup)) + + g = INS.gravity(t, setup) + Zygote.pullback(INS.gravity, t, setup)[2](g) + zpull, z_time = @timed Zygote.pullback(INS.gravity, t, setup)[2](g)[1] + + g = vectorfield(setup) + dg = Enzyme.make_zero(g) .+1 + dt = Enzyme.make_zero(t) + f = INS.enzyme_wrap(INS.gravity!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(g, dg), Duplicated(t, dt), Const(setup)) + g = vectorfield(setup) + dg = Enzyme.make_zero(g) .+1 + dt = Enzyme.make_zero(t) + gb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(g, dg), Duplicated(t, dt), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (gravity): ", e_time, " vs ", z_time + else + @info "Zygote is faster (gravity): ", z_time, " vs ", e_time + end + @test dt == zpull + + end +end + +@testitem "Dissipation" setup = [Case, ChainRulesStuff] begin + for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + + diss = INS.dissipation(u, setup) + Zygote.pullback(INS.dissipation, u, setup)[2](diss) + zpull, z_time = @timed Zygote.pullback(INS.dissipation, u, setup)[2](diss)[1] + + diss = scalarfield(setup) + diff = vectorfield(setup) + ddiss = Enzyme.make_zero(diss) .+1 + ddiff = Enzyme.make_zero(diff) + du = Enzyme.make_zero(u) + f = INS.enzyme_wrap(INS.dissipation!) + Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(diss, ddiss), Duplicated(diff, ddiff), Duplicated(u,du), Const(setup)) + diss = scalarfield(setup) + diff = vectorfield(setup) + diss = Enzyme.make_zero(diss) .+1 + diff = Enzyme.make_zero(diff) + du = Enzyme.make_zero(u) + ed, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(diss, ddiss), Duplicated(diff, ddiff), Duplicated(u,du), Const(setup)) + if e_time < z_time + @info "Enzyme is faster (dissipation): ", e_time, " vs ", z_time + else + @info "Zygote is faster (dissipation): ", z_time, " vs ", e_time + end + @test du == zpull + + end +end +@testitem "Convection_diffusion_temp" setup = [Case, ChainRulesStuff] begin + @test_broken 1 == 2 +end + +@testitem "Convectiondiffusion" setup = [Case, ChainRulesStuff] begin +# the pullback rule is missing for this one + @test_broken 1 == 2 +end diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl new file mode 100644 index 000000000..15159d269 --- /dev/null +++ b/test/enzyme_integration.jl @@ -0,0 +1,119 @@ + +@testsnippet EnzymeStuff begin + using IncompressibleNavierStokes + using Enzyme + using Zygote + using Random + rng = Random.default_rng(); +end + +@testmodule Case begin + using IncompressibleNavierStokes + T = Float64 + ArrayType = Array + Re = T(1_000) + D = 2 + n = 64 + N = n + 2 + lims = T(0), T(1); + x = tanh_grid(lims..., n), tanh_grid(lims..., n, 1.3) ; + boundary_conditions = ntuple(d -> (DirichletBC(), DirichletBC()), D) + setup = Setup(;x, boundary_conditions, Re); + psolver = default_psolver(setup) +end + +@testitem "Enzyme one force pullback" setup = [Case, EnzymeStuff] begin + for (setup, psolver, T, N) in ((Case.setup, Case.psolver, Case.T, Case.N), ) + dudt = zeros(T, (N, N, 2)) ; + u = rand(T, (N, N, 2)); + u0 = copy(u); + params = [setup, psolver]; + params_ref = Ref(params); + right_hand_side!(dudt, u, params_ref, T(0)) + F_out = create_right_hand_side(setup, psolver) + @test dudt ≈ F_out(u, nothing, T(0)) + @test u == u0 + @test sum(dudt) != 0 + + niter = 5000 + list_u = [rand(T, (N, N, 2)) for i in 1:niter]; + list_z = [] + _, tz, mz = @timed begin + for i in 1:niter + dudt = F_out(list_u[i], nothing, T(0)) + push!(list_z, dudt) + end + end + list_e = [zeros(T, (N, N, 2)) for i in 1:niter]; + _, te, me = @timed begin + for i in 1:niter + right_hand_side!(list_e[i], list_u[i], params_ref, T(0)) + end + end + @test all([list_z[i] ≈ list_e[i] for i in 1:niter]) + if te < tz + @info "One F in-place is faster by a factor of $(tz/te)" + else + @info "One F out-of-place is faster by a factor of $(te/tz)" + end + if me < mz + @info "One F in-place is more memory efficient by a factor of $(mz/me)" + else + @info "One F out-of-place is more memory efficient by a factor of $(me/mz)" + end + + end +end + +@testitem "Enzyme RHS pullback" setup = [Case, EnzymeStuff] begin + for (setup, psolver, T, N) in ((Case.setup, Case.psolver, Case.T, Case.N), ) + F_out = create_right_hand_side(setup, psolver) + dudt = zeros(T, (N, N, 2)) ; + u = rand(T, (N, N, 2)); + u0 = copy(u) + du = Enzyme.make_zero(u); + dd = Enzyme.make_zero(dudt) .+ 1; + params = [setup, psolver]; + params_ref = Ref(params); + Enzyme.autodiff(Enzyme.Reverse, right_hand_side!, Duplicated(dudt,dd), Duplicated(u,du), Const(params_ref), Const(T(0))) + @test u0 == u + @test dudt ≈ F_out(u, nothing, T(0)) + zpull = Zygote.pullback(F_out, u, nothing, T(0)); + @test zpull[1] ≈ dudt + @test zpull[2](dudt)[1] ==du + + + # Now I run each option multiple times from different random initial conditions + niter = 3000 + list_u = [rand(T, (N, N, 2)) for i in 1:niter]; + list_z = [] + _, tz, mz = @timed begin + for i in 1:niter + du = Enzyme.make_zero(u); + dd = Enzyme.make_zero(dudt) .+ 1; + zpull = Zygote.pullback(F_out, list_u[i], nothing, T(0)); + push!(list_z, zpull[2](zpull[1])[1]) + end + end + list_e = [] + _, te, me = @timed begin + for i in 1:niter + du = Enzyme.make_zero(u); + dd = Enzyme.make_zero(dudt) .+ 1; + Enzyme.autodiff(Enzyme.Reverse, right_hand_side!, Duplicated(dudt,dd), Duplicated(list_u[i],du), Const(params_ref), Const(T(0))) + push!(list_e, du) + end + end + @test all([list_z[i] ≈ list_e[i] for i in 1:niter]) + if te < tz + @info "Reverse AD using Enzyme is faster by a factor of $(tz/te)" + else + @info "Reverse AD using Zygote is faster by a factor of $(te/tz)" + end + if me < mz + @info "Reverse AD using Enzyme is more memory efficient by a factor of $(mz/me)" + else + @info "Reverse AD using Zygote is more memory efficient by a factor of $(me/mz)" + end + end +end \ No newline at end of file diff --git a/test/operators.jl b/test/operators.jl index 5d6217680..9da7996cb 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -105,7 +105,7 @@ end @testitem "Diffusion" setup = [Setup2D, Setup3D] begin for (u, setup) in ((Setup2D.u, Setup2D.setup), (Setup3D.u, Setup3D.setup)) - T = eltype(u[1]) + T = eltype(u) (; dimension, Iu, Δ, Δu) = setup.grid d = diffusion(u, setup) D = dimension() diff --git a/test/runtests.jl b/test/runtests.jl index 71a4a7ffb..e758e2e97 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,4 +3,10 @@ using TestItemRunner # @testitem "Time steppers" begin include("timesteppers.jl") end # Only run tests from this test dir, and not from other packages in monorepo -@run_package_tests filter = t -> occursin(@__DIR__, t.filename) +#@run_package_tests filter = t -> occursin(@__DIR__, t.filename) + +# Or you can run a single test using the following +function myfilter(t) + return endswith(t.filename, "enzyme_integration.jl") || endswith(t.filename, "chainrules_enzyme.jl") +end +@run_package_tests filter = myfilter \ No newline at end of file diff --git a/test/timesteppers.jl b/test/timesteppers.jl index 52e91901f..5b2150d17 100644 --- a/test/timesteppers.jl +++ b/test/timesteppers.jl @@ -1,122 +1,43 @@ -T = Float64 -Re = 500.0 - -n = 50 -x = LinRange(0, 2π, n + 1), LinRange(0, 2π, n + 1) -setup = Setup(; x, Re) - -psolver = psolver_spectral(setup) - -t_start, t_end = tlims = (0.0, 5.0) - -initial_velocity_u(x, y) = cos(x)sin(y) -initial_velocity_v(x, y) = -sin(x)cos(y) -initial_pressure(x, y) = -1 / 4 * (cos(2x) + cos(2y)) -V = velocityfield( - setup, - initial_velocity_u, - initial_velocity_v, - t_start; - initial_pressure, - psolver, -) - -@testset "Steady state" begin - V, p = solve_steady_state(setup, V₀, p₀) - uₕ = V[setup.grid.indu] - vₕ = V[setup.grid.indv] - @test norm(uₕ .- mean(uₕ)) / mean(uₕ) < 1e-8 - @test norm(vₕ .- mean(vₕ)) / mean(vₕ) < 1e-8 -end - -# Exact solutions -F(t) = exp(-2t / Re) -u(x, y, t) = initial_velocity_u(x, y) * F(t) -v(x, y, t) = initial_velocity_v(x, y) * F(t) -(; xu, yu, xv, yv) = setup.grid -uₕ = u.(xu, yu, t_end) -vₕ = v.(xv, yv, t_end) -V_exact = [uₕ[:]; vₕ[:]] - -@testset "Unsteady solvers" begin - @testset "Explicit Runge Kutta" begin - state, outputs = - solve_unsteady(setup, V₀, tlims; Δt = 0.01, psolver, inplace = false) - @test norm(state.u - u_exact) / norm(u_exact) < 1e-4 - stateip, outputsip = - solve_unsteady(setup, V₀, tlims; Δt = 0.01, psolver, inplace = true) - @test stateip.u ≈ state.u - @test stateip.p ≈ state.p - end - - @testset "Implicit Runge Kutta" begin - @test_broken solve_unsteady( - setup, - V₀, - tlims; - method = RIA2(), - Δt = 0.01, - psolver, - inplace = false, - ) isa Tuple - (; u, t), outputs = solve_unsteady( - setup, - V₀, - tlims; - method = RIA2(), - Δt = 0.01, - psolver, - inplace = true, - processors = (timelogger(),), - ) - @test_broken norm(u - u_exact) / norm(u_exact) < 1e-3 - end - - @testset "One-leg beta method" begin - state, outputs = solve_unsteady( - setup, - V₀, - tlims; - method = OneLegMethod(T), - Δt = 0.01, - psolver, - inplace = false, - ) - @test norm(state.u - u_exact) / norm(u_exact) < 1e-4 - stateip, outputsip = solve_unsteady( - setup, - V₀, - tlims; - method = OneLegMethod(T), - Δt = 0.01, - psolver, - inplace = true, - ) - @test stateip.u ≈ state.u - @test stateip.p ≈ state.p - end - - @testset "Adams-Bashforth Crank-Nicolson" begin - state, outputs = solve_unsteady( - setup, - V₀, - tlims; - method = AdamsBashforthCrankNicolsonMethod(T), - Δt = 0.01, - psolver, - inplace = false, - ) - @test norm(state.u - u_exact) / norm(u_exact) < 1e-4 - stateip, outputs = solve_unsteady( - setup, - V₀, - tlims; - method = AdamsBashforthCrankNicolsonMethod(T), - Δt = 0.01, - psolver, - inplace = true, - ) - @test stateip.u ≈ state.u - @test stateip.p ≈ state.p +@testitem "In/out-of place" begin + using Random + ax = range(0, 1, 17) + temperature = temperature_equation(; + Pr = 0.71, + Ra = 1e7, + Ge = 1.0, + boundary_conditions = ((PeriodicBC(), PeriodicBC()), (PeriodicBC(), PeriodicBC())), + ) + setup = Setup(; x = (ax, ax), Re = 1e3, temperature) + psolver = default_psolver(setup) + u = random_field(setup, psolver) + temp = randn!(scalarfield(setup)) + temp = apply_bc_temp(temp, 0.0, setup) + Δt = 0.1 + for method in [LMWray3(), RKMethods.RK44()] + stepper_outplace = let + stepper = create_stepper( + method; + setup, + psolver, + u = copy(u), + temp = copy(temp), + t = 0.0, + ) + timestep(method, stepper, Δt) + end + stepper_inplace = let + cache = IncompressibleNavierStokes.ode_method_cache(method, setup) + stepper = create_stepper( + method; + setup, + psolver, + u = copy(u), + temp = copy(temp), + t = 0.0, + ) + IncompressibleNavierStokes.timestep!(method, stepper, Δt; cache) + end + @test stepper_inplace.u ≈ stepper_outplace.u + @test stepper_inplace.temp ≈ stepper_outplace.temp end end From b2195c497e0b1e4766bfc6c03a7e020f2e76750a Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 11 Nov 2024 19:20:31 +0100 Subject: [PATCH 02/14] Fix tests --- test/enzyme_integration.jl | 10 +++++----- test/runtests.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl index 15159d269..4b345cf60 100644 --- a/test/enzyme_integration.jl +++ b/test/enzyme_integration.jl @@ -7,7 +7,7 @@ rng = Random.default_rng(); end -@testmodule Case begin +@testmodule Info begin using IncompressibleNavierStokes T = Float64 ArrayType = Array @@ -22,8 +22,8 @@ end psolver = default_psolver(setup) end -@testitem "Enzyme one force pullback" setup = [Case, EnzymeStuff] begin - for (setup, psolver, T, N) in ((Case.setup, Case.psolver, Case.T, Case.N), ) +@testitem "Enzyme one force pullback" setup = [Info, EnzymeStuff] begin + for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N), ) dudt = zeros(T, (N, N, 2)) ; u = rand(T, (N, N, 2)); u0 = copy(u); @@ -65,8 +65,8 @@ end end end -@testitem "Enzyme RHS pullback" setup = [Case, EnzymeStuff] begin - for (setup, psolver, T, N) in ((Case.setup, Case.psolver, Case.T, Case.N), ) +@testitem "Enzyme RHS pullback" setup = [Info, EnzymeStuff] begin + for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N), ) F_out = create_right_hand_side(setup, psolver) dudt = zeros(T, (N, N, 2)) ; u = rand(T, (N, N, 2)); diff --git a/test/runtests.jl b/test/runtests.jl index e758e2e97..41f28869b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using TestItemRunner # Only run tests from this test dir, and not from other packages in monorepo #@run_package_tests filter = t -> occursin(@__DIR__, t.filename) -# Or you can run a single test using the following +# Or you can run only specific tests using the following function myfilter(t) return endswith(t.filename, "enzyme_integration.jl") || endswith(t.filename, "chainrules_enzyme.jl") end From 5a8a1a204bc4c51b6c708101476aa65cec913fd1 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 13 Nov 2024 16:18:42 +0100 Subject: [PATCH 03/14] add docstrings add docpages clean deps make @info optional in tests --- Project.toml | 3 - docs/src/manual/differentiability.md | 64 ++++++++++++++- docs/src/manual/sciml.md | 46 +++++++++++ src/IncompressibleNavierStokes.jl | 1 - src/boundary_conditions.jl | 2 - src/sciml.jl | 48 +++++------ test/chainrules_enzyme.jl | 115 ++++++++++++++++----------- test/enzyme_integration.jl | 8 +- test/runtests.jl | 12 +-- 9 files changed, 208 insertions(+), 91 deletions(-) create mode 100644 docs/src/manual/sciml.md diff --git a/Project.toml b/Project.toml index 8f25ebe23..62fc55f2e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,11 +6,9 @@ version = "2.0.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -24,7 +22,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/docs/src/manual/differentiability.md b/docs/src/manual/differentiability.md index 8da9df70f..4bb0fd479 100644 --- a/docs/src/manual/differentiability.md +++ b/docs/src/manual/differentiability.md @@ -7,19 +7,24 @@ CurrentModule = IncompressibleNavierStokes IncompressibleNavierStokes is [reverse-mode differentiable](https://juliadiff.org/ChainRulesCore.jl/stable/index.html#Reverse-mode-AD-rules-(rrules)), which means that you can back-propagate gradients through the code. +Two AD libraries are currently supported: +* **[Zygote.jl](https://github.com/FluxML/Zygote.jl)**: it is the default AD library in the Julia ecosystem and is the most widely used. +* **[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)**: currently has low coverage over the Julia programming language, however it is usually the most efficient if applicable. + +## Automatic differentiation with Zygote + +Zygote.jl is the default choice for AD backend because it is easy to understand, compatible with most of the Julia ecosystem and good with vectorized code and BLAS. This comes at a cost however, as intermediate velocity fields need to be stored in memory for use in the backward pass. For this reason, many of the operators come in two versions: a slow differentiable allocating non-mutating variant (e.g. [`divergence`](@ref)) and fast non-differentiable non-allocating mutating variant (e.g. [`divergence!`](@ref).) -!!! warning "Differentiable code" +!!! warning "Zygote limitation: array mutation" To make your code differentiable, you must use the differentiable versions of the operators (without the exclamation marks). -To differentiate the code, use [Zygote.jl](https://github.com/FluxML/Zygote.jl). - -## Example: Gradient of kinetic energy +#### Example: Gradient of kinetic energy To differentiate outputs of a simulation with respect to the initial conditions, make a time stepping loop composed of differentiable operations: @@ -55,3 +60,54 @@ Now `g` is the gradient of `final_energy` with respect to the initial conditions Note that every operation in the `final_energy` function is non-mutating and thus differentiable. + +--- +## Automatic differentiation with Enzyme + +Enzyme.jl is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools. +The downside is that restricts the user's defined f function to not do things like require garbage collection or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place f with fully mutating non-allocating code will work with Enzyme and this will be the most efficient adjoint implementation. + +!!! warning "Enzyme limitation: vector returns" + Enzyme's autodiff function can only handle functions with scalar output. To implement pullbacks for array-valued functions, use a mutating function that returns `nothing` and stores its result in one of the arguments, which must be passed wrapped in a Duplicated. + In IncompressibleNavierStokes, we provide `enzyme_wrapper` to automatically wrap the function and its arguments in the correct way. + +#### Example: Gradient of the right-hand side + +In this example we differentiate the right-hand side of the Navier-Stokes equations with respect to the velocity field `u`: + +```julia +import IncompressibleNavierStokes as INS +ax = range(0, 1, 101) +setup = INS.Setup(; x = (ax, ax), Re = 500.0) +psolver = INS.default_psolver(setup) +u = INS.random_field(setup) +t = 0.0 +f! = INS.right_hand_side! +``` +Notice that we are using the mutating (in-place) version of the right-hand side function. This function can not be differentiate by Zygote, which requires the slower non-mutating version of the right-hand side. + +We then define the `Dual` part of the input and output, required to store the adjoint values: +```julia +ddudt = Enzyme.make_zero(dudt) .+ 1; +du = Enzyme.make_zero(u); +``` +Remember that the derivative of the output (also called the *seed*) has to be set to $1$ in order to compute the gradient. In this case the output is the force, that we store mutating the value of `dudt` inside `right_hand_side!`. + +Then we pack the parameters to be passed to `right_hand_side!`: +```julia +params = [setup, psolver]; +params_ref = Ref(params); +``` +Now, we call the `autodiff` function from Enzyme: +```julia +Enzyme.autodiff(Enzyme.Reverse, f!, Duplicated(dudt,dd), Duplicated(u,du), Const(params_ref), Const(t)) +``` +Since we have passed a `Duplicated` object, the gradient of `u` is stored in `du`. + +Finally, we can also compare its value with the one obtained by Zygote differentiating the out-of-place (non-mutating) version of the right-hand side: +```julia +using Zygote +f = create_right_hand_side(setup, psolver) +_, zpull = Zygote.pullback(f, u, nothing, T(0)); +@assert zpull(dudt)[1] == du +``` \ No newline at end of file diff --git a/docs/src/manual/sciml.md b/docs/src/manual/sciml.md new file mode 100644 index 000000000..af943810e --- /dev/null +++ b/docs/src/manual/sciml.md @@ -0,0 +1,46 @@ +```@meta +CurrentModule = IncompressibleNavierStokes +``` + +# Using IncompressibleNavierStokes in SciML + +The [SciML organization](https://sciml.ai/) is a collection of tools for solving equations and modeling systems. It has a coherent development principle, unified APIs over large collections of equation solvers, pervasive differentiability and sensitivity analysis, and features many of the highest performance and parallel implementations one can find. + +In particular, [DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) contains tools to solve differential equations defined as $\dfrac{du}{dt} = f(u, t)$ that include a large collection of solvers, sensitivity analysis, and more. + +Using IncompressibleNavierStokes it is possible to write the momentum equations without the pressure by explicitly solving the discrete Poisson equation and obtaining: + +```math +\begin{align*} +\frac{\mathrm{d} u_h}{\mathrm{d} t} &= (I - G L^{-1} W M) +(F(u_h) - y_G) - G L^{-1} W \frac{\mathrm{d} y_M}{\mathrm{d} t}\\ &=f(u_h). +\end{align*} +``` + +The derivation and the drawbacks of this approach are discussed in the [documentation](/docs/src/manual/spatial.md). + +This projected right-hand side can be used in the SciML solvers to solve the Navier-Stokes equations. The following example shows how to use the SciML solvers to solve the ODEs obtained from the Navier-Stokes equations. + +```julia +using DifferentialEquations +f(u, p, t) = create_right_hand_side(setup, psolver) +u0 = INITIAL_CONDITION +tspan = (0.0, 1.0) # time span where to solve. +problem = ODEProblem(f, u0, tspan) #SciMLBase.ODEProblem +sol = solve(problem, Tsit5(), reltol = 1e-8, abstol = 1e-8) # sol: SciMLBase.ODESolution +``` + +Alternatively, it is also possible to use an [in-place formulation](https://docs.sciml.ai/DiffEqDocs/stable/basics/problem/#In-place-vs-Out-of-Place-Function-Definition-Forms) + +```julia +f(du,u,p,t) = right_hand_side!(du, u, Ref([setup, psolver]), t) +``` +that is usually faster than the out-of-place formulation. + +You can look [here](https://docs.sciml.ai/DiffEqDocs/stable/basics/overview/) for more information on how to use the SciML solvers and all the options available. + +## API +```@autodocs +Modules = [IncompressibleNavierStokes] +Pages = ["sciml.jl"] +``` \ No newline at end of file diff --git a/src/IncompressibleNavierStokes.jl b/src/IncompressibleNavierStokes.jl index a9931c465..8ab2fe584 100644 --- a/src/IncompressibleNavierStokes.jl +++ b/src/IncompressibleNavierStokes.jl @@ -29,7 +29,6 @@ using SparseArrays using StaticArrays using Statistics using WriteVTK: CollectionFile, paraview_collection, vtk_grid, vtk_save -using Zygote # Docstring templates @template MODULES = """ diff --git a/src/boundary_conditions.jl b/src/boundary_conditions.jl index 5eac0d218..944e64683 100644 --- a/src/boundary_conditions.jl +++ b/src/boundary_conditions.jl @@ -541,12 +541,10 @@ function enzyme_wrap(f::Union{typeof(apply_bc_u!), typeof(apply_bc_p!), typeof(a end function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(apply_bc_u!))}, Const{typeof(enzyme_wrap(apply_bc_p!))}, Const{typeof(enzyme_wrap(apply_bc_temp!))}}, ::Type{<:Const}, y::Duplicated, x::Duplicated, t::Const, setup::Const) - @info "augmented_primal" primal = func.val(y.val, x.val, t.val, setup.val) return AugmentedReturn(primal, nothing, nothing) end function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_u!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) - @info "reverse" adj = apply_bc_u_pullback!(x.val, t.val, setup.val) x.dval .+= adj y.dval .= x.dval # y is a copy of x diff --git a/src/sciml.jl b/src/sciml.jl index b6474cbbd..3ab11c30d 100644 --- a/src/sciml.jl +++ b/src/sciml.jl @@ -1,22 +1,39 @@ +""" + create_right_hand_side(setup, psolver) +Creates a function that computes the right-hand side of the Navier-Stokes equations for a given setup and pressure solver. + +# Arguments +- `setup`: The simulation setup containing grid and boundary conditions. +- `psolver`: The pressure solver to be used. + +# Returns +A function that computes the right-hand side of the Navier-Stokes equations. +""" function create_right_hand_side(setup, psolver) function right_hand_side(u, param, t) F = zeros(size(u)) u = apply_bc_u(u, t, setup) - #F = convection(u, setup) .+ diffusion(u, setup) F = momentum(u, nothing, t, setup) F = apply_bc_u(F, t, setup; dudt = true) FP = project(F, setup; psolver) - #p = divergence(F, setup) - #p = scalewithvolume(p, setup) - #p = poisson(psolver, p) - #p = apply_bc_p(p, t, setup) - #G = pressuregradient(p, setup) - #F .- G end end +""" + right_hand_side!(dudt, u, params_ref, t) + +Computes the right-hand side of the Navier-Stokes equations in-place. + +# Arguments +- `dudt`: The array to store the computed right-hand side. +- `u`: The current velocity field. +- `params_ref`: A reference to the parameters containing the setup and pressure solver. +- `t`: The current time. +# Returns +Nothing. The result is stored in `dudt`. +""" function right_hand_side!(dudt, u, params_ref, t) params = params_ref[] setup = params[1] @@ -26,22 +43,12 @@ function right_hand_side!(dudt, u, params_ref, t) temp_vector = copy(u) apply_bc_u!(temp_vector, t, setup) momentum!(dudt, temp_vector, nothing, t, setup) - #fill!(dudt, 0) - #convectiondiffusion!(dudt, temp_vector, setup) apply_bc_u!(dudt, t, setup) project!(dudt, setup; psolver, p) - #divergence!(p, dudt, setup) - #scalewithvolume!(p, setup) - #poisson!(psolver, p) - #apply_bc_p!(p, t, setup) - #applypressure!(dudt, p, setup) return nothing end -using Enzyme -import .EnzymeRules: reverse, augmented_primal -using .EnzymeRules function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, ::Type{<:Const}, dudt::Duplicated, u::Duplicated, params_ref::Any, t::Const) # this runs function to modify dudt and store the intermediates params = params_ref.val[] @@ -51,15 +58,8 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typ u_bc = copy(u.val) apply_bc_u!(u_bc, t.val, setup) momentum!(dudt.val, u_bc, nothing, t, setup) - #fill!(dudt.val, 0) - #convectiondiffusion!(dudt.val, u_bc, setup) apply_bc_u!(dudt.val, t.val, setup) project!(dudt.val, setup; psolver, p) - #divergence!(p, dudt.val, setup) - #scalewithvolume!(p, setup) - #poisson!(psolver, p) - #apply_bc_p!(p, t.val, setup) - #applypressure!(dudt.val, p, setup) return AugmentedReturn(nothing, nothing, u_bc) end function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, dret, u_bc, dudt::Duplicated, u::Duplicated, params_ref::Const, t::Const) diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index 62a09c095..2b951c626 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -1,5 +1,3 @@ - - @testsnippet ChainRulesStuff begin using ChainRulesCore using ChainRulesTestUtils @@ -8,6 +6,7 @@ using IncompressibleNavierStokes: IncompressibleNavierStokes as INS import .EnzymeRules: reverse, augmented_primal using .EnzymeRules + ENABLE_LOGGING = false end @testmodule Case begin @@ -94,12 +93,14 @@ end y = Enzyme.make_zero(u) dy = Enzyme.make_zero(u) .+1 eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(u0, du), Const(nothing), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (bc_u): ", e_time, " vs ", z_time - else - @info "Zygote is faster (bc_u): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (bc_u): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_u): ", z_time, " vs ", e_time + end end - @test du == zpull + @test du == zpull # --- bc_p Zygote.pullback(apply_bc_p, p, nothing, setup)[2](p0)[1] @@ -113,10 +114,12 @@ end y = Enzyme.make_zero(p) dy = Enzyme.make_zero(p) .+1 eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(p0, dp), Const(nothing), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (bc_p): ", e_time, " vs ", z_time - else - @info "Zygote is faster (bc_p): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (bc_p): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_p): ", z_time, " vs ", e_time + end end @test dp == zpull @@ -134,10 +137,12 @@ end dy = Enzyme.make_zero(temp) .+1 f = INS.enzyme_wrap(INS.apply_bc_temp!) eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(temp0, dtemp), Const(nothing), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (bc_temp): ", e_time, " vs ", z_time - else - @info "Zygote is faster (bc_temp): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (bc_temp): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bc_temp): ", z_time, " vs ", e_time + end end @test dtemp == zpull @@ -159,10 +164,12 @@ end dd = Enzyme.make_zero(d) .+1 du = Enzyme.make_zero(u) eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d0, dd), Duplicated(u0, du), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (divergence): ", e_time, " vs ", z_time - else - @info "Zygote is faster (divergence): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (divergence): ", e_time, " vs ", z_time + else + @info "Zygote is faster (divergence): ", z_time, " vs ", e_time + end end @test du == zpull end @@ -182,10 +189,12 @@ end dpg = Enzyme.make_zero(pg) .+1 dp = Enzyme.make_zero(p) eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(pg0, dpg), Duplicated(p0, dp), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (pressuregradient): ", e_time, " vs ", z_time - else - @info "Zygote is faster (pressuregradient): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (pressuregradient): ", e_time, " vs ", z_time + else + @info "Zygote is faster (pressuregradient): ", z_time, " vs ", e_time + end end @test dp == zpull end @@ -206,10 +215,12 @@ end Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p, dp), Const(psolver), Duplicated(d, dd)) ep, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p0, dp), Const(psolver), Duplicated(d, dd)) - if e_time < z_time - @info "Enzyme is faster (poisson): ", e_time, " vs ", z_time - else - @info "Zygote is faster (poisson): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (poisson): ", e_time, " vs ", z_time + else + @info "Zygote is faster (poisson): ", z_time, " vs ", e_time + end end @test dd == zpull end @@ -233,10 +244,12 @@ end dc = Enzyme.make_zero(c) .+1 du = Enzyme.make_zero(u) ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(c, dc), Duplicated(u, du), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (convection): ", e_time, " vs ", z_time - else - @info "Zygote is faster (convection): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (convection): ", e_time, " vs ", z_time + else + @info "Zygote is faster (convection): ", z_time, " vs ", e_time + end end @test du == zpull end @@ -259,10 +272,12 @@ end dd = Enzyme.make_zero(d) .+1 du = Enzyme.make_zero(u) ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (diffusion): ", e_time, " vs ", z_time - else - @info "Zygote is faster (diffusion): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (diffusion): ", e_time, " vs ", z_time + else + @info "Zygote is faster (diffusion): ", z_time, " vs ", e_time + end end @test du == zpull end @@ -290,10 +305,12 @@ end dbf = Enzyme.make_zero(bf) .+1 du = Enzyme.make_zero(u) eb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(bf, dbf), Duplicated(u, du), Const(t), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (bodyforce): ", e_time, " vs ", z_time - else - @info "Zygote is faster (bodyforce): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (bodyforce): ", e_time, " vs ", z_time + else + @info "Zygote is faster (bodyforce): ", z_time, " vs ", e_time + end end @test du == zpull end @@ -315,10 +332,12 @@ end dg = Enzyme.make_zero(g) .+1 dt = Enzyme.make_zero(t) gb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(g, dg), Duplicated(t, dt), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (gravity): ", e_time, " vs ", z_time - else - @info "Zygote is faster (gravity): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (gravity): ", e_time, " vs ", z_time + else + @info "Zygote is faster (gravity): ", z_time, " vs ", e_time + end end @test dt == zpull @@ -345,10 +364,12 @@ end diff = Enzyme.make_zero(diff) du = Enzyme.make_zero(u) ed, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(diss, ddiss), Duplicated(diff, ddiff), Duplicated(u,du), Const(setup)) - if e_time < z_time - @info "Enzyme is faster (dissipation): ", e_time, " vs ", z_time - else - @info "Zygote is faster (dissipation): ", z_time, " vs ", e_time + if ENABLE_LOGGING + if e_time < z_time + @info "Enzyme is faster (dissipation): ", e_time, " vs ", z_time + else + @info "Zygote is faster (dissipation): ", z_time, " vs ", e_time + end end @test du == zpull diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl index 4b345cf60..ced39c567 100644 --- a/test/enzyme_integration.jl +++ b/test/enzyme_integration.jl @@ -52,14 +52,14 @@ end end @test all([list_z[i] ≈ list_e[i] for i in 1:niter]) if te < tz - @info "One F in-place is faster by a factor of $(tz/te)" + @info "Right-hand side in-place is faster by a factor of $(tz/te)" else - @info "One F out-of-place is faster by a factor of $(te/tz)" + @info "Right-hand side out-of-place is faster by a factor of $(te/tz)" end if me < mz - @info "One F in-place is more memory efficient by a factor of $(mz/me)" + @info "Right-hand side in-place is more memory efficient by a factor of $(mz/me)" else - @info "One F out-of-place is more memory efficient by a factor of $(me/mz)" + @info "Right-hand side out-of-place is more memory efficient by a factor of $(me/mz)" end end diff --git a/test/runtests.jl b/test/runtests.jl index 41f28869b..97a2f4d68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,10 +3,10 @@ using TestItemRunner # @testitem "Time steppers" begin include("timesteppers.jl") end # Only run tests from this test dir, and not from other packages in monorepo -#@run_package_tests filter = t -> occursin(@__DIR__, t.filename) +@run_package_tests filter = t -> occursin(@__DIR__, t.filename) -# Or you can run only specific tests using the following -function myfilter(t) - return endswith(t.filename, "enzyme_integration.jl") || endswith(t.filename, "chainrules_enzyme.jl") -end -@run_package_tests filter = myfilter \ No newline at end of file +## Or you can run only specific tests using the following +#function myfilter(t) +# return endswith(t.filename, "enzyme_integration.jl") || endswith(t.filename, "chainrules_enzyme.jl") +#end +#@run_package_tests filter = myfilter \ No newline at end of file From f3acd84ca4f45d19f4e3df29e044047bd9bcb3fc Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 14 Nov 2024 14:37:40 +0100 Subject: [PATCH 04/14] Remove unused Pkg --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 62fc55f2e..3be81b73a 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192" [weakdeps] From c657a3f0719854b95fbdf8d5aa65d8fadda4b48f Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 19 Nov 2024 08:23:44 +0100 Subject: [PATCH 05/14] Add Simone to CITATION.cff --- CITATION.cff | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CITATION.cff b/CITATION.cff index 294bef371..353e5ebe2 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -9,6 +9,10 @@ authors: family-names: Agdestein affiliation: Centrum Wiskunde & Informatica orcid: 'https://orcid.org/0000-0002-1589-2916' + - given-names: Simone + family-names: Ciarella + affiliation: Netherlands eScience Center + orcid: 'https://orcid.org/0000-0002-9247-139X' - given-names: "Benjamin" family-names: "Sanderse" affiliation: Centrum Wiskunde & Informatica From 5dc7a59c23f54a5f4228db753f961a83ac092720 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 20 Nov 2024 14:44:33 +0100 Subject: [PATCH 06/14] Make AD example runnable --- docs/Project.toml | 15 ++++++++++----- docs/src/manual/differentiability.md | 14 ++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 72d354c3f..ba198a622 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,16 +2,12 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Examples = "318dbb63-4243-420f-99f2-d56058123f9d" IncompressibleNavierStokes = "5e318141-6589-402b-868d-77d7df8c442e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25" -[sources] -Examples = {path = "../examples"} -IncompressibleNavierStokes = {path = ".."} -NeuralClosure = {path = "../lib/NeuralClosure"} - [compat] Documenter = "1" DocumenterCitations = "1" @@ -21,3 +17,12 @@ IncompressibleNavierStokes = "2" Literate = "2" NeuralClosure = "1" julia = "1.9" + +[sources.Examples] +path = "../examples" + +[sources.IncompressibleNavierStokes] +path = ".." + +[sources.NeuralClosure] +path = "../lib/NeuralClosure" diff --git a/docs/src/manual/differentiability.md b/docs/src/manual/differentiability.md index 4bb0fd479..cccfcdf80 100644 --- a/docs/src/manual/differentiability.md +++ b/docs/src/manual/differentiability.md @@ -75,37 +75,39 @@ The downside is that restricts the user's defined f function to not do things li In this example we differentiate the right-hand side of the Navier-Stokes equations with respect to the velocity field `u`: -```julia +```@example import IncompressibleNavierStokes as INS +using Enzyme ax = range(0, 1, 101) setup = INS.Setup(; x = (ax, ax), Re = 500.0) psolver = INS.default_psolver(setup) u = INS.random_field(setup) +dudt = similar(u) t = 0.0 f! = INS.right_hand_side! ``` Notice that we are using the mutating (in-place) version of the right-hand side function. This function can not be differentiate by Zygote, which requires the slower non-mutating version of the right-hand side. We then define the `Dual` part of the input and output, required to store the adjoint values: -```julia +```@example ddudt = Enzyme.make_zero(dudt) .+ 1; du = Enzyme.make_zero(u); ``` Remember that the derivative of the output (also called the *seed*) has to be set to $1$ in order to compute the gradient. In this case the output is the force, that we store mutating the value of `dudt` inside `right_hand_side!`. Then we pack the parameters to be passed to `right_hand_side!`: -```julia +```@example params = [setup, psolver]; params_ref = Ref(params); ``` Now, we call the `autodiff` function from Enzyme: -```julia -Enzyme.autodiff(Enzyme.Reverse, f!, Duplicated(dudt,dd), Duplicated(u,du), Const(params_ref), Const(t)) +```@example +Enzyme.autodiff(Enzyme.Reverse, f!, Duplicated(dudt,ddudt), Duplicated(u,du), Const(params_ref), Const(t)) ``` Since we have passed a `Duplicated` object, the gradient of `u` is stored in `du`. Finally, we can also compare its value with the one obtained by Zygote differentiating the out-of-place (non-mutating) version of the right-hand side: -```julia +```@example using Zygote f = create_right_hand_side(setup, psolver) _, zpull = Zygote.pullback(f, u, nothing, T(0)); From 539b32013aeb374ff702254055c0539473cbd2eb Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 20 Nov 2024 14:50:21 +0100 Subject: [PATCH 07/14] Format --- src/IncompressibleNavierStokes.jl | 3 +- src/boundary_conditions.jl | 54 +++++- src/operators.jl | 190 ++++++++++++++++--- src/pressure.jl | 23 ++- src/sciml.jl | 44 +++-- src/utils.jl | 2 +- test/chainrules_enzyme.jl | 291 ++++++++++++++++++++++-------- test/enzyme_integration.jl | 88 +++++---- test/runtests.jl | 2 +- 9 files changed, 527 insertions(+), 170 deletions(-) diff --git a/src/IncompressibleNavierStokes.jl b/src/IncompressibleNavierStokes.jl index 8ab2fe584..e363a37f5 100644 --- a/src/IncompressibleNavierStokes.jl +++ b/src/IncompressibleNavierStokes.jl @@ -150,7 +150,6 @@ export apply_bc_u, Qfield # SciML operations -export create_right_hand_side, - right_hand_side! +export create_right_hand_side, right_hand_side! end diff --git a/src/boundary_conditions.jl b/src/boundary_conditions.jl index 944e64683..bb9172bd4 100644 --- a/src/boundary_conditions.jl +++ b/src/boundary_conditions.jl @@ -528,9 +528,10 @@ apply_bc_temp!(bc::PressureBC, temp, β, t, setup; isright, kwargs...) = apply_bc_temp_pullback!(bc::PressureBC, φbar, β, t, setup; isright, kwargs...) = apply_bc_p_pullback!(SymmetricBC(), φbar, β, t, setup; isright, kwargs...) - # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. -function enzyme_wrap(f::Union{typeof(apply_bc_u!), typeof(apply_bc_p!), typeof(apply_bc_temp!)}) +function enzyme_wrap( + f::Union{typeof(apply_bc_u!),typeof(apply_bc_p!),typeof(apply_bc_temp!)}, +) # the boundary condition modifies x which is usually the field that we want to differentiate, so we need to introduce a copy of it and modify it instead function wrapped_f(y, x, args...) y .= x @@ -540,25 +541,64 @@ function enzyme_wrap(f::Union{typeof(apply_bc_u!), typeof(apply_bc_p!), typeof(a return wrapped_f end -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(apply_bc_u!))}, Const{typeof(enzyme_wrap(apply_bc_p!))}, Const{typeof(enzyme_wrap(apply_bc_temp!))}}, ::Type{<:Const}, y::Duplicated, x::Duplicated, t::Const, setup::Const) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Union{ + Const{typeof(enzyme_wrap(apply_bc_u!))}, + Const{typeof(enzyme_wrap(apply_bc_p!))}, + Const{typeof(enzyme_wrap(apply_bc_temp!))}, + }, + ::Type{<:Const}, + y::Duplicated, + x::Duplicated, + t::Const, + setup::Const, +) primal = func.val(y.val, x.val, t.val, setup.val) return AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_u!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(apply_bc_u!))}, + dret, + tape, + y::Duplicated, + x::Duplicated, + t::Const, + setup::Const, +) adj = apply_bc_u_pullback!(x.val, t.val, setup.val) x.dval .+= adj y.dval .= x.dval # y is a copy of x return (nothing, nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_p!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(apply_bc_p!))}, + dret, + tape, + y::Duplicated, + x::Duplicated, + t::Const, + setup::Const, +) adj = apply_bc_p_pullback!(x.val, t.val, setup.val) x.dval .+= adj y.dval .= x.dval # y is a copy of x return (nothing, nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(apply_bc_temp!))}, dret, tape, y::Duplicated, x::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(apply_bc_temp!))}, + dret, + tape, + y::Duplicated, + x::Duplicated, + t::Const, + setup::Const, +) adj = apply_bc_temp_pullback!(x.val, t.val, setup.val) - x.dval .+= adj + x.dval .+= adj y.dval .= x.dval # y is a copy of x return (nothing, nothing, nothing, nothing) end diff --git a/src/operators.jl b/src/operators.jl index fc6dcea3c..b33dd298d 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -202,7 +202,12 @@ applypressure(u, p, setup) = applypressure!(copy.(u), p, setup) ChainRulesCore.rrule(::typeof(applypressure), u, p, setup) = ( applypressure(u, p, setup), - φ -> (NoTangent(), NoTangent(), applypressure_adjoint!(scalarfield(setup), φ, nothing, setup), NoTangent()), + φ -> ( + NoTangent(), + NoTangent(), + applypressure_adjoint!(scalarfield(setup), φ, nothing, setup), + NoTangent(), + ), ) "Subtract pressure gradient (in-place version)." @@ -257,10 +262,10 @@ function applypressure_adjoint!(pbar, φ, u, setup) local p_I = zero(eltype(p)) # Loop over each dimension to compute adjoint contributions - for α in 1:D + for α = 1:D # Contribution from φ[I - e(α)] / Δu[α][I[α] - 1] if I - e(α) ∈ Iu[α] - p_I += φ[I - e(α), α] / Δu[α][I[α] - 1] + p_I += φ[I-e(α), α] / Δu[α][I[α]-1] end # Contribution from -φ[I, α] / Δu[α][I[α]] @@ -280,7 +285,6 @@ function applypressure_adjoint!(pbar, φ, u, setup) return pbar end - "Compute Laplacian of pressure field (differentiable version)." laplacian(p, setup) = laplacian!(scalarfield(setup), p, setup) @@ -1462,7 +1466,19 @@ function get_scale_numbers(u, setup) end # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. -function enzyme_wrap(f::Union{typeof(divergence!), typeof(pressuregradient!), typeof(convection!), typeof(diffusion!), typeof(applybodyforce!), typeof(gravity!), typeof(dissipation!), typeof(convection_diffusion_temp!), typeof(momentum!)}) +function enzyme_wrap( + f::Union{ + typeof(divergence!), + typeof(pressuregradient!), + typeof(convection!), + typeof(diffusion!), + typeof(applybodyforce!), + typeof(gravity!), + typeof(dissipation!), + typeof(convection_diffusion_temp!), + typeof(momentum!), + }, +) function wrapped_f(args...) f(args...) return nothing @@ -1470,8 +1486,20 @@ function enzyme_wrap(f::Union{typeof(divergence!), typeof(pressuregradient!), ty return wrapped_f end - -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(divergence!))},Const{typeof(enzyme_wrap(pressuregradient!))},Const{typeof(enzyme_wrap(convection!))},Const{typeof(enzyme_wrap(diffusion!))},Const{typeof(enzyme_wrap(gravity!))}}, ::Type{<:Const}, y::Duplicated, u::Duplicated, setup::Const) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Union{ + Const{typeof(enzyme_wrap(divergence!))}, + Const{typeof(enzyme_wrap(pressuregradient!))}, + Const{typeof(enzyme_wrap(convection!))}, + Const{typeof(enzyme_wrap(diffusion!))}, + Const{typeof(enzyme_wrap(gravity!))}, + }, + ::Type{<:Const}, + y::Duplicated, + u::Duplicated, + setup::Const, +) primal = func.val(y.val, u.val, setup.val) if overwritten(config)[3] tape = copy(u.val) @@ -1480,37 +1508,76 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Con end return AugmentedReturn(primal, nothing, tape) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(divergence!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(divergence!))}, + dret, + tape, + y::Duplicated, + u::Duplicated, + setup::Const, +) adj = vectorfield(setup.val) divergence_adjoint!(adj, y.val, setup.val) - u.dval .+= adj + u.dval .+= adj make_zero!(y.dval) return (nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(pressuregradient!))}, dret, tape, y::Duplicated, p::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(pressuregradient!))}, + dret, + tape, + y::Duplicated, + p::Duplicated, + setup::Const, +) adj = scalarfield(setup.val) pressuregradient_adjoint!(adj, y.val, setup.val) - p.dval .+= adj + p.dval .+= adj make_zero!(y.dval) return (nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(convection!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(convection!))}, + dret, + tape, + y::Duplicated, + u::Duplicated, + setup::Const, +) adj = zero(u.val) convection_adjoint!(adj, y.val, u.val, setup.val) - u.dval .+= adj + u.dval .+= adj make_zero!(y.dval) return (nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(diffusion!))}, dret, tape, y::Duplicated, u::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(diffusion!))}, + dret, + tape, + y::Duplicated, + u::Duplicated, + setup::Const, +) adj = zero(u.val) diffusion_adjoint!(adj, y.val, setup.val) - u.dval .+= adj + u.dval .+= adj make_zero!(y.dval) return (nothing, nothing, nothing) end - -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(applybodyforce!))}}, ::Type{<:Const}, y::Duplicated, u::Duplicated, t::Const, setup::Const) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Union{Const{typeof(enzyme_wrap(applybodyforce!))}}, + ::Type{<:Const}, + y::Duplicated, + u::Duplicated, + t::Const, + setup::Const, +) primal = func.val(y.val, u.val, t.val, setup.val) if overwritten(config)[3] tape = copy(u.val) @@ -1519,7 +1586,16 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Con end return AugmentedReturn(primal, nothing, tape) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(applybodyforce!))}, dret, tape, y::Duplicated, u::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(applybodyforce!))}, + dret, + tape, + y::Duplicated, + u::Duplicated, + t::Const, + setup::Const, +) @warn "bodyforce Enzyme-AD tested only for issteadybodyforce=true" adj = setup.val.bodyforce u.dval .+= adj .* y.dval @@ -1527,7 +1603,15 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzym return (nothing, nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(gravity!))}, dret, tape, y::Duplicated, temp::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(gravity!))}, + dret, + tape, + y::Duplicated, + temp::Duplicated, + setup::Const, +) (; grid, backend, workgroupsize, temperature) = setup.val (; dimension, Δ, N, Iu) = grid (; gdir, α2) = temperature @@ -1557,8 +1641,18 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzym return (nothing, nothing, nothing) end - -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(dissipation!))}, Const{typeof(enzyme_wrap(convection_diffusion_temp!))}}, ::Type{<:Const}, y::Duplicated, x1::Duplicated, x2::Duplicated, setup::Const) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Union{ + Const{typeof(enzyme_wrap(dissipation!))}, + Const{typeof(enzyme_wrap(convection_diffusion_temp!))}, + }, + ::Type{<:Const}, + y::Duplicated, + x1::Duplicated, + x2::Duplicated, + setup::Const, +) primal = func.val(y.val, x1.val, x2.val, setup.val) if overwritten(config)[3] tape = copy(x2.val) @@ -1567,7 +1661,16 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Con end return AugmentedReturn(primal, nothing, tape) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(dissipation!))}, dret, tape, y::Duplicated, d::Duplicated, u::Duplicated, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(dissipation!))}, + dret, + tape, + y::Duplicated, + d::Duplicated, + u::Duplicated, + setup::Const, +) (; grid, backend, workgroupsize, Re, temperature) = setup.val (; dimension, N, Ip) = grid (; α1, γ) = temperature @@ -1611,13 +1714,42 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzym return (nothing, nothing, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(convection_diffusion_temp!))}, dret, tape, y::Duplicated, temp::Duplicated, u::Duplicated, setup::Const) - @error "convection_diffusion_temp Enzyme-AD not yet implemented" -end - -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Union{Const{typeof(enzyme_wrap(momentum!))}}, ::Type{<:Const}, y::Duplicated, x1::Duplicated, x2::Duplicated, x3::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(convection_diffusion_temp!))}, + dret, + tape, + y::Duplicated, + temp::Duplicated, + u::Duplicated, + setup::Const, +) + @error "convection_diffusion_temp Enzyme-AD not yet implemented" +end + +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Union{Const{typeof(enzyme_wrap(momentum!))}}, + ::Type{<:Const}, + y::Duplicated, + x1::Duplicated, + x2::Duplicated, + x3::Duplicated, + t::Const, + setup::Const, +) @error "momentum Enzyme-AD not yet implemented" end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(momentum!))}, dret, tape, y::Duplicated, u::Duplicated, temp::Duplicated, t::Const, setup::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(momentum!))}, + dret, + tape, + y::Duplicated, + u::Duplicated, + temp::Duplicated, + t::Const, + setup::Const, +) @error "momentum Enzyme-AD not yet implemented" -end \ No newline at end of file +end diff --git a/src/pressure.jl b/src/pressure.jl index a9d8d4521..1dd13d198 100644 --- a/src/pressure.jl +++ b/src/pressure.jl @@ -361,18 +361,33 @@ end # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. function enzyme_wrap(f::typeof(poisson!)) function wrapped_f(y, args...) - y.= f(args...) + y .= f(args...) return nothing end return wrapped_f end -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(poisson!))}, ::Type{<:Const}, y::Duplicated, psolver::Const, div::Duplicated) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(poisson!))}, + ::Type{<:Const}, + y::Duplicated, + psolver::Const, + div::Duplicated, +) primal = func.val(y.val, psolver.val, div.val) return AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(enzyme_wrap(poisson!))}, dret, tape, y::Duplicated, psolver::Const, div::Duplicated) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(enzyme_wrap(poisson!))}, + dret, + tape, + y::Duplicated, + psolver::Const, + div::Duplicated, +) auto_adj = copy(y.val) - func.val(auto_adj, psolver.val, y.val ) + func.val(auto_adj, psolver.val, y.val) div.dval .+= auto_adj .* y.dval make_zero!(y.dval) return (nothing, nothing, nothing) diff --git a/src/sciml.jl b/src/sciml.jl index 3ab11c30d..f7c1ad13b 100644 --- a/src/sciml.jl +++ b/src/sciml.jl @@ -10,14 +10,12 @@ Creates a function that computes the right-hand side of the Navier-Stokes equati # Returns A function that computes the right-hand side of the Navier-Stokes equations. """ -function create_right_hand_side(setup, psolver) - function right_hand_side(u, param, t) - F = zeros(size(u)) - u = apply_bc_u(u, t, setup) - F = momentum(u, nothing, t, setup) - F = apply_bc_u(F, t, setup; dudt = true) - FP = project(F, setup; psolver) - end +create_right_hand_side(setup, psolver) = function right_hand_side(u, param, t) + F = zeros(size(u)) + u = apply_bc_u(u, t, setup) + F = momentum(u, nothing, t, setup) + F = apply_bc_u(F, t, setup; dudt = true) + FP = project(F, setup; psolver) end """ @@ -48,8 +46,15 @@ function right_hand_side!(dudt, u, params_ref, t) return nothing end - -function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, ::Type{<:Const}, dudt::Duplicated, u::Duplicated, params_ref::Any, t::Const) +function EnzymeRules.augmented_primal( + config::RevConfigWidth{1}, + func::Const{typeof(right_hand_side!)}, + ::Type{<:Const}, + dudt::Duplicated, + u::Duplicated, + params_ref::Any, + t::Const, +) # this runs function to modify dudt and store the intermediates params = params_ref.val[] setup = params[1] @@ -62,7 +67,16 @@ function EnzymeRules.augmented_primal(config::RevConfigWidth{1}, func::Const{typ project!(dudt.val, setup; psolver, p) return AugmentedReturn(nothing, nothing, u_bc) end -function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, dret, u_bc, dudt::Duplicated, u::Duplicated, params_ref::Const, t::Const) +function EnzymeRules.reverse( + config::RevConfigWidth{1}, + func::Const{typeof(right_hand_side!)}, + dret, + u_bc, + dudt::Duplicated, + u::Duplicated, + params_ref::Const, + t::Const, +) # unpack the parameters params = params_ref.val[] setup = params[1] @@ -75,11 +89,11 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(right # [!] notice that the chain starts from the final value of dudt because it gets modified in place in the forward pass dudt.dval .*= dudt.val # [!] the minus sign is missing somewhere in the adjoint - dp .= - applypressure_adjoint!(temp_scalar, dudt.dval, nothing, setup) + dp .= -applypressure_adjoint!(temp_scalar, dudt.dval, nothing, setup) apply_bc_p_pullback!(dp, t.val, setup) - poisson!(psolver,dp) + poisson!(psolver, dp) scalewithvolume!(dp, setup) dudt.dval .+= divergence_adjoint!(temp_vector, dp, setup) @@ -92,6 +106,6 @@ function EnzymeRules.reverse(config::RevConfigWidth{1}, func::Const{typeof(right u.dval .+= diffusion_adjoint!(temp_vector, dudt.dval, setup) apply_bc_u_pullback!(u.dval, t.val, setup) - + return (nothing, nothing, nothing, nothing) -end \ No newline at end of file +end diff --git a/src/utils.jl b/src/utils.jl index 3f6abce7c..3e11e3ebe 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -126,4 +126,4 @@ function get_spectrum(setup; npoint = 100, a = typeof(e.setup.Re)(1 + sqrt(5)) / BoolArray = typeof(similar(xp[1], Bool, ntuple(Returns(0), D)...)) masks = adapt.(BoolArray, masks) (; κ, masks, K) -end \ No newline at end of file +end diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index 2b951c626..a05ee5195 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -44,7 +44,15 @@ end bodyforce = (dim, x, y, z, t) -> (dim == 1) * 5 * sinpi(8 * y) dbodyforce = (dim, x, y, z, t) -> (dim == 1) * 5 * pi * 8 * cos(pi * 8 * y) end - setup = Setup(; x, boundary_conditions, Re, temperature, bodyforce, dbodyforce, issteadybodyforce = true) + setup = Setup(; + x, + boundary_conditions, + Re, + temperature, + bodyforce, + dbodyforce, + issteadybodyforce = true, + ) psolver = default_psolver(setup) u = randn(T, setup.grid.N..., D) p = randn(T, setup.grid.N) @@ -65,7 +73,7 @@ end n = 7 lims = T(0), T(1) x = range(lims..., n + 1), range(lims..., n + 1) - + for bc in (PeriodicBC(), DirichletBC(), SymmetricBC(), PressureBC()) boundary_conditions = (bc, bc), (bc, bc) setup = Setup(; @@ -86,13 +94,27 @@ end zpull, z_time = @timed Zygote.pullback(apply_bc_u, u, nothing, setup)[2](u0)[1] du = Enzyme.make_zero(u) y = Enzyme.make_zero(u) - dy = Enzyme.make_zero(u) .+1 + dy = Enzyme.make_zero(u) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_u!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(u, du), Const(nothing), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(u, du), + Const(nothing), + Const(setup), + ) du = Enzyme.make_zero(u) y = Enzyme.make_zero(u) - dy = Enzyme.make_zero(u) .+1 - eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(u0, du), Const(nothing), Const(setup)) + dy = Enzyme.make_zero(u) .+ 1 + eg, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(u0, du), + Const(nothing), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (bc_u): ", e_time, " vs ", z_time @@ -100,20 +122,34 @@ end @info "Zygote is faster (bc_u): ", z_time, " vs ", e_time end end - @test du == zpull + @test du == zpull # --- bc_p Zygote.pullback(apply_bc_p, p, nothing, setup)[2](p0)[1] zpull, z_time = @timed Zygote.pullback(apply_bc_p, p, nothing, setup)[2](p0)[1] dp = Enzyme.make_zero(p) y = Enzyme.make_zero(p) - dy = Enzyme.make_zero(p) .+1 + dy = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_p!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(p0, dp), Const(nothing), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(p0, dp), + Const(nothing), + Const(setup), + ) dp = Enzyme.make_zero(p) y = Enzyme.make_zero(p) - dy = Enzyme.make_zero(p) .+1 - eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(p0, dp), Const(nothing), Const(setup)) + dy = Enzyme.make_zero(p) .+ 1 + eg, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(p0, dp), + Const(nothing), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (bc_p): ", e_time, " vs ", z_time @@ -123,20 +159,34 @@ end end @test dp == zpull - # --- bc_temp Zygote.pullback(apply_bc_temp, temp, nothing, setup)[2](temp0)[1] - zpull, z_time = @timed Zygote.pullback(apply_bc_temp, temp, nothing, setup)[2](temp0)[1] + zpull, z_time = + @timed Zygote.pullback(apply_bc_temp, temp, nothing, setup)[2](temp0)[1] dtemp = Enzyme.make_zero(temp) y = Enzyme.make_zero(temp) - dy = Enzyme.make_zero(temp) .+1 + dy = Enzyme.make_zero(temp) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_temp!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(temp0, dtemp), Const(nothing), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(temp0, dtemp), + Const(nothing), + Const(setup), + ) dtemp = Enzyme.make_zero(temp) y = Enzyme.make_zero(temp) - dy = Enzyme.make_zero(temp) .+1 + dy = Enzyme.make_zero(temp) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_temp!) - eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(y, dy), Duplicated(temp0, dtemp), Const(nothing), Const(setup)) + eg, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(y, dy), + Duplicated(temp0, dtemp), + Const(nothing), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (bc_temp): ", e_time, " vs ", z_time @@ -145,25 +195,35 @@ end end end @test dtemp == zpull - end end - @testitem "Divergence" setup = [Case, ChainRulesStuff] begin - - for (u,setup,d) in ((Case.D2.u, Case.D2.setup, Case.D2.div), (Case.D3.u, Case.D3.setup, Case.D3.div)) + for (u, setup, d) in + ((Case.D2.u, Case.D2.setup, Case.D2.div), (Case.D3.u, Case.D3.setup, Case.D3.div)) d0 = copy(d) u0 = copy(u) Zygote.pullback(INS.divergence, u, setup)[2](d0)[1] zpull, z_time = @timed Zygote.pullback(INS.divergence, u, setup)[2](d0)[1] - dd = Enzyme.make_zero(d) .+1 + dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.divergence!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) - dd = Enzyme.make_zero(d) .+1 + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(d, dd), + Duplicated(u, du), + Const(setup), + ) + dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) - eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d0, dd), Duplicated(u0, du), Const(setup)) + eg, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(d0, dd), + Duplicated(u0, du), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (divergence): ", e_time, " vs ", z_time @@ -176,19 +236,31 @@ end end @testitem "Pressuregradient" setup = [Case, ChainRulesStuff] begin - for (p,setup) in ((Case.D2.p, Case.D2.setup), (Case.D3.p, Case.D3.setup)) + for (p, setup) in ((Case.D2.p, Case.D2.setup), (Case.D3.p, Case.D3.setup)) p0 = copy(p) pg = INS.pressuregradient(p, setup) pg0 = copy(pg) Zygote.pullback(INS.pressuregradient, p, setup)[2](pg0)[1] zpull, z_time = @timed Zygote.pullback(INS.pressuregradient, p, setup)[2](pg0)[1] - dpg = Enzyme.make_zero(pg) .+1 + dpg = Enzyme.make_zero(pg) .+ 1 dp = Enzyme.make_zero(p) f = INS.enzyme_wrap(INS.pressuregradient!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(pg, dpg), Duplicated(p, dp), Const(setup)) - dpg = Enzyme.make_zero(pg) .+1 + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(pg, dpg), + Duplicated(p, dp), + Const(setup), + ) + dpg = Enzyme.make_zero(pg) .+ 1 dp = Enzyme.make_zero(p) - eg, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(pg0, dpg), Duplicated(p0, dp), Const(setup)) + eg, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(pg0, dpg), + Duplicated(p0, dp), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (pressuregradient): ", e_time, " vs ", z_time @@ -198,23 +270,36 @@ end end @test dp == zpull end - end @testitem "Poisson" setup = [Case, ChainRulesStuff] begin - for (psolver,d,setup) in ((Case.D2.psolver, Case.D2.div, Case.D2.setup), (Case.D3.psolver, Case.D3.div, Case.D3.setup)) - + for (psolver, d, setup) in ( + (Case.D2.psolver, Case.D2.div, Case.D2.setup), + (Case.D3.psolver, Case.D3.div, Case.D3.setup), + ) p0 = INS.poisson(psolver, d) Zygote.pullback(INS.poisson, psolver, d)[2](p0)[1] zpull, z_time = @timed Zygote.pullback(INS.poisson, psolver, d)[2](p0)[2] - dd = Enzyme.make_zero(d) - p = Enzyme.make_zero(p0) - dp = Enzyme.make_zero(p) .+1 + dd = Enzyme.make_zero(d) + p = Enzyme.make_zero(p0) + dp = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.poisson!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p, dp), Const(psolver), Duplicated(d, dd)) - ep, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(p0, dp), Const(psolver), Duplicated(d, dd)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(p, dp), + Const(psolver), + Duplicated(d, dd), + ) + ep, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(p0, dp), + Const(psolver), + Duplicated(d, dd), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (poisson): ", e_time, " vs ", z_time @@ -224,11 +309,10 @@ end end @test dd == zpull end - end @testitem "Convection" setup = [Case, ChainRulesStuff] begin - for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) c = INS.convection(u, setup) u0 = copy(u) Zygote.pullback(INS.convection, u, setup)[2](u)[1] @@ -236,14 +320,26 @@ end # [!] convection! wants to start from 0 initialized field Enzyme.make_zero!(c) - dc = Enzyme.make_zero(c) .+1 + dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.convection!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(c, dc), Duplicated(u, du), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(c, dc), + Duplicated(u, du), + Const(setup), + ) Enzyme.make_zero!(c) - dc = Enzyme.make_zero(c) .+1 + dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) - ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(c, dc), Duplicated(u, du), Const(setup)) + ec, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(c, dc), + Duplicated(u, du), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (convection): ", e_time, " vs ", z_time @@ -256,7 +352,7 @@ end end @testitem "Diffusion" setup = [Case, ChainRulesStuff] begin - for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) d = INS.diffusion(u, setup) u0 = copy(u) Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] @@ -264,14 +360,26 @@ end # [!] diffusion! wants to start from 0 initialized field Enzyme.make_zero!(d) - dd = Enzyme.make_zero(d) .+1 + dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.diffusion!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(d, dd), + Duplicated(u, du), + Const(setup), + ) Enzyme.make_zero!(d) - dd = Enzyme.make_zero(d) .+1 + dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) - ec, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(d, dd), Duplicated(u, du), Const(setup)) + ec, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(d, dd), + Duplicated(u, du), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (diffusion): ", e_time, " vs ", z_time @@ -285,26 +393,41 @@ end @testitem "Bodyforce" setup = [Case, ChainRulesStuff] begin @warn "bodyforce is tested only in the static case" - for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) + for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) t = 0.5 bf = INS.applybodyforce(u, t, setup) bf0 = copy(bf) setup0 = deepcopy(setup) Zygote.pullback(INS.applybodyforce, u, t, setup)[2](bf0) - zpull, z_time = @timed Zygote.pullback(INS.applybodyforce, u, t, setup)[2](bf0)[3].bodyforce - + zpull, z_time = + @timed Zygote.pullback(INS.applybodyforce, u, t, setup)[2](bf0)[3].bodyforce + # We can also test Zygote autodiff @test zpull == setup.bodyforce - bf = bf .*0 - dbf = Enzyme.make_zero(bf) .+1 + bf = bf .* 0 + dbf = Enzyme.make_zero(bf) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.applybodyforce!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(bf, dbf), Duplicated(u, du), Const(t), Const(setup)) - bf = bf .*0 - dbf = Enzyme.make_zero(bf) .+1 + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(bf, dbf), + Duplicated(u, du), + Const(t), + Const(setup), + ) + bf = bf .* 0 + dbf = Enzyme.make_zero(bf) .+ 1 du = Enzyme.make_zero(u) - eb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(bf, dbf), Duplicated(u, du), Const(t), Const(setup)) + eb, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(bf, dbf), + Duplicated(u, du), + Const(t), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (bodyforce): ", e_time, " vs ", z_time @@ -317,21 +440,32 @@ end end @testitem "Gravity" setup = [Case, ChainRulesStuff] begin - for (t,setup) in ((Case.D2.temp, Case.D2.setup), (Case.D3.temp, Case.D3.setup)) - + for (t, setup) in ((Case.D2.temp, Case.D2.setup), (Case.D3.temp, Case.D3.setup)) g = INS.gravity(t, setup) Zygote.pullback(INS.gravity, t, setup)[2](g) zpull, z_time = @timed Zygote.pullback(INS.gravity, t, setup)[2](g)[1] g = vectorfield(setup) - dg = Enzyme.make_zero(g) .+1 + dg = Enzyme.make_zero(g) .+ 1 dt = Enzyme.make_zero(t) f = INS.enzyme_wrap(INS.gravity!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(g, dg), Duplicated(t, dt), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(g, dg), + Duplicated(t, dt), + Const(setup), + ) g = vectorfield(setup) - dg = Enzyme.make_zero(g) .+1 + dg = Enzyme.make_zero(g) .+ 1 dt = Enzyme.make_zero(t) - gb, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(g, dg), Duplicated(t, dt), Const(setup)) + gb, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(g, dg), + Duplicated(t, dt), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (gravity): ", e_time, " vs ", z_time @@ -340,30 +474,42 @@ end end end @test dt == zpull - end end @testitem "Dissipation" setup = [Case, ChainRulesStuff] begin - for (u,setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) - + for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) diss = INS.dissipation(u, setup) Zygote.pullback(INS.dissipation, u, setup)[2](diss) zpull, z_time = @timed Zygote.pullback(INS.dissipation, u, setup)[2](diss)[1] diss = scalarfield(setup) diff = vectorfield(setup) - ddiss = Enzyme.make_zero(diss) .+1 + ddiss = Enzyme.make_zero(diss) .+ 1 ddiff = Enzyme.make_zero(diff) du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.dissipation!) - Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(diss, ddiss), Duplicated(diff, ddiff), Duplicated(u,du), Const(setup)) + Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(diss, ddiss), + Duplicated(diff, ddiff), + Duplicated(u, du), + Const(setup), + ) diss = scalarfield(setup) diff = vectorfield(setup) - diss = Enzyme.make_zero(diss) .+1 + diss = Enzyme.make_zero(diss) .+ 1 diff = Enzyme.make_zero(diff) du = Enzyme.make_zero(u) - ed, e_time = @timed Enzyme.autodiff(Enzyme.Reverse, f, Duplicated(diss, ddiss), Duplicated(diff, ddiff), Duplicated(u,du), Const(setup)) + ed, e_time = @timed Enzyme.autodiff( + Enzyme.Reverse, + f, + Duplicated(diss, ddiss), + Duplicated(diff, ddiff), + Duplicated(u, du), + Const(setup), + ) if ENABLE_LOGGING if e_time < z_time @info "Enzyme is faster (dissipation): ", e_time, " vs ", z_time @@ -372,7 +518,6 @@ end end end @test du == zpull - end end @testitem "Convection_diffusion_temp" setup = [Case, ChainRulesStuff] begin @@ -380,6 +525,6 @@ end end @testitem "Convectiondiffusion" setup = [Case, ChainRulesStuff] begin -# the pullback rule is missing for this one + # the pullback rule is missing for this one @test_broken 1 == 2 end diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl index ced39c567..8559b1eef 100644 --- a/test/enzyme_integration.jl +++ b/test/enzyme_integration.jl @@ -4,7 +4,7 @@ using Enzyme using Zygote using Random - rng = Random.default_rng(); + rng = Random.default_rng() end @testmodule Info begin @@ -15,20 +15,20 @@ end D = 2 n = 64 N = n + 2 - lims = T(0), T(1); - x = tanh_grid(lims..., n), tanh_grid(lims..., n, 1.3) ; + lims = T(0), T(1) + x = tanh_grid(lims..., n), tanh_grid(lims..., n, 1.3) boundary_conditions = ntuple(d -> (DirichletBC(), DirichletBC()), D) - setup = Setup(;x, boundary_conditions, Re); + setup = Setup(; x, boundary_conditions, Re) psolver = default_psolver(setup) end @testitem "Enzyme one force pullback" setup = [Info, EnzymeStuff] begin - for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N), ) - dudt = zeros(T, (N, N, 2)) ; - u = rand(T, (N, N, 2)); - u0 = copy(u); - params = [setup, psolver]; - params_ref = Ref(params); + for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N),) + dudt = zeros(T, (N, N, 2)) + u = rand(T, (N, N, 2)) + u0 = copy(u) + params = [setup, psolver] + params_ref = Ref(params) right_hand_side!(dudt, u, params_ref, T(0)) F_out = create_right_hand_side(setup, psolver) @test dudt ≈ F_out(u, nothing, T(0)) @@ -36,21 +36,21 @@ end @test sum(dudt) != 0 niter = 5000 - list_u = [rand(T, (N, N, 2)) for i in 1:niter]; + list_u = [rand(T, (N, N, 2)) for i = 1:niter] list_z = [] _, tz, mz = @timed begin - for i in 1:niter + for i = 1:niter dudt = F_out(list_u[i], nothing, T(0)) push!(list_z, dudt) end end - list_e = [zeros(T, (N, N, 2)) for i in 1:niter]; + list_e = [zeros(T, (N, N, 2)) for i = 1:niter] _, te, me = @timed begin - for i in 1:niter + for i = 1:niter right_hand_side!(list_e[i], list_u[i], params_ref, T(0)) end end - @test all([list_z[i] ≈ list_e[i] for i in 1:niter]) + @test all([list_z[i] ≈ list_e[i] for i = 1:niter]) if te < tz @info "Right-hand side in-place is faster by a factor of $(tz/te)" else @@ -61,50 +61,62 @@ end else @info "Right-hand side out-of-place is more memory efficient by a factor of $(me/mz)" end - end end @testitem "Enzyme RHS pullback" setup = [Info, EnzymeStuff] begin - for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N), ) + for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N),) F_out = create_right_hand_side(setup, psolver) - dudt = zeros(T, (N, N, 2)) ; - u = rand(T, (N, N, 2)); + dudt = zeros(T, (N, N, 2)) + u = rand(T, (N, N, 2)) u0 = copy(u) - du = Enzyme.make_zero(u); - dd = Enzyme.make_zero(dudt) .+ 1; - params = [setup, psolver]; - params_ref = Ref(params); - Enzyme.autodiff(Enzyme.Reverse, right_hand_side!, Duplicated(dudt,dd), Duplicated(u,du), Const(params_ref), Const(T(0))) + du = Enzyme.make_zero(u) + dd = Enzyme.make_zero(dudt) .+ 1 + params = [setup, psolver] + params_ref = Ref(params) + Enzyme.autodiff( + Enzyme.Reverse, + right_hand_side!, + Duplicated(dudt, dd), + Duplicated(u, du), + Const(params_ref), + Const(T(0)), + ) @test u0 == u @test dudt ≈ F_out(u, nothing, T(0)) - zpull = Zygote.pullback(F_out, u, nothing, T(0)); + zpull = Zygote.pullback(F_out, u, nothing, T(0)) @test zpull[1] ≈ dudt - @test zpull[2](dudt)[1] ==du - + @test zpull[2](dudt)[1] == du # Now I run each option multiple times from different random initial conditions niter = 3000 - list_u = [rand(T, (N, N, 2)) for i in 1:niter]; + list_u = [rand(T, (N, N, 2)) for i = 1:niter] list_z = [] _, tz, mz = @timed begin - for i in 1:niter - du = Enzyme.make_zero(u); - dd = Enzyme.make_zero(dudt) .+ 1; - zpull = Zygote.pullback(F_out, list_u[i], nothing, T(0)); + for i = 1:niter + du = Enzyme.make_zero(u) + dd = Enzyme.make_zero(dudt) .+ 1 + zpull = Zygote.pullback(F_out, list_u[i], nothing, T(0)) push!(list_z, zpull[2](zpull[1])[1]) end end list_e = [] _, te, me = @timed begin - for i in 1:niter - du = Enzyme.make_zero(u); - dd = Enzyme.make_zero(dudt) .+ 1; - Enzyme.autodiff(Enzyme.Reverse, right_hand_side!, Duplicated(dudt,dd), Duplicated(list_u[i],du), Const(params_ref), Const(T(0))) + for i = 1:niter + du = Enzyme.make_zero(u) + dd = Enzyme.make_zero(dudt) .+ 1 + Enzyme.autodiff( + Enzyme.Reverse, + right_hand_side!, + Duplicated(dudt, dd), + Duplicated(list_u[i], du), + Const(params_ref), + Const(T(0)), + ) push!(list_e, du) end end - @test all([list_z[i] ≈ list_e[i] for i in 1:niter]) + @test all([list_z[i] ≈ list_e[i] for i = 1:niter]) if te < tz @info "Reverse AD using Enzyme is faster by a factor of $(tz/te)" else @@ -116,4 +128,4 @@ end @info "Reverse AD using Zygote is more memory efficient by a factor of $(me/mz)" end end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 97a2f4d68..f24a13581 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,4 @@ using TestItemRunner #function myfilter(t) # return endswith(t.filename, "enzyme_integration.jl") || endswith(t.filename, "chainrules_enzyme.jl") #end -#@run_package_tests filter = myfilter \ No newline at end of file +#@run_package_tests filter = myfilter From 1b23e429ff2e27da57f6321b9d482cdf837428cb Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 20 Nov 2024 14:56:55 +0100 Subject: [PATCH 08/14] Replace Enzyme with EnzymeCore --- Project.toml | 2 +- src/IncompressibleNavierStokes.jl | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3be81b73a..ae303d873 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,7 @@ version = "2.0.1" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" diff --git a/src/IncompressibleNavierStokes.jl b/src/IncompressibleNavierStokes.jl index e363a37f5..69c70ca35 100644 --- a/src/IncompressibleNavierStokes.jl +++ b/src/IncompressibleNavierStokes.jl @@ -13,9 +13,8 @@ using Adapt using ChainRulesCore using DocStringExtensions using FFTW -using Enzyme -import .EnzymeRules: reverse, augmented_primal -using .EnzymeRules +using EnzymeCore +using EnzymeCore.EnzymeRules using IterativeSolvers using KernelAbstractions using KernelAbstractions.Extras.LoopInfo: @unroll From 6c8c5846a372071cf740b7307ef19aa8007ebd4f Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 20 Nov 2024 17:28:16 +0100 Subject: [PATCH 09/14] Fix make_zero! calls --- src/operators.jl | 14 +++++++------- src/pressure.jl | 2 +- test/chainrules_enzyme.jl | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index b33dd298d..1a9a1c5a8 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1520,7 +1520,7 @@ function EnzymeRules.reverse( adj = vectorfield(setup.val) divergence_adjoint!(adj, y.val, setup.val) u.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end function EnzymeRules.reverse( @@ -1535,7 +1535,7 @@ function EnzymeRules.reverse( adj = scalarfield(setup.val) pressuregradient_adjoint!(adj, y.val, setup.val) p.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end function EnzymeRules.reverse( @@ -1550,7 +1550,7 @@ function EnzymeRules.reverse( adj = zero(u.val) convection_adjoint!(adj, y.val, u.val, setup.val) u.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end function EnzymeRules.reverse( @@ -1565,7 +1565,7 @@ function EnzymeRules.reverse( adj = zero(u.val) diffusion_adjoint!(adj, y.val, setup.val) u.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end @@ -1599,7 +1599,7 @@ function EnzymeRules.reverse( @warn "bodyforce Enzyme-AD tested only for issteadybodyforce=true" adj = setup.val.bodyforce u.dval .+= adj .* y.dval - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing, nothing) end @@ -1637,7 +1637,7 @@ function EnzymeRules.reverse( end adj = gravity_pullback(y.val) temp.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end @@ -1710,7 +1710,7 @@ function EnzymeRules.reverse( end adj = dissipation_pullback(y.val) u.dval .+= adj - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing, nothing) end diff --git a/src/pressure.jl b/src/pressure.jl index 1dd13d198..a420b9257 100644 --- a/src/pressure.jl +++ b/src/pressure.jl @@ -389,6 +389,6 @@ function EnzymeRules.reverse( auto_adj = copy(y.val) func.val(auto_adj, psolver.val, y.val) div.dval .+= auto_adj .* y.dval - make_zero!(y.dval) + EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index a05ee5195..77f7e4d14 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -319,7 +319,7 @@ end zpull, z_time = @timed Zygote.pullback(INS.convection, u, setup)[2](c)[1] # [!] convection! wants to start from 0 initialized field - Enzyme.make_zero!(c) + EnzymeCore.make_zero!(c) dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.convection!) @@ -330,7 +330,7 @@ end Duplicated(u, du), Const(setup), ) - Enzyme.make_zero!(c) + EnzymeCore.make_zero!(c) dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) ec, e_time = @timed Enzyme.autodiff( @@ -359,7 +359,7 @@ end zpull, z_time = @timed Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] # [!] diffusion! wants to start from 0 initialized field - Enzyme.make_zero!(d) + EnzymeCore.make_zero!(d) dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.diffusion!) @@ -370,7 +370,7 @@ end Duplicated(u, du), Const(setup), ) - Enzyme.make_zero!(d) + EnzymeCore.make_zero!(d) dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) ec, e_time = @timed Enzyme.autodiff( From 046cfd424340416e1e72727092e671a99e70604c Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 21 Nov 2024 08:31:27 +0100 Subject: [PATCH 10/14] Fix Enzyme-Core deps in Project --- Project.toml | 1 + test/Project.toml | 9 +++++---- test/chainrules_enzyme.jl | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ae303d873..e877456f9 100644 --- a/Project.toml +++ b/Project.toml @@ -37,6 +37,7 @@ CUDA = "5" CUDSS = "0.3" ChainRulesCore = "1" DocStringExtensions = "0.9" +EnzymeCore = "0.8" FFTW = "1" IterativeSolvers = "0.9" KernelAbstractions = "0.9" diff --git a/test/Project.toml b/test/Project.toml index 9b2102deb..db7c243dc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" IncompressibleNavierStokes = "5e318141-6589-402b-868d-77d7df8c442e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -11,9 +13,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +IncompressibleNavierStokes = {path = ".."} + [compat] Aqua = "0.8" CairoMakie = "0.12" @@ -25,6 +29,3 @@ Logging = "1" Statistics = "1" TestItemRunner = "1" julia = "1.9" - -[sources.IncompressibleNavierStokes] -path = ".." diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index 77f7e4d14..d465d6de7 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -4,6 +4,7 @@ using Enzyme using Zygote using IncompressibleNavierStokes: IncompressibleNavierStokes as INS + using EnzymeCore import .EnzymeRules: reverse, augmented_primal using .EnzymeRules ENABLE_LOGGING = false From 19fd96e506cba6e5f99a07f961bb6a653e1f4333 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 21 Nov 2024 12:34:38 +0100 Subject: [PATCH 11/14] Fix testsnippet names --- test/chainrules.jl | 3 ++- test/chainrules_enzyme.jl | 41 ++++++++++++++++++++------------------ test/enzyme_integration.jl | 6 +++--- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index d5c13938d..6c3f8fd82 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -4,8 +4,9 @@ # Test chain rule correctness by comparing with finite differences "Use function name only as test set name" - test_rrule_named(f, args...; kwargs...) = + function test_rrule_named(f, args...; kwargs...) test_rrule(f, args...; testset_name = string(f), kwargs...) + end end @testitem "Chain rules (boundary conditions)" setup = [ChainRulesStuff] begin diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index d465d6de7..d6b7c6764 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -1,12 +1,8 @@ -@testsnippet ChainRulesStuff begin - using ChainRulesCore - using ChainRulesTestUtils +@testsnippet EnzymeSnip begin using Enzyme + using EnzymeCore using Zygote using IncompressibleNavierStokes: IncompressibleNavierStokes as INS - using EnzymeCore - import .EnzymeRules: reverse, augmented_primal - using .EnzymeRules ENABLE_LOGGING = false end @@ -63,9 +59,8 @@ end end end -@testitem "Chain rules (boundary conditions)" setup = [ChainRulesStuff] begin - import .EnzymeRules: reverse, augmented_primal - using .EnzymeRules +@testitem "Chain rules (boundary conditions)" setup = [EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS T = Float64 Re = T(1_000) Pr = T(0.71) @@ -199,7 +194,8 @@ end end end -@testitem "Divergence" setup = [Case, ChainRulesStuff] begin +@testitem "Divergence" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup, d) in ((Case.D2.u, Case.D2.setup, Case.D2.div), (Case.D3.u, Case.D3.setup, Case.D3.div)) d0 = copy(d) @@ -236,7 +232,8 @@ end end end -@testitem "Pressuregradient" setup = [Case, ChainRulesStuff] begin +@testitem "Pressuregradient" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (p, setup) in ((Case.D2.p, Case.D2.setup), (Case.D3.p, Case.D3.setup)) p0 = copy(p) pg = INS.pressuregradient(p, setup) @@ -273,7 +270,8 @@ end end end -@testitem "Poisson" setup = [Case, ChainRulesStuff] begin +@testitem "Poisson" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (psolver, d, setup) in ( (Case.D2.psolver, Case.D2.div, Case.D2.setup), (Case.D3.psolver, Case.D3.div, Case.D3.setup), @@ -312,7 +310,8 @@ end end end -@testitem "Convection" setup = [Case, ChainRulesStuff] begin +@testitem "Convection" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) c = INS.convection(u, setup) u0 = copy(u) @@ -352,7 +351,8 @@ end end end -@testitem "Diffusion" setup = [Case, ChainRulesStuff] begin +@testitem "Diffusion" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) d = INS.diffusion(u, setup) u0 = copy(u) @@ -392,7 +392,8 @@ end end end -@testitem "Bodyforce" setup = [Case, ChainRulesStuff] begin +@testitem "Bodyforce" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS @warn "bodyforce is tested only in the static case" for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) t = 0.5 @@ -440,7 +441,8 @@ end end end -@testitem "Gravity" setup = [Case, ChainRulesStuff] begin +@testitem "Gravity" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (t, setup) in ((Case.D2.temp, Case.D2.setup), (Case.D3.temp, Case.D3.setup)) g = INS.gravity(t, setup) Zygote.pullback(INS.gravity, t, setup)[2](g) @@ -478,7 +480,8 @@ end end end -@testitem "Dissipation" setup = [Case, ChainRulesStuff] begin +@testitem "Dissipation" setup = [Case, EnzymeSnip] begin + using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) diss = INS.dissipation(u, setup) Zygote.pullback(INS.dissipation, u, setup)[2](diss) @@ -521,11 +524,11 @@ end @test du == zpull end end -@testitem "Convection_diffusion_temp" setup = [Case, ChainRulesStuff] begin +@testitem "Convection_diffusion_temp" setup = [Case, EnzymeSnip] begin @test_broken 1 == 2 end -@testitem "Convectiondiffusion" setup = [Case, ChainRulesStuff] begin +@testitem "Convectiondiffusion" setup = [Case, EnzymeSnip] begin # the pullback rule is missing for this one @test_broken 1 == 2 end diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl index 8559b1eef..4fa895c43 100644 --- a/test/enzyme_integration.jl +++ b/test/enzyme_integration.jl @@ -1,5 +1,5 @@ -@testsnippet EnzymeStuff begin +@testsnippet EnzymeSnipPull begin using IncompressibleNavierStokes using Enzyme using Zygote @@ -22,7 +22,7 @@ end psolver = default_psolver(setup) end -@testitem "Enzyme one force pullback" setup = [Info, EnzymeStuff] begin +@testitem "Enzyme one force pullback" setup = [Info, EnzymeSnipPull] begin for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N),) dudt = zeros(T, (N, N, 2)) u = rand(T, (N, N, 2)) @@ -64,7 +64,7 @@ end end end -@testitem "Enzyme RHS pullback" setup = [Info, EnzymeStuff] begin +@testitem "Enzyme RHS pullback" setup = [Info, EnzymeSnipPull] begin for (setup, psolver, T, N) in ((Info.setup, Info.psolver, Info.T, Info.N),) F_out = create_right_hand_side(setup, psolver) dudt = zeros(T, (N, N, 2)) From f4820e635f5eda57a02066cf699fdc6cc2a123a0 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 21 Nov 2024 14:34:55 +0100 Subject: [PATCH 12/14] Add tests to increase coverage --- src/pressure.jl | 5 +++-- test/chainrules_enzyme.jl | 31 +++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/pressure.jl b/src/pressure.jl index a420b9257..3523057a8 100644 --- a/src/pressure.jl +++ b/src/pressure.jl @@ -360,8 +360,9 @@ end # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. function enzyme_wrap(f::typeof(poisson!)) - function wrapped_f(y, args...) - y .= f(args...) + function wrapped_f(p, psolve, d) + p .= d + f(psolve, p) return nothing end return wrapped_f diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index d6b7c6764..d218f9f0d 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -92,6 +92,9 @@ end y = Enzyme.make_zero(u) dy = Enzyme.make_zero(u) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_u!) + f(y, u, nothing, setup) + @test y != u + @test any(!iszero, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -127,6 +130,9 @@ end y = Enzyme.make_zero(p) dy = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_p!) + f(y, p, nothing, setup) + @test y != p + @test any(!iszero, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -163,6 +169,9 @@ end y = Enzyme.make_zero(temp) dy = Enzyme.make_zero(temp) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_temp!) + f(y, temp, nothing, setup) + @test y != temp + @test any(!iszero, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -205,6 +214,8 @@ end dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.divergence!) + f(d, u, setup) + @test d == d0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -243,6 +254,8 @@ end dpg = Enzyme.make_zero(pg) .+ 1 dp = Enzyme.make_zero(p) f = INS.enzyme_wrap(INS.pressuregradient!) + f(pg, p, setup) + @test pg == pg0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -284,6 +297,9 @@ end p = Enzyme.make_zero(p0) dp = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.poisson!) + f(p, psolver, d) + @test p == p0 + dp = Enzyme.make_zero(p) .+ 1 Enzyme.autodiff( Enzyme.Reverse, @@ -314,7 +330,7 @@ end using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) c = INS.convection(u, setup) - u0 = copy(u) + c0 = copy(c) Zygote.pullback(INS.convection, u, setup)[2](u)[1] zpull, z_time = @timed Zygote.pullback(INS.convection, u, setup)[2](c)[1] @@ -323,6 +339,8 @@ end dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.convection!) + f(c, u, setup) + @test c == c0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -355,7 +373,7 @@ end using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) d = INS.diffusion(u, setup) - u0 = copy(u) + d0 = copy(d) Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] zpull, z_time = @timed Zygote.pullback(INS.diffusion, u, setup)[2](d)[1] @@ -364,6 +382,8 @@ end dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.diffusion!) + f(d, u, setup) + @test d == d0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -411,6 +431,8 @@ end dbf = Enzyme.make_zero(bf) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.applybodyforce!) + f(bf, u, t, setup) + @test bf == bf0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -452,6 +474,8 @@ end dg = Enzyme.make_zero(g) .+ 1 dt = Enzyme.make_zero(t) f = INS.enzyme_wrap(INS.gravity!) + f(g, t, setup) + @test g != 0 Enzyme.autodiff( Enzyme.Reverse, f, @@ -484,6 +508,7 @@ end using IncompressibleNavierStokes: IncompressibleNavierStokes as INS for (u, setup) in ((Case.D2.u, Case.D2.setup), (Case.D3.u, Case.D3.setup)) diss = INS.dissipation(u, setup) + diss0 = copy(diss) Zygote.pullback(INS.dissipation, u, setup)[2](diss) zpull, z_time = @timed Zygote.pullback(INS.dissipation, u, setup)[2](diss)[1] @@ -493,6 +518,8 @@ end ddiff = Enzyme.make_zero(diff) du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.dissipation!) + f(diss, diff, u, setup) + @test diss == diss0 Enzyme.autodiff( Enzyme.Reverse, f, From 71439c7c4c235d43097cda4417bf7d4453e949e7 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Fri, 22 Nov 2024 10:07:01 +0100 Subject: [PATCH 13/14] Add tests to increase coverage --- src/IncompressibleNavierStokes.jl | 2 ++ test/chainrules_enzyme.jl | 17 ++++++++-- test/enzyme_integration.jl | 6 ++++ test/operators.jl | 55 +++++++++++++++++++++++++++++-- 4 files changed, 75 insertions(+), 5 deletions(-) diff --git a/src/IncompressibleNavierStokes.jl b/src/IncompressibleNavierStokes.jl index 69c70ca35..3ff2ac607 100644 --- a/src/IncompressibleNavierStokes.jl +++ b/src/IncompressibleNavierStokes.jl @@ -122,6 +122,7 @@ export apply_bc_u, apply_bc_p, apply_bc_temp, applybodyforce, + applypressure, convection_diffusion_temp, convection, diffusion, @@ -138,6 +139,7 @@ export apply_bc_u, laplacian_mat, momentum, poisson, + pressure, pressuregradient, project, scalewithvolume, diff --git a/test/chainrules_enzyme.jl b/test/chainrules_enzyme.jl index d218f9f0d..fa3794545 100644 --- a/test/chainrules_enzyme.jl +++ b/test/chainrules_enzyme.jl @@ -92,9 +92,10 @@ end y = Enzyme.make_zero(u) dy = Enzyme.make_zero(u) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_u!) + @test f isa Function f(y, u, nothing, setup) @test y != u - @test any(!iszero, y) + @test any(!isnan, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -130,9 +131,10 @@ end y = Enzyme.make_zero(p) dy = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_p!) + @test f isa Function f(y, p, nothing, setup) @test y != p - @test any(!iszero, y) + @test any(!isnan, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -169,9 +171,10 @@ end y = Enzyme.make_zero(temp) dy = Enzyme.make_zero(temp) .+ 1 f = INS.enzyme_wrap(INS.apply_bc_temp!) + @test f isa Function f(y, temp, nothing, setup) @test y != temp - @test any(!iszero, y) + @test any(!isnan, y) Enzyme.autodiff( Enzyme.Reverse, f, @@ -214,6 +217,7 @@ end dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.divergence!) + @test f isa Function f(d, u, setup) @test d == d0 Enzyme.autodiff( @@ -254,6 +258,7 @@ end dpg = Enzyme.make_zero(pg) .+ 1 dp = Enzyme.make_zero(p) f = INS.enzyme_wrap(INS.pressuregradient!) + @test f isa Function f(pg, p, setup) @test pg == pg0 Enzyme.autodiff( @@ -297,6 +302,7 @@ end p = Enzyme.make_zero(p0) dp = Enzyme.make_zero(p) .+ 1 f = INS.enzyme_wrap(INS.poisson!) + @test f isa Function f(p, psolver, d) @test p == p0 dp = Enzyme.make_zero(p) .+ 1 @@ -339,6 +345,7 @@ end dc = Enzyme.make_zero(c) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.convection!) + @test f isa Function f(c, u, setup) @test c == c0 Enzyme.autodiff( @@ -382,6 +389,7 @@ end dd = Enzyme.make_zero(d) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.diffusion!) + @test f isa Function f(d, u, setup) @test d == d0 Enzyme.autodiff( @@ -431,6 +439,7 @@ end dbf = Enzyme.make_zero(bf) .+ 1 du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.applybodyforce!) + @test f isa Function f(bf, u, t, setup) @test bf == bf0 Enzyme.autodiff( @@ -474,6 +483,7 @@ end dg = Enzyme.make_zero(g) .+ 1 dt = Enzyme.make_zero(t) f = INS.enzyme_wrap(INS.gravity!) + @test f isa Function f(g, t, setup) @test g != 0 Enzyme.autodiff( @@ -518,6 +528,7 @@ end ddiff = Enzyme.make_zero(diff) du = Enzyme.make_zero(u) f = INS.enzyme_wrap(INS.dissipation!) + @test f isa Function f(diss, diff, u, setup) @test diss == diss0 Enzyme.autodiff( diff --git a/test/enzyme_integration.jl b/test/enzyme_integration.jl index 4fa895c43..cefc44917 100644 --- a/test/enzyme_integration.jl +++ b/test/enzyme_integration.jl @@ -30,7 +30,9 @@ end params = [setup, psolver] params_ref = Ref(params) right_hand_side!(dudt, u, params_ref, T(0)) + @test all(!isnan, dudt) F_out = create_right_hand_side(setup, psolver) + @test all(!isnan, F_out(u, nothing, T(0))) @test dudt ≈ F_out(u, nothing, T(0)) @test u == u0 @test sum(dudt) != 0 @@ -82,6 +84,10 @@ end Const(params_ref), Const(T(0)), ) + @test all(!isnan, u) + @test all(!isnan, du) + @test all(!isnan, dudt) + @test all(!isnan, dd) @test u0 == u @test dudt ≈ F_out(u, nothing, T(0)) zpull = Zygote.pullback(F_out, u, nothing, T(0)) diff --git a/test/operators.jl b/test/operators.jl index 9da7996cb..0f795a7b3 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -6,7 +6,19 @@ lims = T(0), T(1) x = tanh_grid(lims..., n), tanh_grid(lims..., n, 1.3) bc = DirichletBC(), DirichletBC() - setup = Setup(; x, Re, boundary_conditions = (bc, bc)) + boundary_conditions = (bc, bc) + temperature = + temperature_equation(; Pr = T(0.71), Ra = T(1e6), Ge = T(1.0), boundary_conditions) + bodyforce = (dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y) + setup = Setup(; + x, + boundary_conditions, + Re, + temperature, + bodyforce, + issteadybodyforce = true, + ) + psolver = default_psolver(setup) uref(dim, x, y, args...) = -(dim == 1) * sin(x) * cos(y) + (dim == 2) * cos(x) * sin(y) u = velocityfield(setup, uref, T(0)) end @@ -19,7 +31,19 @@ end lims = T(0), T(1) x = tanh_grid(lims..., n, 1.2), tanh_grid(lims..., n, 1.1), cosine_grid(lims..., n) bc = DirichletBC(), DirichletBC(), DirichletBC() - setup = Setup(; x, Re, boundary_conditions = (bc, bc, bc)) + boundary_conditions = (bc, bc, bc) + temperature = + temperature_equation(; Pr = T(0.71), Ra = T(1e6), Ge = T(1.0), boundary_conditions) + bodyforce = (dim, x, y, z, t) -> (dim == 1) * 5 * sinpi(8 * y) + setup = Setup(; + x, + boundary_conditions, + Re, + temperature, + bodyforce, + issteadybodyforce = true, + ) + psolver = default_psolver(setup) uref(dim, x, y, args...) = -(dim == 1) * sin(x) * cos(y) + (dim == 2) * cos(x) * sin(y) u = velocityfield(setup, uref, T(0)) end @@ -144,6 +168,33 @@ end end end +@testitem "Apply body force" setup = [Setup2D, Setup3D] begin + using Random + for (u, setup) in ((Setup2D.u, Setup2D.setup), (Setup3D.u, Setup3D.setup)) + T = eltype(u) + F = applybodyforce(u, T(0), setup) + @test F isa Array{T} + @test all(!isnan, F) + end +end + +@testitem "Pressure" setup = [Setup2D, Setup3D] begin + using Random + for (u, setup, psolver) in ( + (Setup2D.u, Setup2D.setup, Setup2D.psolver), + (Setup3D.u, Setup3D.setup, Setup3D.psolver), + ) + T = eltype(u) + temp = randn(T, setup.grid.N) + p = pressure(u, temp, T(0), setup; psolver = psolver) + @test p isa Array{T} + @test all(!isnan, p) + F = applypressure(u, p, setup) + @test F isa Array{T} + @test all(!isnan, F) + end +end + @testitem "Other fields" setup = [Setup2D, Setup3D] begin using Random for (u, setup) in ((Setup2D.u, Setup2D.setup), (Setup3D.u, Setup3D.setup)) From 6a582de708cb42df8c7c08c0bd34ebbfee71d2b3 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Fri, 22 Nov 2024 12:24:01 +0100 Subject: [PATCH 14/14] Mask some inner functions to codecov --- src/boundary_conditions.jl | 2 ++ src/operators.jl | 2 ++ src/pressure.jl | 2 ++ src/sciml.jl | 2 ++ 4 files changed, 8 insertions(+) diff --git a/src/boundary_conditions.jl b/src/boundary_conditions.jl index bb9172bd4..1480af1a3 100644 --- a/src/boundary_conditions.jl +++ b/src/boundary_conditions.jl @@ -528,6 +528,7 @@ apply_bc_temp!(bc::PressureBC, temp, β, t, setup; isright, kwargs...) = apply_bc_temp_pullback!(bc::PressureBC, φbar, β, t, setup; isright, kwargs...) = apply_bc_p_pullback!(SymmetricBC(), φbar, β, t, setup; isright, kwargs...) +# COV_EXCL_START # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. function enzyme_wrap( f::Union{typeof(apply_bc_u!),typeof(apply_bc_p!),typeof(apply_bc_temp!)}, @@ -602,3 +603,4 @@ function EnzymeRules.reverse( y.dval .= x.dval # y is a copy of x return (nothing, nothing, nothing, nothing) end +# COV_EXCL_STOP diff --git a/src/operators.jl b/src/operators.jl index 1a9a1c5a8..c436d274f 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1465,6 +1465,7 @@ function get_scale_numbers(u, setup) (; uavg, ϵ, η, λ, Reλ, L, τ) end +# COV_EXCL_START # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. function enzyme_wrap( f::Union{ @@ -1753,3 +1754,4 @@ function EnzymeRules.reverse( ) @error "momentum Enzyme-AD not yet implemented" end +# COV_EXCL_STOP diff --git a/src/pressure.jl b/src/pressure.jl index 3523057a8..7405bbd8c 100644 --- a/src/pressure.jl +++ b/src/pressure.jl @@ -358,6 +358,7 @@ function psolver_spectral(setup) end end +# COV_EXCL_START # Wrap a function to return `nothing`, because Enzyme can not handle vector return values. function enzyme_wrap(f::typeof(poisson!)) function wrapped_f(p, psolve, d) @@ -393,3 +394,4 @@ function EnzymeRules.reverse( EnzymeCore.make_zero!(y.dval) return (nothing, nothing, nothing) end +# COV_EXCL_STOP diff --git a/src/sciml.jl b/src/sciml.jl index f7c1ad13b..becb909a1 100644 --- a/src/sciml.jl +++ b/src/sciml.jl @@ -46,6 +46,7 @@ function right_hand_side!(dudt, u, params_ref, t) return nothing end +# COV_EXCL_START function EnzymeRules.augmented_primal( config::RevConfigWidth{1}, func::Const{typeof(right_hand_side!)}, @@ -109,3 +110,4 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing, nothing) end +# COV_EXCL_STOP