Skip to content

Commit

Permalink
Support data preconditionners with OOC data via lazy application
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 21, 2024
1 parent ec0fc97 commit 940ae55
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 65 deletions.
38 changes: 0 additions & 38 deletions src/TimeModeling/LinearOperators/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,6 @@ struct judiRHS{D} <: judiMultiSourceVector{D}
d::judiVector
end

"""
LazyAdd
nsrc
A
B
sign
Lazy addition of two RHS (currently only judiVector). The addition isn't evaluated to avoid
large memory allocation but instead evaluates the addition (with sign `sign`) `A + sign * B`
for a single source at propagation time.
"""
struct LazyAdd{D} <: judiMultiSourceVector{D}
nsrc::Integer
A
B
sign
end


############################################################################################################################
# Constructors
Expand Down Expand Up @@ -140,7 +122,6 @@ getindex(P::judiWavelet{D}, i) where D = judiWavelet{D}(P.m[i], P.n[i], P.wavele
getindex(P::judiWavelet{D}, i::Integer) where D = judiWavelet{D}(P.m[i], P.n[i], P.wavelet[i:i], P.dt[i:i])
getindex(rhs::judiRHS{D}, i::Integer) where D = judiRHS{D}(length(i), rhs.P[i], rhs.d[i])
getindex(rhs::judiRHS{D}, i::RangeOrVec) where D = judiRHS{D}(length(i), rhs.P[i], rhs.d[i])
getindex(la::LazyAdd{D}, i::RangeOrVec) where D = LazyAdd{D}(length(i), la.A[i], la.B[i], la.sign)

# Backward compatible subsample
subsample(P::judiNoopOperator{D}, i) where D = getindex(P, i)
Expand Down Expand Up @@ -235,22 +216,3 @@ _as_src(::judiNoopOperator, ::AbstractModel, q::judiMultiSourceVector) = q
############################################################################################################################
###### Evaluate lazy operation
eval(rhs::judiRHS) = rhs.d

function eval(ls::LazyAdd{D}) where D
aloc = eval(ls.A)
bloc = eval(ls.B)
ga = aloc.geometry
gb = bloc.geometry
@assert (ga.nt == gb.nt && ga.dt == gb.dt && ga.t == gb.t)
xloc = [vcat(ga.xloc[1], gb.xloc[1])]
yloc = [vcat(ga.yloc[1], gb.yloc[1])]
zloc = [vcat(ga.zloc[1], gb.zloc[1])]
geom = GeometryIC{D}(xloc, yloc, zloc, ga.dt, ga.nt, ga.t)
data = hcat(aloc.data[1], ls.sign*bloc.data[1])
judiVector{D, Matrix{D}}(1, geom, [data])
end

function make_src(ls::LazyAdd{D}) where D
q = eval(ls)
return q.geometry[1], q.data[1]
end
42 changes: 21 additions & 21 deletions src/TimeModeling/Modeling/misfit_fg.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@

export fwi_objective, lsrtm_objective, fwi_objective!, lsrtm_objective!

function multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions,
# Type of accepted input
Dtypes = Union{<:judiVector, NTuple{N, <:judiVector} where N, Vector{<:judiVector}, <:LazyMul}
MTypes = Union{<:AbstractModel, NTuple{N, <:AbstractModel} where N, Vector{<:AbstractModel}}
dmTypes = Union{dmType, NTuple{N, dmType} where N, Vector{dmType}}


function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions,
nlind::Bool, lin::Bool, misfit::Function, illum::Bool)
GC.gc(true)
devito.clear_cache()
Expand All @@ -10,14 +16,14 @@ function multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiV
@assert dObs.nsrc == 1 "Multiple-source data is used in a single-source fwi_objective"

# Load full geometry for out-of-core geometry containers
dObs.geometry = Geometry(dObs.geometry)
source.geometry = Geometry(source.geometry)
d_geometry = Geometry(dObs.geometry)
s_geometry = Geometry(source.geometry)

# Limit model to area with sources/receivers
if options.limit_m == true
@juditime "Limit model to geometry" begin
model = deepcopy(model_full)
model, dm = limit_model_to_receiver_area(source.geometry, dObs.geometry, model, options.buffer_size; pert=dm)
model, dm = limit_model_to_receiver_area(s_geometry, d_geometry, model, options.buffer_size; pert=dm)
end
else
model = model_full
Expand All @@ -30,14 +36,14 @@ function multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiV
end

# Extrapolate input data to computational grid
qIn = time_resample(make_input(source), source.geometry, dtComp)
dObserved = time_resample(make_input(dObs), dObs.geometry, dtComp)
qIn, dObserved = _maybe_pad_t0(qIn, source.geometry, dObserved, dObs.geometry, dtComp)
qIn = time_resample(make_input(source), s_geometry, dtComp)
dObserved = time_resample(make_input(dObs), d_geometry, dtComp)
qIn, dObserved = _maybe_pad_t0(qIn, s_geometry, dObserved, d_geometry, dtComp)

