Skip to content

Commit

Permalink
fix: use atomics for now
Browse files Browse the repository at this point in the history
  • Loading branch information
agdestein committed Dec 5, 2024
1 parent 3da3f94 commit aa9448d
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 33 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand Down Expand Up @@ -33,6 +34,7 @@ IncompressibleNavierStokesMakieExt = ["Makie"]

[compat]
Adapt = "4"
Atomix = "1"
CUDA = "5"
CUDSS = "0.3"
ChainRulesCore = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/SymmetryClosure/src/SymmetryClosure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearAlgebra
using IncompressibleNavierStokes
using NeuralClosure

include("tensor.jl")
include("tensorclosure.jl")

export tensorclosure, polynomial

Expand Down
59 changes: 55 additions & 4 deletions lib/SymmetryClosure/test_tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,70 @@ using IncompressibleNavierStokes
using SymmetryClosure
using CUDA
using Zygote
using Random
using LinearAlgebra

lines(cumsum(randn(100)))

# Setup
n = 32
ax = range(0.0, 1.0, n + 1)
setup = Setup(; x = (ax, ax), Re = 1e4, backend = CUDABackend());
ustart = random_field(setup, 0.0)
n = 8
# ax = range(0.0, 1.0, n + 1)
# x = ax, ax
x = tanh_grid(0.0, 1.0, n + 1), stretched_grid(-0.2, 1.2, n + 1)
setup = Setup(;
x,
Re = 1e4,
backend = CUDABackend(),
boundary_conditions = ((DirichletBC(), DirichletBC()), (DirichletBC(), DirichletBC())),
);
ustart = vectorfield(setup) |> randn!

u = ustart

let
B, V = tensorbasis(u, setup)
# B, V = randn!(B), randn!(V)
V = randn!(V)
function f(u)
Bi, Vi = tensorbasis(u, setup)
# dot(Bi, B) + dot(Vi, V)
# dot(getindex.(Bi, 1), getindex.(B, 1)) + dot(Vi, V)
dot(Vi, V)
# dot(Vi[:, :, 1], V[:, :, 1])
end

fd = map(eachindex(u)) do i
h = 1e-2
v1 = copy(u)
v2 = copy(u)
CUDA.@allowscalar v1[i] -= h / 2
CUDA.@allowscalar v2[i] += h / 2
(f(v2) - f(v1)) / h
end |> x -> reshape(x, size(u))

ad = Zygote.gradient(f, u)[1] |> Array

# mask = @. abs(fd - ad) > 1e-3

# i = 1
# V[:, :, i] |> display
# # (mask .* u)[:, :, i] |> display
# (mask .* fd)[:, :, i] |> display
# (mask .* ad)[:, :, i] |> display

# fd .- ad |> display
@show fd - ad .|> abs |> maximum
# @show f(u)
nothing
end

B, V = tensorbasis(u, setup)

typeof(B)
getindex.(B, 1)

B[:, :, 1]

tb, pb = SymmetryClosure.ChainRulesCore.rrule(tensorbasis, u, setup)

ubar = pb(tb)[2]
Expand Down
1 change: 1 addition & 0 deletions src/IncompressibleNavierStokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ $(EXPORTS)
module IncompressibleNavierStokes

using Adapt
using Atomix: @atomic
using ChainRulesCore
using DocStringExtensions
using FFTW
Expand Down
62 changes: 36 additions & 26 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1021,13 +1021,16 @@ function vorticity!(::Dimension{3}, ω, u, setup)
end

