Skip to content

Commit

Permalink
Memory analysis (#11)
Browse files Browse the repository at this point in the history
* support for low memory

Co-authored-by: turquoisedragon2926 <rarockiasamy3@gatech.edu>
  • Loading branch information
turquoisedragon2926 and Richard2926 authored Mar 14, 2024
1 parent 14e8c4c commit 22f8ebd
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/ParBroadcasted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ end

(A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractVector{D}} = A.op(x)
(A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractMatrix{D}} = A.op(x)
(A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractArray{D,3}} = A.op(x)
*(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x*A.op
+(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x+A.op
+(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractArray{D}} = x+A.op
/(A::ParBroadcasted{D,R,Linear,<:Applicable,F}, x::X) where {D,R,F,X<:AbstractArray{D}} = A.op/x

function ChainRulesCore.rrule(A::ParBroadcasted{D,R,L,Parametric,F}, params) where {D,R,L,F}
op_out = A(params)
Expand Down
1 change: 1 addition & 0 deletions src/ParCommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function local_size(global_size::Integer, rank::Integer, num_ranks::Integer)
end

function rotate_dims_batched(x, rot)
# TODO: Fix this bottleneck.
n = length(size(x))
perm = [circshift(collect(1:n-1), rot)..., n]

Expand Down
5 changes: 5 additions & 0 deletions src/ParMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Range(A::ParMatrix) = A.m

complexity(A::ParMatrix{T}) where {T} = elementwise_multiplication_cost(T)*A.n*A.m

distribute(A::ParMatrix) = ParBroadcasted(A)

function init!(A::ParMatrix{T}, d::Parameters) where {T<:Real}
if A.n == 1
d[A] = zeros(T, A.m, A.n)
Expand All @@ -39,6 +41,7 @@ end

(A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params*x
(A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params*x
(A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractArray{T,3}} = batched_mul(A.params,x)
(A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params[A.op.op]'*x
(A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params[A.op.op]'*x

Expand All @@ -54,6 +57,8 @@ end
+(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractArray{T}} = x+A.params[A.op.op]'
+(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x+A.params[A.op.op]'

/(A::ParParameterized{T,T,Linear,ParMatrix{T},V}, x::X) where {T,V,X<:AbstractMatrix{T}} = A.params./x

function to_Dict(A::ParMatrix{T}) where {T}
rv = Dict{String, Any}(
"type" => "ParMatrix",
Expand Down
1 change: 1 addition & 0 deletions src/ParOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Apply a linear operator to a vector or matrix through multiplication.
"""
*(A::ParOperator{D,R,L,<:Applicable,T}, x::X) where {D,R,L,T,X<:AbstractVector{D}} = A(x)
*(A::ParOperator{D,R,L,<:Applicable,T}, x::X) where {D,R,L,T,X<:AbstractMatrix{D}} = A(x)
*(A::ParOperator{D,R,L,<:Applicable,T}, x::X) where {D,R,L,T,X<:AbstractArray{D,3}} = A(x)

"""
Apply a matrix to a linear operator. By default, use rules of the adjoint.
Expand Down
17 changes: 17 additions & 0 deletions src/ParReduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ function (A::ParReduce{T})(x::X) where {T,X<:AbstractArray{T}}
end
end

function (A::ParReduce{T})(x::X) where {T,X<:AbstractMatrix{T}}
device = get_device(x)
if device == "cpu"
return MPI.Allreduce(x, MPI.SUM, A.comm)
elseif device == "gpu"
return MPI.Allreduce(x |> cpu, MPI.SUM, A.comm) |> gpu
end
end

function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractVector{T}}
op_out = A(x)
function pullback(op)
Expand All @@ -45,3 +54,11 @@ function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractArray{T
end
return op_out, pullback
end

function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractMatrix{T}}
op_out = A(x)
function pullback(op)
return NoTangent(), op
end
return op_out, pullback
end
3 changes: 0 additions & 3 deletions src/ParTensor.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
export ParTensor

using OMEinsum
using Flux:batched_mul

"""
Dense N dimensional tensor operator.
"""
Expand Down
2 changes: 2 additions & 0 deletions src/ParametricOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using Match
using MPI
using Random
using UUIDs
using Flux:batched_mul
using OMEinsum

# ==== Includes ====

Expand Down

0 comments on commit 22f8ebd

Please sign in to comment.