Skip to content

Commit

Permalink
Prevent unnecessary refactorisations when matrices haven't changed (#569
Browse files Browse the repository at this point in the history
)
  • Loading branch information
JHopeCollins authored Dec 18, 2024
1 parent d573de5 commit a1d5959
Show file tree
Hide file tree
Showing 19 changed files with 355 additions and 90 deletions.
156 changes: 124 additions & 32 deletions gusto/core/function_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

from gusto.core.logging import logger
from firedrake import (HCurl, HDiv, FunctionSpace, FiniteElement,
TensorProductElement, interval)
from firedrake import (HCurl, HDiv, FunctionSpace, FiniteElement, VectorElement,
TensorProductElement, BrokenElement, interval)

__all__ = ["Spaces", "check_degree_args"]

Expand Down Expand Up @@ -71,6 +71,7 @@ def __init__(self, mesh):
self.mesh = mesh
self.extruded_mesh = hasattr(mesh, "_base_mesh")
self.de_rham_complex = {}
self.continuity = {}

def __call__(self, name):
"""
Expand All @@ -89,7 +90,7 @@ def __call__(self, name):
else:
raise ValueError(f'The space container has no space {name}')

def add_space(self, name, space, overwrite_space=False):
def add_space(self, name, space, overwrite_space=False, continuity=None):
"""
Adds a function space to the container.
Expand All @@ -100,6 +101,10 @@ def add_space(self, name, space, overwrite_space=False):
overwrite_space (bool, optional): Logical to allow space existing in
container to be overwritten by an incoming space. Defaults to
False.
continuity: Whether the space is continuous or not. For spaces on
extruded meshes, must be a dictionary with entries for the
'horizontal' and 'vertical' continuity.
If None, defaults to value of is_cg(space).
"""

if hasattr(self, name) and not overwrite_space:
Expand All @@ -109,9 +114,28 @@ def add_space(self, name, space, overwrite_space=False):

setattr(self, name, space)

def create_space(self, name, family, degree, overwrite_space=False):
# set the continuity of the space - for extruded meshes specify both directions
if continuity is None:
continuity = is_cg(space)
if self.extruded_mesh:
self.continuity[name] = {
'horizontal': continuity,
'vertical': continuity
}
else:
self.continuity[name] = continuity
else:
if self.extruded_mesh:
self.continuity[name] = {
'horizontal': continuity['horizontal'],
'vertical': continuity['vertical']
}
else:
self.continuity[name] = continuity

def create_space(self, name, family, degree, overwrite_space=False, continuity=None):
"""
Creates a space and adds it to the container..
Creates a space and adds it to the container.
Args:
name (str): the name to give to the space.
Expand All @@ -120,18 +144,18 @@ def create_space(self, name, family, degree, overwrite_space=False):
overwrite_space (bool, optional): Logical to allow space existing in
container to be overwritten by an incoming space. Defaults to
False.
continuity: Whether the space is continuous or not. For spaces on
extruded meshes, must be a tuple for the (horizontal, vertical)
continuity. If None, defaults to value of is_cg(space).
Returns:
:class:`FunctionSpace`: the desired function space.
"""

if hasattr(self, name) and not overwrite_space:
raise RuntimeError(f'Space {name} already exists. If you really '
+ 'to create it then set `overwrite_space` as '
+ 'to be True')

space = FunctionSpace(self.mesh, family, degree, name=name)
setattr(self, name, space)

self.add_space(name, space,
overwrite_space=overwrite_space,
continuity=continuity)
return space

def build_compatible_spaces(self, family, horizontal_degree,
Expand Down Expand Up @@ -181,9 +205,15 @@ def build_compatible_spaces(self, family, horizontal_degree,
setattr(self, "L2"+complex_name, de_rham_complex.L2)
# Register L2 space as DG also
setattr(self, "DG"+complex_name, de_rham_complex.L2)
if hasattr(de_rham_complex, "theta"+complex_name):
if hasattr(de_rham_complex, "theta"):
setattr(self, "theta"+complex_name, de_rham_complex.theta)

# Grab the continuity information from the complex
for space_type in ("H1, HCurl", "HDiv", "L2", "DG", "theta"):
space_name = space_type + complex_name
if hasattr(de_rham_complex, space_type):
self.continuity[space_name] = de_rham_complex.continuity[space_type]

def build_dg1_equispaced(self):
"""
Builds the equispaced variant of the DG1 function space, which is used in
Expand All @@ -198,12 +228,15 @@ def build_dg1_equispaced(self):
hori_elt = FiniteElement('DG', cell, 1, variant='equispaced')
vert_elt = FiniteElement('DG', interval, 1, variant='equispaced')
V_elt = TensorProductElement(hori_elt, vert_elt)
continuity = {'horizontal': False, 'vertical': False}
else:
cell = self.mesh.ufl_cell().cellname()
V_elt = FiniteElement('DG', cell, 1, variant='equispaced')
continuity = False

