Skip to content

Commit

Permalink
Merge pull request #212 from firedrakeproject/wence/complex
Browse files Browse the repository at this point in the history
wence/complex
  • Loading branch information
wence- authored May 13, 2020
2 parents d2fc040 + 41c3bc2 commit d961ab4
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 87 deletions.
11 changes: 5 additions & 6 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ def __new__(cls, array):
return super(Literal, cls).__new__(cls)

def __init__(self, array):
array = asarray(array)
try:
self.array = asarray(array, dtype=float)
self.array = array.astype(float, casting="safe")
except TypeError:
self.array = asarray(array, dtype=complex)
self.array = array.astype(complex)

def is_equal(self, other):
if type(self) != type(other):
Expand All @@ -208,10 +209,8 @@ def get_hash(self):

@property
def value(self):
try:
return float(self.array)
except TypeError:
return complex(self.array)
assert self.shape == ()
return self.array.dtype.type(self.array)

@property
def shape(self):
Expand Down
27 changes: 25 additions & 2 deletions gem/refactorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from gem.node import Memoizer, traversal
from gem.gem import (Node, Conditional, Zero, Product, Sum, Indexed,
ListTensor, one)
ListTensor, one, MathFunction)
from gem.optimise import (remove_componenttensors, sum_factorise,
traverse_product, traverse_sum, unroll_indexsum,
make_rename_map, make_renamer)
Expand Down Expand Up @@ -169,7 +169,7 @@ def stop_at(expr):
sums = []
for expr in compounds:
summands = traverse_sum(expr, stop_at=stop_at)
if len(summands) <= 1 and not isinstance(expr, Conditional):
if len(summands) <= 1 and not isinstance(expr, (Conditional, MathFunction)):
# Compound term is not an addition, avoid infinite
# recursion and fail gracefully raising an exception.
raise FactorisationError(expr)
Expand Down Expand Up @@ -211,6 +211,29 @@ def stop_at(expr):
return result


@_collect_monomials.register(MathFunction)
def _collect_monomials_mathfunction(expression, self):
name = expression.name
if name in {"conj", "real", "imag"}:
# These are allowed to be applied to arguments, and hence must
# be dealt with specially. Just push the function onto each
# entry in the monomialsum of the child.
# NOTE: This presently assumes that the "atomics" part of a
# MonomialSum are real. This is true for the coffee, tensor,
# spectral modes: the atomics are indexed tabulation matrices
# (which are guaranteed real).
# If the classifier puts (potentially) complex expressions in
# atomics, then this code needs fixed.
child_ms, = map(self, expression.children)
result = MonomialSum()
for k, v in child_ms.monomials.items():
result.monomials[k] = MathFunction(name, v)
result.ordering = child_ms.ordering.copy()
return result
else:
return _collect_monomials.dispatch(MathFunction.mro()[1])(expression, self)


