Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an AllAtOnceCofunction class #158

Merged
merged 10 commits into from
Jan 16, 2024
245 changes: 182 additions & 63 deletions asQ/allatonce/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from functools import reduce
from operator import mul
import contextlib
from ufl.duals import is_primal, is_dual
from asQ.profiling import profiler
from asQ.allatonce.mixin import TimePartitionMixin

__all__ = ['time_average', 'AllAtOnceFunction']
__all__ = ['time_average', 'AllAtOnceFunction', 'AllAtOnceCofunction']


@profiler()
Expand Down Expand Up @@ -42,19 +43,20 @@ def time_average(aaofunc, uout, uwrk, average='window'):
return


class AllAtOnceFunction(TimePartitionMixin):
class AllAtOnceFunctionBase(TimePartitionMixin):
@profiler()
def __init__(self, ensemble, time_partition, function_space):
"""
A function representing multiple timesteps of a time-dependent finite-element problem,
A (co)function representing multiple timesteps of a time-dependent finite-element problem,
i.e. the solution to an all-at-once system.

:arg ensemble: time-parallel ensemble communicator. The timesteps are partitioned
over the ensemble members according to time_partition so
ensemble.ensemble_comm.size == len(time_partition) must be True.
:arg time_partition: a list of integers for the number of timesteps stored on each
ensemble rank.
:arg function_space: a FunctionSpace for the solution at a single timestep.
:arg function_space: a Space for the a single timestep.
Either `FunctionSpace` or `DualSpace` depending if the child is AAO(Co)Function.
"""
self._time_partition_setup(ensemble, time_partition)

Expand All @@ -67,28 +69,33 @@ def __init__(self, ensemble, time_partition, function_space):

self.ncomponents = len(self.field_function_space.subfunctions)

self.function = fd.Function(self.function_space)
self.initial_condition = fd.Function(self.field_function_space)
# this will be renamed either self.function or self.cofunction
self._fbuf = fd.Function(self.function_space)

# Functions to view each timestep
def field_function(i):
dats = (self.function.subfunctions[j].dat
for j in self._component_indices(i))
if self.ncomponents == 1:
j = self._component_indices(i)[0]
dat = self._fbuf.subfunctions[j].dat
else:
dat = MixedDat((self._fbuf.subfunctions[j].dat
for j in self._component_indices(i)))

return fd.Function(self.field_function_space,
val=MixedDat(dats))
val=dat)

self._fields = tuple(field_function(i)
for i in range(self.nlocal_timesteps))

# functions containing the last step of the previous
# (co)functions containing the last step of the previous
# and current slice for parallel communication
self.uprev = fd.Function(self.field_function_space)
self.unext = fd.Function(self.field_function_space)

self.nlocal_dofs = self.function_space.node_set.size
self.nglobal_dofs = self.ntimesteps*self.field_function_space.dim()