space = FunctionSpace(self.mesh, V_elt, name='DG1_equispaced')
setattr(self, 'DG1_equispaced', space)

self.add_space('DG1_equispaced', space, continuity=continuity)
return space


Expand Down Expand Up @@ -234,6 +267,7 @@ def __init__(self, mesh, family, horizontal_degree, vertical_degree=None,
self.extruded_mesh = hasattr(mesh, '_base_mesh')
self.family = family
self.complex_name = complex_name
self.continuity = {}
self.build_base_spaces(family, horizontal_degree, vertical_degree)
self.build_compatible_spaces()

Expand Down Expand Up @@ -303,15 +337,24 @@ def build_compatible_spaces(self):
if self.extruded_mesh:
# Horizontal and vertical degrees
# need specifying separately. Vtheta needs returning.
Vcg = self.build_h1_space()
Vcg, continuity = self.build_h1_space()
self.continuity["H1"] = continuity
setattr(self, "H1", Vcg)
Vcurl = self.build_hcurl_space()

Vcurl, continuity = self.build_hcurl_space()
self.continuity["HCurl"] = continuity
setattr(self, "HCurl", Vcurl)
Vu = self.build_hdiv_space()

Vu, continuity = self.build_hdiv_space()
self.continuity["HDiv"] = continuity
setattr(self, "HDiv", Vu)
Vdg = self.build_l2_space()

Vdg, continuity = self.build_l2_space()
self.continuity["L2"] = continuity
setattr(self, "L2", Vdg)
Vth = self.build_theta_space()

Vth, continuity = self.build_theta_space()
self.continuity["theta"] = continuity
setattr(self, "theta", Vth)

return Vcg, Vcurl, Vu, Vdg, Vth
Expand All @@ -320,25 +363,39 @@ def build_compatible_spaces(self):
# 2D: two de Rham complexes (hcurl or hdiv) with 3 spaces
# 3D: one de Rham complexes with 4 spaces
# either way, build all spaces
Vcg = self.build_h1_space()
Vcg, continuity = self.build_h1_space()
self.continuity["H1"] = continuity
setattr(self, "H1", Vcg)
Vcurl = self.build_hcurl_space()

Vcurl, continuity = self.build_hcurl_space()
self.continuity["HCurl"] = continuity
setattr(self, "HCurl", Vcurl)
Vu = self.build_hdiv_space()

Vu, continuity = self.build_hdiv_space()
self.continuity["HDiv"] = continuity
setattr(self, "HDiv", Vu)
Vdg = self.build_l2_space()

Vdg, continuity = self.build_l2_space()
self.continuity["L2"] = continuity
setattr(self, "L2", Vdg)

return Vcg, Vcurl, Vu, Vdg

else:
# 1D domain, de Rham complex has 2 spaces
# CG, hdiv and hcurl spaces should be the same
Vcg = self.build_h1_space()
Vcg, continuity = self.build_h1_space()

self.continuity["H1"] = continuity
setattr(self, "H1", Vcg)

setattr(self, "HCurl", None)

self.continuity["HDiv"] = continuity
setattr(self, "HDiv", Vcg)
Vdg = self.build_l2_space()

Vdg, continuity = self.build_l2_space()
self.continuity["L2"] = continuity
setattr(self, "L2", Vdg)

return Vcg, Vdg
Expand All @@ -360,18 +417,21 @@ def build_hcurl_space(self):
"""
if hdiv_hcurl_dict[self.family] is None:
logger.warning('There is no HCurl space for this family. Not creating one')
return None
return None, None

if self.extruded_mesh:
Vh_elt = HCurl(TensorProductElement(self.base_elt_hori_hcurl,
self.base_elt_vert_cg))
Vv_elt = HCurl(TensorProductElement(self.base_elt_hori_cg,
self.base_elt_vert_dg))
V_elt = Vh_elt + Vv_elt
continuity = {'horizontal': True, 'vertical': True}
else:
V_elt = self.base_elt_hori_hcurl
continuity = True

return FunctionSpace(self.mesh, V_elt, name='HCurl'+self.complex_name)
space_name = 'HCurl'+self.complex_name
return FunctionSpace(self.mesh, V_elt, name=space_name), continuity

def build_hdiv_space(self):
"""
Expand All @@ -387,9 +447,13 @@ def build_hdiv_space(self):
self.base_elt_vert_cg)
Vv_elt = HDiv(Vt_elt)
V_elt = Vh_elt + Vv_elt
continuity = {'horizontal': True, 'vertical': True}
else:
V_elt = self.base_elt_hori_hdiv
return FunctionSpace(self.mesh, V_elt, name='HDiv'+self.complex_name)
continuity = True