@inline ∂x(u, I::CartesianIndex{D}, α, β, Δβ, Δuβ; e = Offset(D)) where {D} =
α == β ? (u[I, α] - u[I-e(β), α]) / Δβ[I[β]] :
(
(u[I+e(β), α] - u[I, α]) / Δuβ[I[β]] +
(u[I-e(α)+e(β), α] - u[I-e(α), α]) / Δuβ[I[β]] +
(u[I, α] - u[I-e(β), α]) / Δuβ[I[β]-1] +
(u[I-e(α), α] - u[I-e(α)-e(β), α]) / Δuβ[I[β]-1]
) / 4
if α == β
(u[I, α] - u[I-e(β), α]) / Δβ[I[β]]
else
(
(u[I+e(β), α] - u[I, α]) / Δuβ[I[β]] +
(u[I-e(α)+e(β), α] - u[I-e(α), α]) / Δuβ[I[β]] +
(u[I, α] - u[I-e(β), α]) / Δuβ[I[β]-1] +
(u[I-e(α), α] - u[I-e(α)-e(β), α]) / Δuβ[I[β]-1]
) / 4
end
@inline function ∂x_adjoint!(
φ,
u,
Expand All @@ -1038,24 +1041,31 @@ end
Δuβ;
e = Offset(D),
) where {D}
# TODO:
# - Invert input/output indices
# - combine all output indices into one
# - Get rid of @atomic
if α == β
# φ = (u[I, α] - u[I-e(β), α]) / Δβ[I[β]]
u[I, α] += φ / Δβ[I[β]]
u[I-e(β), α] -= φ / Δβ[I[β]]
val = φ / Δβ[I[β]]
@atomic u[I, α] += val
@atomic u[I-e(β), α] -= val
else
# φ =
# (u[I+e(β), α] - u[I, α]) / 4Δuβ[I[β]] +
# (u[I-e(α)+e(β), α] - u[I-e(α), α]) / 4Δuβ[I[β]] +
# (u[I, α] - u[I-e(β), α]) / 4Δuβ[I[β]-1] +
# (u[I-e(α), α] - u[I-e(α)-e(β), α]) / 4Δuβ[I[β]-1]
u[I+e(β), α] += φ / 4Δuβ[I[β]]
u[I, α] -= φ / 4Δuβ[I[β]]
u[I-e(α)+e(β), α] += φ / 4Δuβ[I[β]]
u[I-e(α), α] -= φ / 4Δuβ[I[β]]
u[I, α] += φ / 4Δuβ[I[β]-1]
u[I-e(β), α] -= φ / 4Δuβ[I[β]-1]
u[I-e(α), α] += φ / 4Δuβ[I[β]-1]
u[I-e(α)-e(β), α] -= φ / 4Δuβ[I[β]-1]
val = φ / 4Δuβ[I[β]]
@atomic u[I+e(β), α] += val
@atomic u[I, α] -= val
@atomic u[I-e(α)+e(β), α] += val
@atomic u[I-e(α), α] -= val
val = φ / 4Δuβ[I[β]-1]
@atomic u[I, α] += val
@atomic u[I-e(β), α] -= val
@atomic u[I-e(α), α] += val
@atomic u[I-e(α)-e(β), α] -= val
end
u
end
Expand Down Expand Up @@ -1084,15 +1094,15 @@ end
u
end
@inline function ∇_adjoint!(∇u, u, I::CartesianIndex{3}, Δ, Δu)
∂x_adjoint!(∇u, u, I, 1, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u, u, I, 2, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u, u, I, 3, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u, u, I, 1, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u, u, I, 2, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u, u, I, 3, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u, u, I, 1, 3, Δ[3], Δu[3])
∂x_adjoint!(∇u, u, I, 2, 3, Δ[3], Δu[3])
∂x_adjoint!(∇u, u, I, 3, 3, Δ[3], Δu[3])
∂x_adjoint!(∇u[1, 1], u, I, 1, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u[2, 1], u, I, 2, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u[3, 1], u, I, 3, 1, Δ[1], Δu[1])
∂x_adjoint!(∇u[1, 2], u, I, 1, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u[2, 2], u, I, 2, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u[3, 2], u, I, 3, 2, Δ[2], Δu[2])
∂x_adjoint!(∇u[1, 3], u, I, 1, 3, Δ[3], Δu[3])
∂x_adjoint!(∇u[2, 3], u, I, 2, 3, Δ[3], Δu[3])
∂x_adjoint!(∇u[3, 3], u, I, 3, 3, Δ[3], Δu[3])
u
end
@inline idtensor(u, ::CartesianIndex{2}) = SMatrix{2,2,eltype(u),4}(1, 0, 0, 1)
Expand Down
8 changes: 6 additions & 2 deletions src/tensorbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ end
@kernel function tensorbasis_adjoint_kernel!(::Dimension{2}, ubar, Bbar, Vbar, u, Δ, Δu, I0)
I = @index(Global, Cartesian)
I = I + I0

# Forward pass
∇u = (u, I, Δ, Δu)
S = (∇u + ∇u') / 2
R = (∇u - ∇u') / 2

# Reverse pass (requires S and R from forward pass)
Sbar = Bbar[I, 2] + Bbar[I, 3] * R' - R' * Bbar[I, 3] + 2 * Vbar[I, 1] * S
Rbar = S' * Bbar[I, 3] - Bbar[I, 3] * S' + 2 * Vbar[I, 2] * R
∇ubar = (Sbar + Sbar') / 2 + (Rbar - Rbar') / 2
Expand All @@ -88,7 +92,7 @@ end
end

@kernel function tensorbasis_adjoint_kernel!(::Dimension{3}, ubar, Bbar, Vbar, u, Δ, Δu, I0)
# TODO
# TODO: 3D adjoint
end

"""
Expand All @@ -107,7 +111,7 @@ ChainRulesCore.rrule(::typeof(lastdimcontract), a, b, setup) = (
function (cbar)
abar = zero(a)
bbar = zero(b)
lastdimcontract_adjoint!(abar, bbar, cbar, a, b, setup)
lastdimcontract_adjoint!(abar, bbar, cbar |> unthunk, a, b, setup)
(NoTangent(), abar, bbar, NoTangent())
end,
)
Expand Down
17 changes: 17 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ end
test_rrule_named(diffusion, Case.D3.u, Case.D3.setup NoTangent())
end

@testitem "Tensor basis" setup = [Case, ChainRulesStuff] begin
using IncompressibleNavierStokes.StaticArrays
using Random
test_rrule_named(tensorbasis, Case.D2.u, Case.D2.setup NoTangent())
@test_broken false # TODO: 3D adjoint
# test_rrule_named(tensorbasis, Case.D3.u, Case.D3.setup ⊢ NoTangent())
T = eltype(Case.D2.u)
a = similar(Case.D2.u, size(Case.D2.u)..., 5) |> randn!
b = similar(Case.D2.u, SMatrix{2,2,T,4}, size(Case.D2.u)..., 5) |> randn!
test_rrule_named(
IncompressibleNavierStokes.lastdimcontract,
a,
b,
Case.D2.setup NoTangent(),
)
end

@testitem "Temperature" setup = [Case, ChainRulesStuff] begin
for case in (Case.D2, Case.D3)
(; u, temp, setup) = case
Expand Down

0 comments on commit aa9448d

Please sign in to comment.