Skip to content

Commit

Permalink
Expose common factorisations (#83)
Browse files Browse the repository at this point in the history
* qr and lq both work. lu doesnt

* need to fix QRColPiv - will come back to it

* lu fixed, cholesky started

* non distributed cholesky, lu, qr and lq all pass

* test distributed too

* move helper functions into core

* add convenience setindex! with warning if used with scalars

* whitespace

* add more convenience functions

* qol functions with darray

* not going to do pivoting this time

* delete trailing whitespace

* test/factor.jl wasnt designed to be run in loop

* delete commented code from test/factor

* delete calls to elemental library that were never needed

* change indentation to 4 spaces as appears to be used already

* AbstractArray -> Array

* typeof -> isa, but potentially delete this one

* error when setindex! with scalars

* added type signatures to factorisation struct outer constructors

* remove outdated comment

* do not test functions with preprended _

* clearer types to throw error on scalar setindex
  • Loading branch information
jwscook authored Mar 28, 2024
1 parent 9807b28 commit a3478e6
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 4 deletions.
42 changes: 42 additions & 0 deletions src/core/distmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ for (elty, ext) in ((:ElInt, :i),
A.obj, i, j))
return A
end

function isLocalRow(A::DistMatrix{$elty}, i::Integer)
rv = Ref{ElInt}(0)
ElError(ccall(($(string("ElDistMatrixIsLocalRow_", ext)), libEl), Cuint,
(Ptr{Cvoid}, ElInt, Ref{ElInt}),
A.obj, i - 1, rv))
return Bool(rv[])
end

function isLocalCol(A::DistMatrix{$elty}, i::Integer)
rv = Ref{ElInt}(0)
ElError(ccall(($(string("ElDistMatrixIsLocalCol_", ext)), libEl), Cuint,
(Ptr{Cvoid}, ElInt, Ref{ElInt}),
A.obj, i - 1, rv))
return Bool(rv[])
end

end
end

Expand Down Expand Up @@ -205,3 +222,28 @@ function hcat(x::Vector{DistMatrix{T}}) where {T}
return A
end
end

import DistributedArrays.localpart
# used in testing
function localpart(A::Elemental.DistMatrix{T}) where T
buffer = Base.zeros(T, localHeight(A), localWidth(A))
return localpart!(buffer, A)
end

function localpart!(buffer, A::Elemental.DistMatrix)
@assert size(buffer) == (localHeight(A), localWidth(A))
for j in 1:localWidth(A), i in 1:localHeight(A)
buffer[i, j] = getLocal(A, i, j)
end
return buffer
end

import DistributedArrays.localindices
# used in testing
function localindices(A::Elemental.DistMatrix{T}) where T
# sometimes they aren't contigous so cant do start:start+length
rows = findall(isLocalRow(A, i) for i in 1:height(A))
cols = findall(isLocalCol(A, i) for i in 1:width(A))
return (rows, cols)
end

