Skip to content

Commit

Permalink
Dolci/fix interpolate warnings (#603)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Bendall <thomas.bendall@metoffice.gov.uk>
  • Loading branch information
Ig-dolci and tommbendall authored Dec 19, 2024
1 parent 0db493e commit e305589
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 61 deletions.
11 changes: 6 additions & 5 deletions gusto/diagnostics/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Common diagnostic fields."""


from firedrake import (assemble, dot, dx, Function, sqrt, TestFunction,
from firedrake import (dot, dx, Function, sqrt, TestFunction,
TrialFunction, Constant, grad, inner, FacetNormal,
LinearVariationalProblem, LinearVariationalSolver,
ds_b, ds_v, ds_t, dS_h, dS_v, ds, dS, div, avg, pi,
TensorFunctionSpace, SpatialCoordinate, as_vector,
Projector, Interpolator, FunctionSpace, FiniteElement,
Projector, assemble, FunctionSpace, FiniteElement,
TensorProductElement)
from firedrake.assign import Assigner
from firedrake.__future__ import interpolate
from ufl.domain import extract_unique_domain

from abc import ABCMeta, abstractmethod, abstractproperty
Expand Down Expand Up @@ -193,7 +194,7 @@ def setup(self, domain, state_fields, space=None):

# Solve method must be declared in diagnostic's own setup routine
if self.method == 'interpolate':
self.evaluator = Interpolator(self.expr, self.field)
self.evaluator = interpolate(self.expr, self.space)
elif self.method == 'project':
self.evaluator = Projector(self.expr, self.field)
elif self.method == 'assign':
Expand All @@ -207,7 +208,7 @@ def compute(self):
logger.debug(f'Computing diagnostic {self.name} with {self.method} method')

if self.method == 'interpolate':
self.evaluator.interpolate()
self.field.assign(assemble(self.evaluator))
elif self.method == 'assign':
self.evaluator.assign()
elif self.method == 'project':
Expand Down Expand Up @@ -294,7 +295,7 @@ def setup(self, domain, state_fields, space=None):

# Solve method must be declared in diagnostic's own setup routine
if self.method == 'interpolate':
self.evaluator = Interpolator(self.expr, self.field)
self.evaluator = interpolate(self.expr, self.space)
elif self.method == 'project':
self.evaluator = Projector(self.expr, self.field)
elif self.method == 'assign':
Expand Down
11 changes: 6 additions & 5 deletions gusto/diagnostics/shallow_water_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@


from firedrake import (
dx, TestFunction, TrialFunction, grad, inner, curl, Function, Interpolator,
dx, TestFunction, TrialFunction, grad, inner, curl, Function, assemble,
LinearVariationalProblem, LinearVariationalSolver, conditional
)
from firedrake.__future__ import interpolate
from gusto.diagnostics.diagnostics import DiagnosticField, Energy

__all__ = ["ShallowWaterKineticEnergy", "ShallowWaterPotentialEnergy",
Expand Down Expand Up @@ -321,14 +322,14 @@ def setup(self, domain, state_fields):

qsat_expr = self.equation.compute_saturation(state_fields.X(
self.equation.field_name))
self.qsat_interpolator = Interpolator(qsat_expr, self.qsat_func)
self.qsat_interpolate = interpolate(qsat_expr, space)
self.expr = conditional(q_t < self.qsat_func, q_t, self.qsat_func)

super().setup(domain, state_fields, space=space)

def compute(self):
"""Performs the computation of the diagnostic field."""
self.qsat_interpolator.interpolate()
self.qsat_func.assign(assemble(self.qsat_interpolate))
super().compute()


Expand Down Expand Up @@ -371,13 +372,13 @@ def setup(self, domain, state_fields):

qsat_expr = self.equation.compute_saturation(state_fields.X(
self.equation.field_name))
self.qsat_interpolator = Interpolator(qsat_expr, self.qsat_func)
self.qsat_interpolate = interpolate(qsat_expr, space)
vapour = conditional(q_t < self.qsat_func, q_t, self.qsat_func)
self.expr = q_t - vapour

super().setup(domain, state_fields, space=space)

def compute(self):
"""Performs the computation of the diagnostic field."""
self.qsat_interpolator.interpolate()
self.qsat_func.assign(assemble(self.qsat_interpolate))
super().compute()
23 changes: 11 additions & 12 deletions gusto/physics/microphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"""

from firedrake import (
Interpolator, conditional, Function, dx, min_value, max_value, Constant, pi,
Projector
conditional, Function, dx, min_value, max_value, Constant, pi,
Projector, assemble
)
from firedrake.__future__ import interpolate
from firedrake.fml import identity, Term, subject
from gusto.equations import Phases, TracerVariableType
from gusto.recovery import Recoverer, BoundaryMethod
Expand Down Expand Up @@ -170,8 +171,7 @@ def __init__(self, equation, vapour_name='water_vapour',
# Add terms to equations and make interpolators
# -------------------------------------------------------------------- #
self.source = [Function(V) for factor in factors]
self.source_interpolators = [Interpolator(sat_adj_expr*factor, source)
for factor, source in zip(factors, self.source)]
self.source_interpolate = [interpolate(sat_adj_expr*factor, V) for factor in factors]

tests = [equation.tests[idx] for idx in V_idxs]

Expand All @@ -195,8 +195,8 @@ def evaluate(self, x_in, dt):
if isinstance(self.equation, CompressibleEulerEquations):
self.rho_recoverer.project()
# Evaluate the source
for interpolator in self.source_interpolators:
interpolator.interpolate()
for interpolator, src in zip(self.source_interpolate, self.source):
src.assign(assemble(interpolator))


class AdvectedMoments(Enum):
Expand Down Expand Up @@ -440,7 +440,7 @@ def __init__(self, equation, cloud_name='cloud_water', rain_name='rain',
min_value(accu_rate, self.cloud_water / self.dt),
min_value(accr_rate + accu_rate, self.cloud_water / self.dt))))