with self.function.dat.vec as fvec:
with self._fbuf.dat.vec as fvec:
sizes = (self.nlocal_dofs, self.nglobal_dofs)
self._vec = PETSc.Vec().createWithArray(fvec.array,
size=sizes,
Expand Down Expand Up @@ -204,8 +211,8 @@ def copy(self, copy_values=True):
:arg copy_values: If true, the values of the current AllAtOnceFunction
will be copied into the new AllAtOnceFunction.
"""
new = AllAtOnceFunction(self.ensemble, self.time_partition,
self.field_function_space)
new = type(self)(self.ensemble, self.time_partition,
self.field_function_space)
if copy_values:
new.assign(self)
return new
Expand All @@ -227,34 +234,34 @@ def assign(self, src, update_halos=True, blocking=True):
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
if isinstance(src, AllAtOnceFunction):
dst_funcs = [self.function, self.initial_condition]
src_funcs = [src.function, src.initial_condition]
# these buffers just will be overwritten if the halos are updated
if not update_halos:
dst_funcs.extend([self.uprev, self.unext])
src_funcs.extend([src.uprev, src.unext])
for dst, src in zip(dst_funcs, src_funcs):
dst.assign(src)

def func_assign(x, y):
return y.assign(x)

def vec_assign(x, y):
x.copy(y)

if isinstance(src, type(self)):
return self._vs_op(src, func_assign, vec_assign,
update_ics=True,
update_halos=update_halos,
blocking=blocking)

# TODO: We should be able to use _vs_op here too but
# test_allatoncesolver:::test_solve_heat_equation
# fails if we do. The only difference is that
# _vs_op accesses the global vec with read/write
# access instead of write only.
# It isn't clear why this makes a difference (it
# shouldn't).
elif isinstance(src, PETSc.Vec):
with self.global_vec_wo() as gvec:
src.copy(gvec)

elif isinstance(src, fd.Function):
if src.function_space() == self.field_function_space:
for i in range(self.nlocal_timesteps):
self[i].assign(src)
self.initial_condition.assign(src)
if not update_halos:
self.uprev.assign(src)
self.unext.assign(src)
elif src.function_space() == self.function_space:
self.function.assign(src)
else:
raise ValueError(f"src must be be in the `function_space` {self.function_space}"
+ " or `field_function_space` {self.field_function_space} of the"
+ " the AllAtOnceFunction, not in {src.function_space}")
elif isinstance(src, type(self._fbuf)):
return self._vs_op(src, func_assign, vec_assign,
update_ics=True,
update_halos=update_halos,
blocking=blocking)

else:
raise TypeError(f"src value must be AllAtOnceFunction or PETSc.Vec or field Function, not {type(src)}")
Expand All @@ -270,7 +277,9 @@ def zero(self, subset=None):
:arg subset: pyop2.types.set.Subset indicating the nodes to zero.
If None then the whole function is zeroed.
"""
funcs = (self.initial_condition, self.function, self.uprev, self.unext)
funcs = [self._fbuf, self.uprev, self.unext]
if hasattr(self, 'initial_condition'):
funcs.append(self.initial_condition)
for f in funcs:
f.zero(subset=subset)
return self
Expand All @@ -289,12 +298,10 @@ def scale(self, a, update_ics=False,
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
alpha = fd.Constant(a)

self.function.assign(alpha*self.function)
self._fbuf.assign(a*self._fbuf)

if update_ics:
self.initial_condition.assign(alpha*self.initial_condition)
if update_ics and hasattr(self, 'initial_condition'):
self.initial_condition.assign(a*self.initial_condition)

if update_halos:
return self.update_time_halos(blocking=blocking)
Expand All @@ -320,10 +327,8 @@ def axpy(self, a, x, update_ics=False,
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
alpha = fd.Constant(a)

def func_axpy(x, y):
return y.assign(alpha*x + y)
return y.assign(a*x + y)

def vec_axpy(x, y):
y.axpy(a, x)
Expand Down Expand Up @@ -354,10 +359,8 @@ def aypx(self, a, x, update_ics=False,
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
alpha = fd.Constant(a)

def func_aypx(x, y):
return y.assign(x + alpha*y)
return y.assign(x + a*y)

def vec_aypx(x, y):
y.aypx(a, x)
Expand Down Expand Up @@ -389,11 +392,8 @@ def axpby(self, a, b, x, update_ics=False,
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
alpha = fd.Constant(a)
beta = fd.Constant(b)

def func_axpby(x, y):
return y.assign(alpha*x + beta*y)
return y.assign(a*x + b*y)

def vec_axpby(x, y):
y.axpby(a, b, x)
Expand Down Expand Up @@ -425,32 +425,32 @@ def _vs_op(self, x, func_op, vec_op, update_ics=False,
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
if isinstance(x, AllAtOnceFunction):
func_op(x.function, self.function)
if update_ics:
if isinstance(x, type(self)):
func_op(x._fbuf, self._fbuf)
if update_ics and hasattr(self, 'initial_condition'):
func_op(x.initial_condition, self.initial_condition)

elif isinstance(x, PETSc.Vec):
with self.global_vec() as gvec:
vec_op(x, gvec)

elif isinstance(x, fd.Function):
elif isinstance(x, type(self._fbuf)):
if x.function_space() == self.field_function_space:
for i in range(self.nlocal_timesteps):
func_op(x, self[i])
if update_ics:
if update_ics and hasattr(self, 'initial_condition'):
func_op(x, self.initial_condition)

elif x.function_space() == self.function_space:
func_op(x, self.function)
func_op(x, self._fbuf)

else:
raise ValueError(f"x must be be in the `function_space` {self.function_space}"
+ f" or `field_function_space` {self.field_function_space} of the"
+ f" the AllAtOnceFunction, not in {x.function_space}")

else:
raise TypeError(f"x value must be AllAtOnceFunction or PETSc.Vec or field Function, not {type(x)}")
raise TypeError(f"x value must be AllAtOnce(Co)Function or PETSc.Vec or field (Co)Function, not {type(x)}")

if update_halos:
return self.update_time_halos(blocking=blocking)
Expand All @@ -466,7 +466,7 @@ def global_vec(self):
# fvec shares the same storage as _vec, so we need this context
# manager to make sure that the data gets copied to/from the
# Function.dat storage and _vec.
with self.function.dat.vec:
with self._fbuf.dat.vec:
self._vec.stateIncrease()
yield self._vec

Expand All @@ -481,7 +481,7 @@ def global_vec_ro(self):
# fvec shares the same storage as _vec, so we need this context
# manager to make sure that the data gets copied into _vec from
# the Function.dat storage.
with self.function.dat.vec_ro:
with self._fbuf.dat.vec_ro:
self._vec.stateIncrease()
yield self._vec

Expand All @@ -496,5 +496,124 @@ def global_vec_wo(self):
# fvec shares the same storage as _vec, so we need this context
# manager to make sure that the data gets copied back into the
# Function.dat storage from _vec.
with self.function.dat.vec_wo:
with self._fbuf.dat.vec_wo:
yield self._vec


class AllAtOnceFunction(AllAtOnceFunctionBase):
@profiler()
def __init__(self, ensemble, time_partition, function_space):
"""
A function representing multiple timesteps of a time-dependent finite-element problem,
i.e. the solution to an all-at-once system.

:arg ensemble: time-parallel ensemble communicator. The timesteps are partitioned
over the ensemble members according to time_partition so
ensemble.ensemble_comm.size == len(time_partition) must be True.
:arg time_partition: a list of integers for the number of timesteps stored on each
ensemble rank.
:arg function_space: a FunctionSpace for the solution at a single timestep.
"""
if not is_primal(function_space):
raise TypeError("Cannot only make AllAtOnceFunction from a FunctionSpace")
super().__init__(ensemble, time_partition, function_space)
self.function = self._fbuf
self.initial_condition = fd.Function(self.field_function_space)


class AllAtOnceCofunction(AllAtOnceFunctionBase):
@profiler()
def __init__(self, ensemble, time_partition, function_space):
"""
A Cofunction representing multiple timesteps of a time-dependent finite-element problem,
i.e. the solution to an all-at-once system.

:arg ensemble: time-parallel ensemble communicator. The timesteps are partitioned
over the ensemble members according to time_partition so
ensemble.ensemble_comm.size == len(time_partition) must be True.
:arg time_partition: a list of integers for the number of timesteps stored on each
ensemble rank.
:arg function_space: a FunctionSpace for the solution at a single timestep.
"""
if not is_dual(function_space):
raise TypeError("Can only make an AllAtOnceCofunction from a DualSpace")
super().__init__(ensemble, time_partition, function_space)
self.cofunction = self._fbuf

@profiler()
def scale(self, a, update_halos=False, blocking=True):
"""
Scale the AllAtOnceCofunction by a scalar.

:arg a: scalar to multiply the function by.
:arg update_halos: if True then the time-halos will be updated.
:arg blocking: if update_halos is True, then this argument determines
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
return super().scale(a, update_halos=update_halos, blocking=blocking,
update_ics=False)

@profiler()
def axpy(self, a, x, update_halos=False, blocking=True):
"""
Compute y = a*x + y where y is this AllAtOnceCofunction.

:arg a: scalar to multiply x.
:arg x: other object for calculation. Can be one of:
- AllAtOnceCofunction: all timesteps are updated, and optionally the ics.
- PETSc Vec: all timesteps are updated.
- firedrake.Cofunction in self.function_space:
all timesteps are updated.
- firedrake.Cofunction in self.field_function_space:
all timesteps are updated, and optionally the ics.
:arg update_halos: if True then the time-halos will be updated.
:arg blocking: if update_halos is True, then this argument determines
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
return super().axpy(a, x, update_halos=update_halos, blocking=blocking,
update_ics=False)

@profiler()
def aypx(self, a, x, update_halos=False, blocking=True):
"""
Compute y = x + a*y where y is this AllAtOnceCofunction.

:arg a: scalar to multiply y.
:arg x: other object for calculation. Can be one of:
- AllAtOnceCofunction: all timesteps are updated, and optionally the ics.
- PETSc Vec: all timesteps are updated.
- firedrake.Cofunction in self.function_space:
all timesteps are updated.
- firedrake.Cofunction in self.field_function_space:
all timesteps are updated, and optionally the ics.
:arg update_halos: if True then the time-halos will be updated.
:arg blocking: if update_halos is True, then this argument determines
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
return super().aypx(a, x, update_halos=update_halos, blocking=blocking,
update_ics=False)

@profiler()
def axpby(self, a, b, x, update_halos=False, blocking=True):
"""
Compute y = a*x + b*y where y is this AllAtOnceCofunction.

:arg a: scalar to multiply x.
:arg b: scalar to multiply y.
:arg x: other object for calculation. Can be one of:
- AllAtOnceFunction: all timesteps are updated, and optionally the ics.
- PETSc Vec: all timesteps are updated.
- firedrake.Cofunction in self.function_space:
all timesteps are updated.
- firedrake.Cofunction in self.field_function_space:
all timesteps are updated, and optionally the ics.
:arg update_halos: if True then the time-halos will be updated.
:arg blocking: if update_halos is True, then this argument determines
whether blocking communication is used. A list of MPI Requests is returned
if non-blocking communication is used.
"""
return super().axpby(a, b, x, update_halos=update_halos, blocking=blocking,
update_ics=False)
Loading
Loading