41 changes: 41 additions & 0 deletions src/julia/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,46 @@ for (elty, ext) in ((:ElInt, :i),
processQueues(A)
return A
end

function convert(::Type{DistMatrix{$elty}}, DA::DistributedArrays.DArray)
npr, npc = size(procs(DA))
if npr*npc != MPI.Comm_size(MPI.COMM_WORLD)
error("Used non MPI.COMM_WORLD DArray for DistMatrix, ",
"as procs(DA)=($npr,$npc) is incompatible with ",
"MPI.Comm_size(MPI.COMM_WORLD)=$(MPI.Comm_size(MPI.COMM_WORLD))")
end

m, n = size(DA)
A = DistMatrix($elty, m, n)
@sync begin
for id in workers()
let A = A, DA = DA
@async remotecall_fetch(id) do
rows, cols = DistributedArrays.localindices(DA)
reserve(A,length(rows) * length(cols))
for j in cols, i in rows
queueUpdate(A, i - 1, j - 1, DA[i, j])
end
end
end
end
end
processQueues(A)
return A
end

function copyto!(DA::DistributedArrays.DArray{$elty}, A::DistMatrix{$elty} )
@sync begin
ijs = localindices(DA)
for j in ijs[2], i in ijs[1]
queuePull(A, i, j)
end
DAlocal = DA[:L]

DAlocal_mat = ndims(DAlocal) == 1 ? reshape(DAlocal, :, 1) : DAlocal
processPullQueue(A, DAlocal_mat)
end
return DA
end
end
end
21 changes: 21 additions & 0 deletions src/julia/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,21 @@ Base.convert(::Type{Array}, xd::DistMatrix{T}) where {T} =

Base.Array(xd::DistMatrix) = convert(Array, xd)

function Base.setindex!(A::DistMatrix, values::Number, i::Integer, j::Integer)
throw(ArgumentError("setindex! with scalars is disallowed.
Use a large collection to setindex! in bulk."))
end
function Base.setindex!(A::DistMatrix,
values,
globalis,
globaljs)
for (cj, globalj) in enumerate(globaljs), (ci, globali) in enumerate(globalis)
queueUpdate(A, globali, globalj, values[ci, cj])
end
processQueues(A)
end


LinearAlgebra.norm(x::ElementalMatrix) = nrm2(x)
# function LinearAlgebra.norm(x::ElementalMatrix)
# if size(x, 2) == 1
Expand All @@ -194,6 +209,12 @@ LinearAlgebra.cholesky!(A::Hermitian{<:Union{Real,Complex},<:ElementalMatrix}) =
LinearAlgebra.cholesky(A::Hermitian{<:Union{Real,Complex},<:ElementalMatrix}) = cholesky!(copy(A))

LinearAlgebra.lu(A::ElementalMatrix) = _lu!(copy(A))
LinearAlgebra.lu!(A::ElementalMatrix) = _lu!(A)
LinearAlgebra.qr(A::ElementalMatrix) = _qr!(copy(A))
LinearAlgebra.qr!(A::ElementalMatrix) = _qr!(A)
LinearAlgebra.lq(A::ElementalMatrix) = _lq!(copy(A))
LinearAlgebra.lq!(A::ElementalMatrix) = _lq!(A)
LinearAlgebra.cholesky!(A::ElementalMatrix) = _cholesky!(A)

# Mixed multiplication with Julia Arrays
(*)(A::DistMatrix{T}, B::StridedVecOrMat{T}) where {T} = A*convert(DistMatrix{T}, B)
Expand Down
136 changes: 132 additions & 4 deletions src/lapack_like/factor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,142 @@ for mattype in ("", "Dist")
uplo, A.obj))
return A
end
end
end
end

# These are the number types that Elemental supports

for mattype in ("", "Dist")
mat = Symbol(mattype, "Matrix")
_p = Symbol(mattype, "Permutation")

# TODO - fix QRColPiv
#QRColPivStructName = Symbol("QRColPiv$(string(mattype))")
QRStructName = Symbol("QR$(string(mattype))")
LQStructName = Symbol("LQ$(string(mattype))")
LUStructName = Symbol("LU$(string(mattype))")
CHStructName = Symbol("Cholesky$(string(mattype))")

@eval begin

struct $QRStructName{T,U<:Real}
A::$mat{T}
t::$mat{T}
d::$mat{U}
orientation::Ref{Orientation}
end
function $QRStructName(A::$mat{T}, t::$mat{T}, d::$mat{U}
) where {U<:Union{Float32, Float64}, T<:Union{Complex{U}, U}}
return $QRStructName(A, t, d, Ref(NORMAL::Orientation))
end

struct $LQStructName{T, U<:Real}
A::$mat{T}
householderscalars::$mat{T}
signature::$mat{U}
orientation::Ref{Orientation}
end
function $LQStructName(A::$mat{T}, householderscalars::$mat{T}, signature::$mat{U}
) where {U<:Union{Float32, Float64}, T<:Union{Complex{U}, U}}
return $LQStructName(A, householderscalars, signature, Ref(NORMAL::Orientation))
end

struct $LUStructName{T}
A::$mat{T}
p::$_p
orientation::Ref{Orientation}
end
function $LUStructName(A::$mat{T}, p::$_p
) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}}
return $LUStructName(A, p, Ref(NORMAL::Orientation))
end

struct $CHStructName{T}
uplo::UpperOrLower
A::$mat{T}
orientation::Ref{Orientation}
end
function $CHStructName(uplo::UpperOrLower, A::$mat{T}
) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}}
return $CHStructName(uplo, A, Ref(NORMAL::Orientation))
end

end

for (elty, ext) in ((:Float32, :s),
(:Float64, :d),
(:ComplexF32, :c),
(:ComplexF64, :z))

@eval begin

function _lu!(A::$mat{$elty})
p = $_p()
ElError(ccall(($(string("ElLUPartialPiv", mattype, "_", ext)), libEl), Cuint,
(Ptr{Cvoid}, Ptr{Cvoid}),
A.obj, p.obj))
return A, p
ElError(ccall(($(string("ElLU", mattype, "_", ext)), libEl), Cuint,
(Ptr{Cvoid}, Ptr{Cvoid}),
A.obj, p.obj))
return $LUStructName(A, p)
end

function LinearAlgebra.:\(lu::$LUStructName{$elty}, b::$mat{$elty})
x = deepcopy(b)#$mat($elty)
ElError(ccall(($(string("ElSolveAfterLU", mattype, "_", ext)), libEl), Cuint,
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}),
lu.orientation[], lu.A.obj, x.obj))
return x
end