# Set up coordinates
@juditime "Sparse coords setup" begin
src_coords = setup_grid(source.geometry, size(model)) # shifts source coordinates by origin
rec_coords = setup_grid(dObs.geometry, size(model)) # shifts rec coordinates by origin
src_coords = setup_grid(s_geometry, size(model)) # shifts source coordinates by origin
rec_coords = setup_grid(d_geometry, size(model)) # shifts rec coordinates by origin
end

mfunc = pyfunction(misfit, Matrix{Float32}, Matrix{Float32})
Expand Down Expand Up @@ -74,14 +80,14 @@ end


####### Defaults
multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool) =
multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, false)
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, false)

multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool, illum::Bool) =
multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, illum)
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, illum::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, illum)

multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi::Function) =
multi_src_fg(model_full::AbstractModel, source::judiVector, dObs::judiVector, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi, false)
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi::Function) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi, false)

# Find number of experiments
"""
Expand Down Expand Up @@ -117,12 +123,6 @@ function check_args(args...)
return nexp
end


# Type of accepted input
Dtypes = Union{<:judiVector, NTuple{N, <:judiVector} where N, Vector{<:judiVector}}
MTypes = Union{<:AbstractModel, NTuple{N, <:AbstractModel} where N, Vector{<:AbstractModel}}
dmTypes = Union{dmType, NTuple{N, dmType} where N, Vector{dmType}}

################################################################################################
####################### User Interface #########################################################
################################################################################################
Expand Down
7 changes: 2 additions & 5 deletions src/TimeModeling/Preconditioners/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,5 @@ getproperty(J::Preconditioner, s::Symbol) = _get_property(J, Val{s}())
mul!(out::judiMultiSourceVector, J::Preconditioner, ms::judiMultiSourceVector) = copyto!(out, matvec(J, ms))
mul!(out::PhysicalParameter, J::Preconditioner, ms::PhysicalParameter) = copyto!(out, matvec(J, ms))

# Unsupported OOC
function *(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T
@warn "Data preconditionners only support in-core judiVector. Converting (might run out of memory)"
return J * get_data(v)
end
# OOC judiVector
*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v)
19 changes: 19 additions & 0 deletions src/TimeModeling/TimeModeling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("Types/OptionsStructure.jl")
#############################################################################
# Abstract vectors
include("Types/abstract.jl")
include("Types/lazy_msv.jl")
include("Types/broadcasting.jl")
include("Types/judiWavefield.jl") # dense RHS (wavefield)
include("Types/judiWeights.jl") # Extended source weight vector
Expand Down Expand Up @@ -53,3 +54,21 @@ include("Preconditioners/base.jl")
include("Preconditioners/utils.jl")
include("Preconditioners/DataPreconditioners.jl")
include("Preconditioners/ModelPreconditioners.jl")


#############################################################################
# Extra that need all imports

############################################################################################################################
# Enforce right precedence. Mainly we always want (rightfully)
# - First data operation on the right
# - Then propagation
# - Then right preconditioning
# I.e Ml * P * M * q must do Ml * (P * (M * q))
# It''s easier to just hard code the few cases that can happen

for T in [judiMultiSourceVector, dmType]
@eval *(Ml::Preconditioner, P::judiPropagator, Mr::Preconditioner, v::$(T)) = Ml * (P * (Mr * v))
@eval *(P::judiPropagator, Mr::Preconditioner, v::$(T)) = P * (Mr * v)
@eval *(Ml::Preconditioner, P::judiPropagator, v::$(T)) = Ml * (P * v)
end
92 changes: 92 additions & 0 deletions src/TimeModeling/Types/lazy_msv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

struct LazyData{D} <: AbstractVector{D}
msv::judiMultiSourceVector{D}
end

getindex(ld::LazyData{D}, i) where D = get_data(ld.msv[i]).data

setindex!(::LazyData{D}, ::Any, ::Any) where D = throw(MethodError(setindex!, "LazyData is read-only"))

size(A::LazyData) = size(A.msv)

get_data(ld::LazyData{D}) where D = get_data(ld.msv)

"""
LazyAdd
nsrc
A
B
sign
Lazy addition of two RHS (currently only judiVector). The addition isn't evaluated to avoid
large memory allocation but instead evaluates the addition (with sign `sign`) `A + sign * B`
for a single source at propagation time.
"""

struct LazyAdd{D} <: judiMultiSourceVector{D}
nsrc::Integer
A
B
sign
end


getindex(la::LazyAdd{D}, i::RangeOrVec) where D = LazyAdd{D}(length(i), la.A[i], la.B[i], la.sign)


function eval(ls::LazyAdd{D}) where D
aloc = eval(ls.A)
bloc = eval(ls.B)
ga = aloc.geometry
gb = bloc.geometry
@assert (ga.nt == gb.nt && ga.dt == gb.dt && ga.t == gb.t)
xloc = [vcat(ga.xloc[1], gb.xloc[1])]
yloc = [vcat(ga.yloc[1], gb.yloc[1])]
zloc = [vcat(ga.zloc[1], gb.zloc[1])]
geom = GeometryIC{D}(xloc, yloc, zloc, ga.dt, ga.nt, ga.t)
data = hcat(aloc.data[1], ls.sign*bloc.data[1])
judiVector{D, Matrix{D}}(1, geom, [data])
end

function make_src(ls::LazyAdd{D}) where D
q = eval(ls)
return q.geometry[1], q.data[1]
end


"""
LazyMul
nsrc
A
B
sign
Lazy addition of two RHS (currently only judiVector). The addition isn't evaluated to avoid
large memory allocation but instead evaluates the addition (with sign `sign`) `A + sign * B`
for a single source at propagation time.
"""

struct LazyMul{D} <: judiMultiSourceVector{D}
nsrc::Integer
P::joAbstractLinearOperator
msv::judiMultiSourceVector{D}
end

getindex(la::LazyMul{D}, i::RangeOrVec) where D = LazyMul{D}(length(i), la.P[i], la.msv[i])

function make_input(lm::LazyMul{D}) where D
@assert lm.nsrc == 1
return make_input(lm.P * get_data(lm.msv))
end

get_data(lm::LazyMul{D}) where D = lm.P * get_data(lm.msv)

function getproperty(lm::LazyMul{D}, s::Symbol) where D
if s == :data
return LazyData(lm)
elseif s == :geometry
return lm.msv.geometry
else
return getfield(lm, s)
end
end
68 changes: 67 additions & 1 deletion test/test_preconditioners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ dm = model0.m - model.m
Mdr = judiDataMute(srcGeometry, recGeometry; mode=:reflection)
Mdt = judiDataMute(srcGeometry, recGeometry; mode=:turning)
Mdg = judiTimeGain(recGeometry, 2f0)
Mm = judiTopmute(model.n, 10, 1)
order = .25f0
Dt = judiTimeDerivative(recGeometry, order)
It = judiTimeIntegration(recGeometry, order)
Mm = judiTopmute(model.n, 20, 1)

# Time differential only
@test inv(It) == Dt
Expand Down Expand Up @@ -219,4 +219,70 @@ dm = model0.m - model.m
@test isapprox(dm, dml; rtol=ftol)
end
end

@timeit TIMEROUTPUT "OOC Data Preconditioners tests" begin
datapath = joinpath(dirname(pathof(JUDI)))*"/../data/"
# OOC judiVector
container = segy_scan(datapath, "unit_test_shot_records_2",
["GroupX", "GroupY", "RecGroupElevation", "SourceSurfaceElevation", "dt"])
d_cont = judiVector(container; segy_depth_key="RecGroupElevation")
src_geometry = Geometry(container; key = "source", segy_depth_key = "SourceDepth")
wavelet = ricker_wavelet(src_geometry.t[1], src_geometry.dt[1], 0.005)
q_cont = judiVector(src_geometry, wavelet)

# Make sure we test OOC
@test typeof(d_cont) == judiVector{Float32, SeisCon}
@test isequal(d_cont.nsrc, 2)
@test isequal(typeof(d_cont.data), Array{SegyIO.SeisCon, 1})
@test isequal(typeof(d_cont.geometry), GeometryOOC{Float32})

# Make OOC preconditioner
Mdt = judiDataMute(src_geometry, d_cont.geometry)
Mdg = judiTimeGain(d_cont.geometry, 2f0)

# Test OOC DataPrecon
for Pc in [Mdt, Mdg]
# mul
m = Pc * d_cont
@test isa(m, JUDI.LazyMul{Float32})
@test m.nsrc == d_cont.nsrc
@test m.P == Pc
@test m.msv == d_cont

ma = Pc' * d_cont
@test isa(ma, JUDI.LazyMul{Float32})
@test isa(ma[1], JUDI.LazyMul{Float32})
@test ma.nsrc == d_cont.nsrc
@test ma.P == Pc'
@test ma.msv == d_cont

# getindex
m1 = m[1]
@test isa(m1, JUDI.LazyMul{Float32})
@test m1.nsrc == 1
@test m1.msv == d_cont[1]
@test get_data(m1) == get_data(Pc[1] * d_cont[1])

# data
@test isa(m.data, JUDI.LazyData{Float32})
@test_throws MethodError m.data[1] = 1

@test m.data[1] (Pc[1] * get_data(d_cont[1])).data
@test get_data(m.data) Pc * get_data(d_cont)

# Propagation
Fooc = judiModeling(model, src_geometry, d_cont.geometry)

d_syn = Fooc' * Pc' * d_cont
d_synic = Fooc' * Pc' * get_data(d_cont)

@test isapprox(d_syn, d_synic; rtol=1f-5)

f, g = fwi_objective(model0, Pc*d_cont, q_cont)
f2, g2 = fwi_objective(model0, get_data(Pc*d_cont), get_data(q_cont))
@test isapprox(f, f2)
@test isapprox(g, g2)

end
end
end

0 comments on commit 940ae55

Please sign in to comment.