space_name = 'HDiv'+self.complex_name
return FunctionSpace(self.mesh, V_elt, name=space_name), continuity

def build_l2_space(self):
"""
Expand All @@ -401,10 +465,13 @@ def build_l2_space(self):

if self.extruded_mesh:
V_elt = TensorProductElement(self.base_elt_hori_dg, self.base_elt_vert_dg)
continuity = {'horizontal': False, 'vertical': False}
else:
V_elt = self.base_elt_hori_dg
continuity = False

return FunctionSpace(self.mesh, V_elt, name='L2'+self.complex_name)
space_name = 'L2'+self.complex_name
return FunctionSpace(self.mesh, V_elt, name=space_name), continuity

def build_theta_space(self):
"""
Expand All @@ -423,8 +490,10 @@ def build_theta_space(self):
assert self.extruded_mesh, 'Cannot create theta space if mesh is not extruded'

V_elt = TensorProductElement(self.base_elt_hori_dg, self.base_elt_vert_cg)
continuity = {'horizontal': False, 'vertical': True}

return FunctionSpace(self.mesh, V_elt, name='theta'+self.complex_name)
space_name = 'theta'+self.complex_name
return FunctionSpace(self.mesh, V_elt, name=space_name), continuity

def build_h1_space(self):
"""
Expand All @@ -440,11 +509,13 @@ def build_h1_space(self):

if self.extruded_mesh:
V_elt = TensorProductElement(self.base_elt_hori_cg, self.base_elt_vert_cg)

continuity = {'horizontal': True, 'vertical': True}
else:
V_elt = self.base_elt_hori_cg
continuity = True

return FunctionSpace(self.mesh, V_elt, name='H1'+self.complex_name)
space_name = 'H1'+self.complex_name
return FunctionSpace(self.mesh, V_elt, name=space_name), continuity


def check_degree_args(name, mesh, degree, horizontal_degree, vertical_degree):
Expand Down Expand Up @@ -476,3 +547,24 @@ def check_degree_args(name, mesh, degree, horizontal_degree, vertical_degree):
raise ValueError(f'Cannot pass both "degree" and "vertical_degree" to {name}')
if not extruded_mesh and vertical_degree is not None:
raise ValueError(f'Cannot pass "vertical_degree" to {name} if mesh is not extruded')


def is_cg(V):
"""
Checks if a :class:`FunctionSpace` is continuous.
Function to check if a given space, V, is CG. Broken elements are always
discontinuous; for vector elements we check the names of the Sobolev spaces
of the subelements and for all other elements we just check the Sobolev
space name.
Args:
V (:class:`FunctionSpace`): the space to check.
"""
ele = V.ufl_element()
if isinstance(ele, BrokenElement):
return False
elif type(ele) == VectorElement:
return all([e.sobolev_space.name == "H1" for e in ele._sub_elements])
else:
return V.ufl_element().sobolev_space.name == "H1"
9 changes: 6 additions & 3 deletions gusto/diagnostics/compressible_euler_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ def setup(self, domain, state_fields):

bcs = self.equations.bcs['u']

imbalanceproblem = LinearVariationalProblem(a, L, imbalance, bcs=bcs)
imbalanceproblem = LinearVariationalProblem(a, L, imbalance, bcs=bcs,
constant_jacobian=True)
self.imbalance_solver = LinearVariationalSolver(imbalanceproblem)
self.expr = dot(imbalance, domain.k)
super().setup(domain, state_fields)
Expand Down Expand Up @@ -789,12 +790,14 @@ def setup(self, domain, state_fields):
eqn_rhs = domain.dt * self.phi * (rain * dot(- v, domain.k) * rho / area) * ds_b

# Compute area normalisation
area_prob = LinearVariationalProblem(eqn_lhs, area_rhs, area)
area_prob = LinearVariationalProblem(eqn_lhs, area_rhs, area,
constant_jacobian=True)
area_solver = LinearVariationalSolver(area_prob)
area_solver.solve()

# setup solver
rain_prob = LinearVariationalProblem(eqn_lhs, eqn_rhs, self.flux)
rain_prob = LinearVariationalProblem(eqn_lhs, eqn_rhs, self.flux,
constant_jacobian=True)
self.solver = LinearVariationalSolver(rain_prob)
self.field = state_fields(self.name, space=DG0, dump=True, pick_up=True)
# Initialise field to zero, if picking up this will be overridden
Expand Down
3 changes: 2 additions & 1 deletion gusto/diagnostics/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def setup(self, domain, state_fields):
L = -inner(div(test), f)*dx
if space.extruded:
L += dot(dot(test, n), f)*(ds_t + ds_b)
prob = LinearVariationalProblem(a, L, self.field)
prob = LinearVariationalProblem(a, L, self.field,
constant_jacobian=True)
self.evaluator = LinearVariationalSolver(prob)


Expand Down
5 changes: 4 additions & 1 deletion gusto/diagnostics/shallow_water_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,18 @@ def setup(self, domain, state_fields, vorticity_type=None):

if vorticity_type == "potential":
a = q*gamma*D*dx
constant_jacobian = False
else:
a = q*gamma*dx
constant_jacobian = True

L = (- inner(domain.perp(grad(gamma)), u))*dx
if vorticity_type != "relative":
f = state_fields("coriolis")
L += gamma*f*dx

problem = LinearVariationalProblem(a, L, self.field)
problem = LinearVariationalProblem(a, L, self.field,
constant_jacobian=constant_jacobian)
self.evaluator = LinearVariationalSolver(problem, solver_parameters={"ksp_type": "cg"})


Expand Down
3 changes: 2 additions & 1 deletion gusto/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from gusto.solvers.parameters import * # noqa
from gusto.solvers.linear_solvers import * # noqa
from gusto.solvers.preconditioners import * # noqa
from gusto.solvers.preconditioners import * # noqa
Loading

0 comments on commit a1d5959

Please sign in to comment.