Skip to content

Commit

Permalink
Clean support for 2D + 3D DFNO (#9)
Browse files Browse the repository at this point in the history
* all major changes

* authors + deps

* machine agnostic

* Fix for tensor on GPU

* machine agnostic reduction

* proper reduction

* syntax

* Log repartition + support for Old 2d FNO

---------

Co-authored-by: turquoisedragon2926 <rarockiasamy3@gatech.edu>
Co-authored-by: Rex Arockiasamy <rarockiasamy3@cos-4a10678.cos.gatech.edu>
  • Loading branch information
3 people authored Feb 13, 2024
1 parent 4e90ece commit 47b8cee
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 25 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ParametricOperators"
uuid = "db9e0614-c73c-4112-a40c-114e5b366d0d"
authors = ["Thomas Grady <tgrady@gatech.edu>"]
authors = ["Thomas Grady <tgrady@gatech.edu>", "Richard Rex <richardr2926@gatech.edu>"]
version = "0.1.0"

[deps]
Expand All @@ -9,11 +9,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Match = "7eb4fadd-790c-5f42-8a69-bfa0b872bfbf"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
23 changes: 21 additions & 2 deletions src/ParBroadcasted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct ParBroadcasted{D,R,L,P,F} <: ParOperator{D,R,L,P,Internal}
op::F
comm::MPI.Comm
root::Int
ParBroadcasted(op, comm, root::Int = 0) = new{DDT(op),RDT(op),linearity(op),parametricity(op),typeof(op)}(op, comm, root)
ParBroadcasted(op, comm::Any=MPI.COMM_WORLD, root::Int = 0) = new{DDT(op),RDT(op),linearity(op),parametricity(op),typeof(op)}(op, comm, root)
end

bcasted(A::ParOperator{D,R,L,P,External}, comm = MPI.COMM_WORLD, root = 0) where {D,R,L,P} =
Expand Down Expand Up @@ -32,4 +32,23 @@ end

(A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractVector{D}} = A.op(x)
(A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractMatrix{D}} = A.op(x)
*(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x*A.op
*(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x*A.op
+(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x+A.op

function ChainRulesCore.rrule(A::ParBroadcasted{D,R,L,Parametric,F}, params) where {D,R,L,F}
op_out = A(params)
function pullback(op)
device = get_device(op.op.params)
θ_global = MPI.Reduce(op.op.params |> cpu, MPI.SUM, A.root, A.comm)

if MPI.Comm_rank(A.comm) == A.root
if device == "cpu"
return NoTangent(), Dict(A.op=>θ_global)
end
return NoTangent(), Dict(A.op=>(θ_global |> gpu))
else
return NoTangent(), NoTangent()
end
end
return op_out, pullback
end
19 changes: 18 additions & 1 deletion src/ParCommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ end
function rotate_dims_batched(x, rot)
n = length(size(x))
perm = [circshift(collect(1:n-1), rot)..., n]

device = get_device(x)
if device != "cpu"
0 in size(x) && return permutedims(x |> cpu, perm) |> gpu
end

return permutedims(x, perm)
end

Expand All @@ -83,4 +89,15 @@ zeros_like(::AbstractArray{T}, dims...) where {T} = zeros(T, dims...)
if CUDA.functional()
zeros_like(::CuArray{T}, dims) where {T} = CUDA.zeros(T, dims)
zeros_like(::CuArray{T}, dims...) where {T} = CUDA.zeros(T, dims...)
end
end

"""
Returns whether the input is on a NVIDIA GPU
"""
function get_device(x::AbstractArray)
if isa(x, CUDA.CuArray)
return "gpu"
else
return "cpu"
end
end
8 changes: 4 additions & 4 deletions src/ParDFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Range(A::ParDFT) = A.m

complexity(A::ParDFT{D,R}) where {D,R} = elementwise_multiplication_cost(R)*A.n*log2(A.n)

(A::ParDFT{D,R})(x::X) where {D<:Complex,R,X<:AbstractMatrix{D}} = convert(Matrix{R}, fft(x, 1) ./ sqrt(A.n))
(A::ParDFT{D,R})(x::X) where {D<:Real,R,X<:AbstractMatrix{D}} = convert(Matrix{R}, rfft(x, 1) ./ sqrt(A.n))
(A::ParDFT{D,R})(x::X) where {D<:Complex,R,X<:AbstractMatrix{D}} = 0 in size(x) ? x : fft(x, 1)
(A::ParDFT{D,R})(x::X) where {D<:Real,R,X<:AbstractMatrix{D}} = rfft(x, 1)
(A::ParDFT{D,R})(x::X) where {D,R,X<:AbstractVector{D}} = vec(A(reshape(x, length(x), 1)))

(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Complex,R,X<:AbstractMatrix{R}} = ifft(x, 1).*convert(real(D), sqrt(A.op.n))
(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Real,R,X<:AbstractMatrix{R}} = irfft(x, A.op.n, 1).*convert(D, sqrt(A.op.n))
(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Complex,R,X<:AbstractMatrix{R}} = 0 in size(x) ? x : ifft(x, 1)
(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Real,R,X<:AbstractMatrix{R}} = irfft(x, A.op.n, 1)
(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D,R,X<:AbstractVector{R}} = vec(A(reshape(x, length(x), 1)))

to_Dict(A::ParDFT{D,R}) where {D,R} = Dict{String, Any}("type" => "ParDFT", "T" => string(D), "n" => A.n, "m" => A.m)
Expand Down
34 changes: 30 additions & 4 deletions src/ParDiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ Diagonal matrix (elementwise) operator.
"""
struct ParDiagonal{T} <: ParLinearOperator{T,T,Parametric,External}
n::Int
ParDiagonal(T, n) = new{T}(n)
ParDiagonal(n) = new{Float64}(n)
id::Any
ParDiagonal(T::DataType, n::Int) = new{T}(n, uuid4(Random.GLOBAL_RNG))
ParDiagonal(n::Int) = new{Float64}(n, uuid4(Random.GLOBAL_RNG))
ParDiagonal(T::DataType, n::Int, id) = new{T}(n, id)
ParDiagonal(n::Int, id) = new{Float64}(n, id)
end

Domain(A::ParDiagonal) = A.n
Expand All @@ -27,13 +30,36 @@ end
*(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParDiagonal{T}},V}) where {T,V,X<:AbstractVector{T}} = x.*conj(A.params[A.op.op])
*(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParDiagonal{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x.*conj(A.params[A.op.op])

to_Dict(A::ParDiagonal{T}) where {T} = Dict{String, Any}("type" => "ParDiagonal", "T" => string(T), "n" => A.n)
function to_Dict(A::ParDiagonal{T}) where {T}
rv = Dict{String, Any}(
"type" => "ParDiagonal",
"T" => string(T),
"n" => A.n
)
if typeof(A.id) == String
rv["id"] = A.id
elseif typeof(A.id) == UUID
rv["id"] = "UUID:$(string(A.id))"
else
throw(ParException("I don't know how to encode id of type $(typeof(A.id))"))
end
rv
end

function from_Dict(::Type{ParDiagonal}, d)
ts = d["T"]
if !haskey(Data_TYPES, ts)
throw(ParException("unknown data type `$ts`"))
end
dtype = Data_TYPES[ts]
ParDiagonal(dtype, d["n"])
mid = d["id"]
if startswith(mid, "UUID:")
mid = UUID(mid[6:end])
end
ParDiagonal(dtype, d["n"], mid)
end

function distribute(A::ParDiagonal{T}, comm::MPI.Comm = MPI.COMM_WORLD) where {T}
local_n = local_size(A.n, MPI.Comm_rank(comm), MPI.Comm_size(comm))
return ParDiagonal(T, local_n)
end
2 changes: 2 additions & 0 deletions src/ParIdentity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ function from_Dict(::Type{ParIdentity}, d)
dtype = Data_TYPES[ts]
ParIdentity(dtype, d["n"])
end

kron(A::ParIdentity{T}, B::ParIdentity{T}) where {T} = ParIdentity(T,B.n*A.n)
22 changes: 17 additions & 5 deletions src/ParKron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end
kron(A::ParLinearOperator, B::ParLinearOperator) = ParKron(A, B)
kron(A::ParKron, B::ParLinearOperator) = ParKron(A.ops..., B)
kron(A::ParLinearOperator, B::ParKron) = ParKron(A, B.ops...)
kron(A::ParKron, B::ParKron) = ParKron(A.ops..., B.ops...)
(A::ParKron, B::ParKron) = ParKron(A.ops..., B.ops...)
(A::ParLinearOperator, B::ParLinearOperator) = kron(A, B)

Domain(A::ParSeparableOperator) = prod(map(Domain, children(A)))
Expand Down Expand Up @@ -236,15 +236,26 @@ function latex_string(A::ParKron{D,R,P,F,N}) where {D,R,P,F,N}
return out
end

rebuild(A::ParBroadcasted{D,R,L,Parametric,F}, cs) where {D,R,L,F<:ParKron} = rebuild(A.op, collect(map(c -> parametricity(c) == Parametric ? ParBroadcasted(c, A.comm, A.root) : c, children(cs[1]))))

"""
Distributes Kronecker product over the given communicator
Distributes Kronecker product over the given dimensions
"""
function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_WORLD)

function distribute(A::ParKron, dims_in::Vector{Int64}, dims_out::Vector{Int64}=dims_in, parent_comm=MPI.COMM_WORLD)
comm_in = MPI.Cart_create(parent_comm, dims_in)
comm_out = MPI.Cart_create(parent_comm, dims_out)

return distribute(A, comm_in, comm_out, parent_comm)
end

"""
Distributes Kronecker product over the given communicator
"""
function distribute(A::ParKron, comm_in::MPI.Comm, comm_out::MPI.Comm, parent_comm=MPI.COMM_WORLD)

dims, _, _ = MPI.Cart_get(comm_in)
dims_out, _, _ = MPI.Cart_get(comm_out)

N = length(dims)
@assert length(A.ops) == N

Expand All @@ -271,6 +282,7 @@ function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_
coords_i = MPI.Cart_coords(comm_i)

# Create repartition operator
!isequal(dims_prev, dims_i) && (MPI.Comm_rank(parent_comm) == 0) && println("Adding Repartition")
!isequal(dims_prev, dims_i) && pushfirst!(ops, ParRepartition(DDT(Ai), comm_prev, comm_i, tuple(size_curr...)))

# Create Kronecker w/ distributed identities
Expand All @@ -284,7 +296,7 @@ function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_
pushfirst!(idents_dim_upper, ParDistributed(ParIdentity(DDT(Ai), size_curr[j]), coords_i[j], dims_i[j]))
end

pushfirst!(ops, ParKron(idents_dim_lower..., ParBroadcasted(Ai, comm_i), idents_dim_upper...))
pushfirst!(ops, ParKron(idents_dim_lower..., rebuild(ParBroadcasted(Ai, comm_i), [Ai]), idents_dim_upper...))

size_curr[d] = Range(Ai)
comm_prev = comm_i
Expand Down
23 changes: 21 additions & 2 deletions src/ParMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,41 @@ Range(A::ParMatrix) = A.m
complexity(A::ParMatrix{T}) where {T} = elementwise_multiplication_cost(T)*A.n*A.m

function init!(A::ParMatrix{T}, d::Parameters) where {T<:Real}
d[A] = rand(T, A.m, A.n)/convert(T, sqrt(A.m*A.n))
if A.n == 1
d[A] = zeros(T, A.m, A.n)
return
end
scale = sqrt(24.0f0 / sum((A.m, A.n)))
d[A] = (rand(T, (A.n, A.m)) .- 0.5f0) .* scale
d[A] = permutedims(d[A], [2, 1])
end

function init!(A::ParMatrix{T}, d::Parameters) where {T<:Complex}
d[A] = rand(T, A.m, A.n)/convert(real(T), sqrt(A.m*A.n))
if A.n == 1
d[A] = zeros(T, A.m, A.n)
return
end
d[A] = rand(T, A.n, A.m)/convert(real(T), sqrt(A.m*A.n))
d[A] = permutedims(d[A], [2, 1])
end

(A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params*x
(A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params*x
(A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params[A.op.op]'*x
(A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params[A.op.op]'*x

*(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractVector{T}} = x*A.params
*(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractMatrix{T}} = x*A.params
*(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractVector{T}} = x*A.params[A.op.op]'
*(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x*A.params[A.op.op]'

+(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractVector{T}} = x.+A.params
+(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractArray{T}} = x.+A.params
+(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractMatrix{T}} = x.+A.params
+(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractVector{T}} = x+A.params[A.op.op]'
+(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractArray{T}} = x+A.params[A.op.op]'
+(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x+A.params[A.op.op]'

function to_Dict(A::ParMatrix{T}) where {T}
rv = Dict{String, Any}(
"type" => "ParMatrix",
Expand Down
4 changes: 2 additions & 2 deletions src/ParOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ const Parameters = Dict{<:ParOperator,Any}
Move objects to cpu.
"""
cpu(x::CuArray{<:Number}) = Array(x)
cpu(x::Vector{CuArray}) = [cpu(y) fpr y in x]
cpu(x::Vector{CuArray}) = [cpu(y) for y in x]
cpu(x::AbstractArray) = x
cpu(x::Parameters) = Dict(k => cpu(v) for (k, v) in pairs(x))

Expand All @@ -141,7 +141,7 @@ if CUDA.functional()
Move objects to gpu.
"""
gpu(x::AbstractArray{<:Number}) = CuArray(x)
gpu(x::Vector{<:AbstractArray}) = [gpu(y) fpr y in x]
gpu(x::Vector{<:AbstractArray}) = [gpu(y) for y in x]
gpu(x::CuArray) = x
gpu(x::Parameters) = Dict(k => gpu(v) for (k, v) in pairs(x))
end
Expand Down
47 changes: 47 additions & 0 deletions src/ParReduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
export ParReduce

"""
Reduction Operator. Reduce across the given communicator
"""
struct ParReduce{T} <: ParOperator{T,T,Linear,NonParametric,External}
comm::MPI.Comm

ParReduce() = new{Float64}(MPI.COMM_WORLD)
ParReduce(T::DataType) = new{T}(MPI.COMM_WORLD)
ParReduce(T::DataType, comm::MPI.Comm) = new{T}(comm)
ParReduce(comm::MPI.Comm) = new{Float64}(comm)
end

function (A::ParReduce{T})(x::X) where {T,X<:AbstractVector{T}}
device = get_device(x)
if device == "cpu"
return MPI.Allreduce(x, MPI.SUM, A.comm)
elseif device == "gpu"
return MPI.Allreduce(x |> cpu, MPI.SUM, A.comm) |> gpu
end
end

function (A::ParReduce{T})(x::X) where {T,X<:AbstractArray{T}}
device = get_device(x)
if device == "cpu"
return MPI.Allreduce(x, MPI.SUM, A.comm)
elseif device == "gpu"
return MPI.Allreduce(x |> cpu, MPI.SUM, A.comm) |> gpu
end
end

function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractVector{T}}
op_out = A(x)
function pullback(op)
return NoTangent(), op
end
return op_out, pullback
end

function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractArray{T}}
op_out = A(x)
function pullback(op)
return NoTangent(), op
end
return op_out, pullback
end
14 changes: 11 additions & 3 deletions src/ParRepartition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mutable struct ParRepartition{T,N} <: ParLinearOperator{T,T,NonParametric,Extern
global_size::NTuple{N, Integer}
local_size_in::NTuple{N, Integer}
local_size_out::NTuple{N, Integer}
send_data::OrderedDict{Integer, Tuple{NTuple{N, UnitRange{Integer}}, Option{Vector{T}}}}
recv_data::OrderedDict{Integer, Tuple{NTuple{N, UnitRange{Integer}}, Option{Vector{T}}}}
send_data::OrderedDict{Int32, Tuple{NTuple{N, UnitRange{Int32}}, Option{Vector{T}}}}
recv_data::OrderedDict{Int32, Tuple{NTuple{N, UnitRange{Int32}}, Option{Vector{T}}}}
batch_size::Option{Integer}

function ParRepartition(T, comm_in, comm_out, global_size)
Expand Down Expand Up @@ -194,4 +194,12 @@ end
function (R::ParRepartition{T,N})(x::X) where {T,N,X<:AbstractVector{T}}
y = R(reshape(x, length(x), 1))
return vec(y)
end
end

function ChainRulesCore.rrule(A::ParRepartition{T,N}, x::X) where {T,N,X<:AbstractMatrix{T}}
op_out = A(x)
function pullback(op)
return NoTangent(), A'(op)
end
return op_out, pullback
end
16 changes: 16 additions & 0 deletions src/ParRestriction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ function from_Dict(::Type{ParRestriction}, d)
dtype = Data_TYPES[ts]
ParRestriction(dtype, d["n"], ranges)
end

function ChainRulesCore.rrule(A::ParRestriction{T}, x::X) where {T,X<:AbstractMatrix{T}}
op_out = A(x)
function pullback(op)
return (NoTangent(), A'(op))
end
return op_out, pullback
end

function ChainRulesCore.rrule(A::ParAdjoint{T,T,NonParametric,ParRestriction{T}}, x::X) where {T,X<:AbstractMatrix{T}}
op_out = A(x)
function pullback(op)
return (NoTangent(), A.op(op))
end
return op_out, pullback
end
Loading

0 comments on commit 47b8cee

Please sign in to comment.