self.source_interpolator = Interpolator(rain_expr, self.source)
self.source_interpolate = interpolate(rain_expr, Vt)

# Add term to equation's residual
test_cl = equation.tests[self.cloud_idx]
Expand All @@ -464,7 +464,7 @@ def evaluate(self, x_in, dt):
self.rain.assign(x_in.subfunctions[self.rain_idx])
self.cloud_water.assign(x_in.subfunctions[self.cloud_idx])
# Evaluate the source
self.source.assign(self.source_interpolator.interpolate())
self.source.assign(assemble(self.source_interpolate))


class EvaporationOfRain(PhysicsParametrisation):
Expand Down Expand Up @@ -609,8 +609,7 @@ def __init__(self, equation, rain_name='rain', vapour_name='water_vapour',
# Add terms to equations and make interpolators
# -------------------------------------------------------------------- #
self.source = [Function(V) for factor in factors]
self.source_interpolators = [Interpolator(evap_rate*factor, source)
for factor, source in zip(factors, self.source)]
self.source_interpolate = [interpolate(evap_rate*factor, V) for factor in factors]

tests = [equation.tests[idx] for idx in V_idxs]

Expand All @@ -634,5 +633,5 @@ def evaluate(self, x_in, dt):
if isinstance(self.equation, CompressibleEulerEquations):
self.rho_recoverer.project()
# Evaluate the source
for interpolator in self.source_interpolators:
interpolator.interpolate()
for interpolator, src in zip(self.source_interpolate, self.source):
src.assign(assemble(interpolator))
9 changes: 5 additions & 4 deletions gusto/physics/physics_parametrisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"""

from abc import ABCMeta, abstractmethod
from firedrake import Interpolator, Function, dx, Projector
from firedrake import Function, dx, Projector, assemble
from firedrake.__future__ import interpolate
from firedrake.fml import subject
from gusto.core.labels import PhysicsLabel
from gusto.core.logging import logger
Expand Down Expand Up @@ -117,14 +118,14 @@ def __init__(self, equation, variable_name, rate_expression,

# Handle method of evaluating source/sink
if self.method == 'interpolate':
self.source_interpolator = Interpolator(expression, V)
self.source_interpolate = interpolate(expression, V)
else:
self.source_projector = Projector(expression, V)

# If not time-varying, evaluate for the first time here
if not self.time_varying:
if self.method == 'interpolate':
self.source.assign(self.source_interpolator.interpolate())
self.source.assign(assemble(self.source_interpolate))
else:
self.source.assign(self.source_projector.project())

Expand All @@ -140,7 +141,7 @@ def evaluate(self, x_in, dt):
if self.time_varying:
logger.info(f'Evaluating physics parametrisation {self.label.label}')
if self.method == 'interpolate':
self.source.assign(self.source_interpolator.interpolate())
self.source.assign(assemble(self.source_interpolate))
else:
self.source.assign(self.source_projector.project())
else:
Expand Down
15 changes: 8 additions & 7 deletions gusto/physics/shallow_water_microphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

from firedrake import (
Interpolator, conditional, Function, dx, min_value, max_value, Constant
conditional, Function, dx, min_value, max_value, Constant, assemble
)
from firedrake.__future__ import interpolate
from firedrake.fml import subject
from gusto.core.logging import logger
from gusto.physics.physics_parametrisation import PhysicsParametrisation
Expand Down Expand Up @@ -134,7 +135,7 @@ def __init__(self, equation, saturation_curve,
self.evaluate)

# interpolator does the conversion of vapour to rain
self.source_interpolator = Interpolator(conditional(
self.source_interpolate = interpolate(conditional(
self.water_v > self.saturation_curve,
(1/self.tau)*gamma_r*(self.water_v - self.saturation_curve),
0), Vv)
Expand All @@ -159,7 +160,7 @@ def evaluate(self, x_in, dt):
if self.set_tau_to_dt:
self.tau.assign(dt)
self.water_v.assign(x_in.subfunctions[self.Vv_idx])
self.source.assign(self.source_interpolator.interpolate())
self.source.assign(assemble(self.source_interpolate))


class SWSaturationAdjustment(PhysicsParametrisation):
Expand Down Expand Up @@ -321,8 +322,8 @@ def __init__(self, equation, saturation_curve,
# Add terms to equations and make interpolators
# sources have the same order as V_idxs and factors
self.source = [Function(Vc) for factor in factors]
self.source_interpolators = [Interpolator(sat_adj_expr*factor, source)
for factor, source in zip(factors, self.source)]
self.source_interpolate = [interpolate(sat_adj_expr*factor, Vc)
for factor in factors]

# test functions have the same order as factors and sources (vapour,
# cloud, depth, buoyancy) so that the correct test function multiplies
Expand Down Expand Up @@ -359,5 +360,5 @@ def evaluate(self, x_in, dt):
self.cloud.assign(x_in.subfunctions[self.Vc_idx])
if self.time_varying_gamma_v:
self.gamma_v.interpolate(self.gamma_v_computation(x_in))
for interpolator in self.source_interpolators:
interpolator.interpolate()
for interpolator, src in zip(self.source_interpolate, self.source):
src.assign(assemble(interpolator))
11 changes: 6 additions & 5 deletions gusto/recovery/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
Function, FunctionSpace, Interpolator, Projector,
SpatialCoordinate, TensorProductElement,
VectorFunctionSpace, as_vector, function, interval,
VectorElement)
VectorElement, assemble)
from firedrake.__future__ import interpolate
from gusto.recovery import Averager
from .recovery_kernels import (BoundaryRecoveryExtruded, BoundaryRecoveryHCurl,
BoundaryGaussianElimination)
Expand Down Expand Up @@ -144,14 +145,14 @@ def __init__(self, x_inout, method=BoundaryMethod.extruded, eff_coords=None):
V_broken = FunctionSpace(mesh, BrokenElement(V_inout.ufl_element()))
self.x_DG1_wrong = Function(V_broken)
self.x_DG1_correct = Function(V_broken)
self.interpolator = Interpolator(self.x_inout, self.x_DG1_wrong)
self.interpolate = interpolate(self.x_inout, V_broken)
self.averager = Averager(self.x_DG1_correct, self.x_inout)
self.kernel = BoundaryGaussianElimination(V_broken)

def apply(self):
"""Applies the boundary recovery process."""
if self.method == BoundaryMethod.taylor:
self.interpolator.interpolate()
self.x_DG1_wrong.assign(assemble(self.interpolate))
self.kernel.apply(self.x_DG1_wrong, self.x_DG1_correct,
self.act_coords, self.eff_coords, self.num_ext)
self.averager.project()
Expand Down Expand Up @@ -275,7 +276,7 @@ def __init__(self, x_in, x_out, method='interpolate', boundary_method=None):
self.boundary_recoverers.append(BoundaryRecoverer(x_out_scalars[i],
method=BoundaryMethod.taylor,
eff_coords=eff_coords[i]))
self.interpolate_to_vector = Interpolator(as_vector(x_out_scalars), self.x_out)
self.interpolate_to_vector = interpolate(as_vector(x_out_scalars), V_out)

def project(self):
"""Perform the whole recovery step."""
Expand All @@ -294,7 +295,7 @@ def project(self):
# Correct at boundaries
boundary_recoverer.apply()
# Combine the components to obtain the vector field
self.interpolate_to_vector.interpolate()
self.x_out.assign(assemble(self.interpolate_to_vector))
else:
# Extrapolate at boundaries
self.boundary_recoverer.apply()
Expand Down
15 changes: 8 additions & 7 deletions gusto/recovery/reversible_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from gusto.core.conservative_projection import ConservativeProjector
from firedrake import (Projector, Function, Interpolator)
from firedrake import (Projector, Function, assemble)
from firedrake.__future__ import interpolate
from .recovery import Recoverer

__all__ = ["ReversibleRecoverer", "ConservativeRecoverer"]
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, source_field, target_field, reconstruct_opts):
elif self.opts.project_high_method == 'project':
self.projector_high = Projector(self.q_recovered, self.q_rec_high)
elif self.opts.project_high_method == 'interpolate':
self.projector_high = Interpolator(self.q_recovered, self.q_rec_high)
self.projector_high = interpolate(self.q_recovered, target_field.function_space())
self.interp_high = True
else:
raise ValueError(f'Method {self.opts.project_high_method} '
Expand All @@ -68,7 +69,7 @@ def __init__(self, source_field, target_field, reconstruct_opts):
elif self.opts.project_low_method == 'project':
self.projector_low = Projector(self.q_rec_high, self.q_corr_low)
elif self.opts.project_low_method == 'interpolate':
self.projector_low = Interpolator(self.q_rec_high, self.q_corr_low)
self.projector_low = interpolate(self.q_rec_high, source_field.function_space())
self.interp_low = True
else:
raise ValueError(f'Method {self.opts.project_low_method} '
Expand All @@ -84,17 +85,17 @@ def __init__(self, source_field, target_field, reconstruct_opts):
elif self.opts.injection_method == 'project':
self.injector = Projector(self.q_corr_low, self.q_corr_high)
elif self.opts.injection_method == 'interpolate':
self.injector = Interpolator(self.q_corr_low, self.q_corr_high)
self.injector = interpolate(self.q_corr_low, target_field.function_space())
self.interp_inj = True
else:
raise ValueError(f'Method {self.opts.injection_method} for injection not valid')

def project(self):
self.recoverer.project()
self.projector_high.interpolate() if self.interp_high else self.projector_high.project()
self.projector_low.interpolate() if self.interp_low else self.projector_low.project()
self.q_rec_high.assign(assemble(self.projector_high)) if self.interp_high else self.projector_high.project()
self.q_corr_low.assign(assemble(self.projector_low)) if self.interp_low else self.projector_low.project()
self.q_corr_low.assign(self.q_low - self.q_corr_low)
self.injector.interpolate() if self.interp_inj else self.injector.project()
self.q_corr_high.assign(assemble(self.injector)) if self.interp_inj else self.injector.project()
self.q_high.assign(self.q_corr_high + self.q_rec_high)


Expand Down
15 changes: 8 additions & 7 deletions gusto/solvers/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
rhs, FacetNormal, div, dx, jump, avg, dS, dS_v, dS_h, ds_v, ds_t, ds_b,
ds_tb, inner, action, dot, grad, Function, VectorSpaceBasis, cross,
BrokenElement, FunctionSpace, MixedFunctionSpace, DirichletBC, as_vector,
Interpolator, conditional
assemble, conditional
)
from firedrake.fml import Term, drop
from firedrake.petsc import flatten_parameters
from firedrake.__future__ import interpolate
from pyop2.profiling import timed_function, timed_region

from gusto.equations.active_tracers import TracerVariableType
Expand Down Expand Up @@ -660,11 +661,11 @@ def _setup_solver(self):
qtbar = split(equation.X_ref)[3]

# set up interpolators that use the X_ref values for D and b_e
self.q_sat_expr_interpolator = Interpolator(
equation.compute_saturation(equation.X_ref), self.q_sat_func)
self.q_v_interpolator = Interpolator(
self.q_sat_expr_interpolate = interpolate(
equation.compute_saturation(equation.X_ref), VD)
self.q_v_interpolate = interpolate(
conditional(qtbar < self.q_sat_func, qtbar, self.q_sat_func),
self.qvbar)
VD)

# bbar was be_bar and here we correct to become bbar
bbar += equation.parameters.beta2 * self.qvbar
Expand Down Expand Up @@ -729,8 +730,8 @@ def trace_nullsp(T):
@timed_function("Gusto:UpdateReferenceProfiles")
def update_reference_profiles(self):
if self.equations.equivalent_buoyancy:
self.q_sat_expr_interpolator.interpolate()
self.q_v_interpolator.interpolate()
self.q_sat_func.assign(assemble(self.q_sat_expr_interpolate))
self.qvbar.assign(assemble(self.q_v_interpolate))

@timed_function("Gusto:LinearSolve")
def solve(self, xrhs, dy):
Expand Down
Loading

0 comments on commit e305589

Please sign in to comment.