@_collect_monomials.register(Conditional)
def _collect_monomials_conditional(expression, self):
"""Refactorises a conditional expression into a sum-of-products form,
Expand Down
19 changes: 9 additions & 10 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr
assert isinstance(form, Form)

# Determine whether in complex mode:
# complex nodes would break the refactoriser.
complex_mode = parameters and is_complex(parameters.get("scalar_type"))
if complex_mode:
logger.warning("Disabling whole expression optimisations"
" in GEM for supporting complex mode.")
parameters = parameters.copy()
parameters["mode"] = 'vanilla'

fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode)
logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)

Expand Down Expand Up @@ -98,6 +91,10 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
# Delayed import, loopy is a runtime dependency
import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy
interface = firedrake_interface_loopy.KernelBuilder
if coffee:
scalar_type = parameters["scalar_type_c"]
else:
scalar_type = parameters["scalar_type"]

# Remove these here, they're handled below.
if parameters.get("quadrature_degree") in ["auto", "default", None, -1, "-1"]:
Expand All @@ -123,7 +120,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
domain_numbering = form_data.original_form.domain_numbering()
builder = interface(integral_type, integral_data.subdomain_id,
domain_numbering[integral_data.domain],
parameters["scalar_type"],
scalar_type,
diagonal=diagonal)
argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices()
for arg in arguments)
Expand Down Expand Up @@ -158,7 +155,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
integration_dim=integration_dim,
entity_ids=entity_ids,
argument_multiindices=argument_multiindices,
index_cache=index_cache)
index_cache=index_cache,
complex_mode=is_complex(parameters.get("scalar_type")))

mode_irs = collections.OrderedDict()
for integral in integral_data.integrals:
Expand Down Expand Up @@ -345,7 +343,8 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, inte
ufl_cell=coordinates.ufl_domain().ufl_cell(),
precision=parameters["precision"],
argument_multiindices=argument_multiindices,
index_cache={})
index_cache={},
complex_mode=complex_mode)

if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()):
# This is an optimisation for point-evaluation nodes which
Expand Down
26 changes: 13 additions & 13 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class ContextBase(ProxyKernelInterface):
'precision',
'argument_multiindices',
'facetarea',
'index_cache')
'index_cache',
'complex_mode')

def __init__(self, interface, **kwargs):
ProxyKernelInterface.__init__(self, interface)
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(self, mt, interface):
@property
def config(self):
config = {name: getattr(self.interface, name)
for name in ["ufl_cell", "precision", "index_cache"]}
for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]}
config["interface"] = self.interface
return config

Expand All @@ -143,11 +144,10 @@ def jacobian_at(self, point):
expr = PositiveRestricted(expr)
elif self.mt.restriction == '-':
expr = NegativeRestricted(expr)
expr = preprocess_expression(expr)

config = {"point_set": PointSingleton(point)}
config.update(self.config)
context = PointSetContext(**config)
expr = preprocess_expression(expr, complex_mode=context.complex_mode)
return map_expr_dag(context.translator, expr)

def reference_normals(self):
Expand Down Expand Up @@ -175,10 +175,10 @@ def physical_edge_lengths(self):
expr = NegativeRestricted(expr)

expr = ufl.as_vector([ufl.sqrt(ufl.dot(expr[i, :], expr[i, :])) for i in range(3)])
expr = preprocess_expression(expr)
config = {"point_set": PointSingleton([1/3, 1/3])}
config.update(self.config)
context = PointSetContext(**config)
expr = preprocess_expression(expr, complex_mode=context.complex_mode)
return map_expr_dag(context.translator, expr)


Expand Down Expand Up @@ -273,7 +273,7 @@ def cell_avg(self, o):
integrand, degree, argument_multiindices = entity_avg(integrand / CellVolume(domain), measure, self.context.argument_multiindices)

config = {name: getattr(self.context, name)
for name in ["ufl_cell", "precision", "index_cache"]}
for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]}
config.update(quadrature_degree=degree, interface=self.context,
argument_multiindices=argument_multiindices)
expr, = compile_ufl(integrand, point_sum=True, **config)
Expand All @@ -290,7 +290,7 @@ def facet_avg(self, o):
config = {name: getattr(self.context, name)
for name in ["ufl_cell", "precision", "index_cache",
"integration_dim", "entity_ids",
"integral_type"]}
"integral_type", "complex_mode"]}
config.update(quadrature_degree=degree, interface=self.context,
argument_multiindices=argument_multiindices)
expr, = compile_ufl(integrand, point_sum=True, **config)
Expand Down Expand Up @@ -423,7 +423,7 @@ def translate_spatialcoordinate(terminal, mt, ctx):
# Replace terminal with a Coefficient
terminal = ctx.coordinate(terminal.ufl_domain())
# Get back to reference space
terminal = preprocess_expression(terminal)
terminal = preprocess_expression(terminal, complex_mode=ctx.complex_mode)
# Rebuild modified terminal
expr = construct_modified_terminal(mt, terminal)
# Translate replaced UFL snippet
Expand Down Expand Up @@ -451,7 +451,7 @@ def translate_cellvolume(terminal, mt, ctx):
interface = CellVolumeKernelInterface(ctx, mt.restriction)

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "precision", "index_cache"]}
for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]}
config.update(interface=interface, quadrature_degree=degree)
expr, = compile_ufl(integrand, point_sum=True, **config)
return expr
Expand All @@ -465,7 +465,7 @@ def translate_facetarea(terminal, mt, ctx):

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "integration_dim",
"entity_ids", "precision", "index_cache"]}
"entity_ids", "precision", "index_cache", "complex_mode"]}
config.update(interface=ctx, quadrature_degree=degree)
expr, = compile_ufl(integrand, point_sum=True, **config)
return expr
Expand All @@ -479,7 +479,7 @@ def translate_cellorigin(terminal, mt, ctx):
point_set = PointSingleton((0.0,) * domain.topological_dimension())

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "precision", "index_cache"]}
for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]}
config.update(interface=ctx, point_set=point_set)
context = PointSetContext(**config)
return context.translator(expression)
Expand All @@ -492,7 +492,7 @@ def translate_cell_vertices(terminal, mt, ctx):
ps = PointSet(numpy.array(ctx.fiat_cell.get_vertices()))

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "precision", "index_cache"]}
for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]}
config.update(interface=ctx, point_set=ps)
context = PointSetContext(**config)
expr = context.translator(ufl_expr)
Expand Down Expand Up @@ -633,7 +633,7 @@ def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs):
context = PointSetContext(**kwargs)

# Abs-simplification
expression = simplify_abs(expression)
expression = simplify_abs(expression, context.complex_mode)
if interior_facet:
expressions = []
for rs in itertools.product(("+", "-"), repeat=len(context.argument_multiindices)):
Expand Down
Loading

0 comments on commit d961ab4

Please sign in to comment.