function _qr!(A::$mat{$elty})
t = $mat($elty)
d = $mat(real($elty))
ElError(ccall(($(string("ElQR", mattype, "_", ext)), libEl), Cuint,
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
A.obj, t.obj, d.obj))
return $QRStructName(A, t, d)
end

function LinearAlgebra.:\(qr::$QRStructName{$elty}, b::$mat{$elty})
x = $mat($elty)
ElError(ccall(($(string("ElSolveAfterQR", mattype, "_", ext)), libEl), Cuint,
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
qr.orientation[], qr.A.obj, qr.t.obj, qr.d.obj, b.obj, x.obj))
return x
end

function _lq!(A::$mat{$elty})
householderscalars = $mat($elty)
signature = $mat(real($elty))
ElError(ccall(($(string("ElLQ", mattype, "_", ext)), libEl), Cuint,
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
A.obj, householderscalars.obj, signature.obj))
return $LQStructName(A, householderscalars, signature)
end

function LinearAlgebra.:\(lq::$LQStructName{$elty}, b::$mat{$elty})
x = $mat($elty)
ElError(ccall(($(string("ElSolveAfterLQ", mattype, "_", ext)), libEl), Cuint,
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
lq.orientation[], lq.A.obj, lq.householderscalars.obj,
lq.signature.obj, b.obj, x.obj))
return x
end

function _cholesky!(A::$mat{$elty}, uplo::UpperOrLower=UPPER::UpperOrLower)
ElError(ccall(($(string("ElCholesky", mattype, "_", ext)), libEl), Cuint,
(UpperOrLower, Ptr{Cvoid}),
uplo, A.obj))
return $CHStructName(uplo, A)
end

function LinearAlgebra.:\(ch::$CHStructName{$elty}, b::$mat{$elty})
x = deepcopy(b)#$mat($elty)
ElError(ccall(($(string("ElSolveAfterCholesky", mattype, "_", ext)), libEl), Cuint,
(UpperOrLower, Orientation, Ptr{Cvoid}, Ptr{Cvoid}),
ch.uplo, ch.orientation[], ch.A.obj, x.obj))
return x
end

end
end
end
43 changes: 43 additions & 0 deletions test/distcholesky.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using MPI, MPIClusterManagers, Distributed

man = MPIManager(np = 2);

addprocs(man);

@everywhere using LinearAlgebra, Elemental

const M = 400
const N = M

@mpi_do man M = @fetchfrom 1 M
@mpi_do man N = @fetchfrom 1 N

const Ahost = rand(Float64, M, N)
Ahost .+= Ahost'
Ahost .+= M * I(M)
const bhost = rand(Float64, M)

@mpi_do man Aall = @fetchfrom 1 Ahost
@mpi_do man ball = @fetchfrom 1 bhost

@mpi_do man A = Elemental.DistMatrix(Float64);
@mpi_do man b = Elemental.DistMatrix(Float64);

@mpi_do man A = Elemental.resize!(A, M, N);
@mpi_do man b = Elemental.resize!(b, M);

@mpi_do man copyto!(A, Aall)
@mpi_do man copyto!(b, ball)

@mpi_do man chA = Elemental.cholesky!(A);

@mpi_do man x = chA \ b;

@mpi_do man localx = zeros(Float64, Elemental.localHeight(x), Elemental.localWidth(x))
@mpi_do man copyto!(localx, Elemental.localpart(x))

using Test
x = vcat((fetch(@spawnat p localx)[:] for p in workers())...)
@testset "Cholesky" begin
@test x Ahost \ bhost
end
41 changes: 41 additions & 0 deletions test/distlq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using MPI, MPIClusterManagers, Distributed

man = MPIManager(np = 2);

addprocs(man);

@everywhere using LinearAlgebra, Elemental

const M = 300
const N = 400

@mpi_do man M = @fetchfrom 1 M
@mpi_do man N = @fetchfrom 1 N

const Ahost = rand(Float64, M, N)
const bhost = rand(Float64, M)

@mpi_do man Aall = @fetchfrom 1 Ahost
@mpi_do man ball = @fetchfrom 1 bhost

@mpi_do man A = Elemental.DistMatrix(Float64);
@mpi_do man b = Elemental.DistMatrix(Float64);

@mpi_do man A = Elemental.resize!(A, M, N);
@mpi_do man b = Elemental.resize!(b, M);

@mpi_do man copyto!(A, Aall)
@mpi_do man copyto!(b, ball)

@mpi_do man lqA = Elemental.lq!(A);

@mpi_do man x = lqA \ b;

@mpi_do man localx = zeros(Float64, Elemental.localHeight(x), Elemental.localWidth(x))
@mpi_do man copyto!(localx, Elemental.localpart(x))

using Test
x = vcat((fetch(@spawnat p localx)[:] for p in workers())...)
@testset "lq" begin
@test x Ahost \ bhost
end
Loading

0 comments on commit a3478e6

Please sign in to comment.