From a8a256f6f733dc66388a6791357614a5fac5bff9 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 Apr 2020 16:57:10 +0100 Subject: [PATCH 1/8] gem: Avoid complex warnings when initialising Literal array --- gem/gem.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index 4021e962..f0140f61 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -191,10 +191,11 @@ def __new__(cls, array): return super(Literal, cls).__new__(cls) def __init__(self, array): + array = asarray(array) try: - self.array = asarray(array, dtype=float) + self.array = array.astype(float, casting="safe") except TypeError: - self.array = asarray(array, dtype=complex) + self.array = array.astype(complex) def is_equal(self, other): if type(self) != type(other): @@ -208,10 +209,8 @@ def get_hash(self): @property def value(self): - try: - return float(self.array) - except TypeError: - return complex(self.array) + assert self.shape == () + return self.array.dtype.type(self.array) @property def shape(self): From faa08cfcdc6598b7479bd777cc787d738b2d5f44 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 Apr 2020 16:58:07 +0100 Subject: [PATCH 2/8] loopy: Correct rounding in complex mode --- tsfc/loopy.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index baa3b597..a130c743 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -2,8 +2,6 @@ This is the final stage of code generation in TSFC.""" -from math import isnan - import numpy from functools import singledispatch from collections import defaultdict, OrderedDict @@ -399,10 +397,10 @@ def _expression_conditional(expr, ctx): def _expression_scalar(expr, parameters): assert not expr.shape v = expr.value - if isnan(v): + if numpy.isnan(v): return p.Variable("NAN") - r = round(v, 1) - if r and abs(v - r) < parameters.epsilon: + r = numpy.round(v, 1) + if r and numpy.abs(v - r) < parameters.epsilon: return r return v From a063012d55cfe8858ded07ca4ac8f6e2ea228cf6 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 Apr 2020 16:59:31 +0100 Subject: [PATCH 3/8] parameters: scalar_type is a dtype --- tsfc/driver.py | 6 +++++- tsfc/parameters.py | 10 +++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index 239dc17c..2f31f245 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -98,6 +98,10 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co # Delayed import, loopy is a runtime dependency import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy interface = firedrake_interface_loopy.KernelBuilder + if coffee: + scalar_type = parameters["scalar_type_c"] + else: + scalar_type = parameters["scalar_type"] # Remove these here, they're handled below. if parameters.get("quadrature_degree") in ["auto", "default", None, -1, "-1"]: @@ -123,7 +127,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co domain_numbering = form_data.original_form.domain_numbering() builder = interface(integral_type, integral_data.subdomain_id, domain_numbering[integral_data.domain], - parameters["scalar_type"], + scalar_type, diagonal=diagonal) argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() for arg in arguments) diff --git a/tsfc/parameters.py b/tsfc/parameters.py index 8cd98be4..dbd65a82 100644 --- a/tsfc/parameters.py +++ b/tsfc/parameters.py @@ -14,8 +14,11 @@ # that makes compilation time much shorter. "unroll_indexsum": 3, - # Scalar type (C typename string) - "scalar_type": "double", + # Scalar type numpy dtype + "scalar_type": numpy.dtype(numpy.float64), + + # So that tests pass (needs to match scalar_type) + "scalar_type_c": "double", # Precision of float printing (number of digits) "precision": numpy.finfo(numpy.dtype("double")).precision, @@ -28,4 +31,5 @@ def default_parameters(): def is_complex(scalar_type): """Decides complex mode based on scalar type.""" - return scalar_type and 'complex' in scalar_type + return scalar_type and (isinstance(scalar_type, numpy.dtype) and scalar_type.kind == 'c') \ + or (isinstance(scalar_type, str) and "complex" in scalar_type) From 41cd51c391ad06dda783e4b52c371c3d256b9e5e Mon Sep 17 00:00:00 2001 From: David Ham Date: Thu, 30 Apr 2020 17:00:29 +0100 Subject: [PATCH 4/8] loopy: Correct math function name mappings --- tsfc/loopy.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index a130c743..05bde41b 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -21,6 +21,37 @@ from contextlib import contextmanager +# Table of handled math functions in real and complex modes +# Note that loopy handles addition of type prefixes and suffixes itself. +math_table = { + 'sqrt': ('sqrt', 'sqrt'), + 'abs': ('abs', 'abs'), + 'cos': ('cos', 'cos'), + 'sin': ('sin', 'sin'), + 'tan': ('tan', 'tan'), + 'acos': ('acos', 'acos'), + 'asin': ('asin', 'asin'), + 'atan': ('atan', 'atan'), + 'cosh': ('cosh', 'cosh'), + 'sinh': ('sinh', 'sinh'), + 'tanh': ('tanh', 'tanh'), + 'acosh': ('acosh', 'acosh'), + 'asinh': ('asinh', 'asinh'), + 'atanh': ('atanh', 'atanh'), + 'power': ('pow', 'pow'), + 'exp': ('exp', 'exp'), + 'ln': ('log', 'log'), + 'real': (None, 'real'), + 'imag': (None, 'imag'), + 'conj': (None, 'conj'), + 'erf': ('erf', None), + 'atan_2': ('atan2', None), + 'atan2': ('atan2', None), + 'min_value': ('min', None), + 'max_value': ('max', None) +} + + class LoopyContext(object): def __init__(self): self.indices = {} # indices for declarations and referencing values, from ImperoC @@ -309,11 +340,6 @@ def _expression_power(expr, ctx): @_expression.register(gem.MathFunction) def _expression_mathfunction(expr, ctx): - from tsfc.coffee import math_table - - math_table = math_table.copy() - math_table['abs'] = ('abs', 'cabs') - complex_mode = int(is_complex(ctx.scalar_type)) # Bessel functions @@ -352,7 +378,8 @@ def _expression_mathfunction(expr, ctx): # Other math functions name = math_table[expr.name][complex_mode] if name is None: - raise RuntimeError("{} not supported in complex mode".format(expr.name)) + raise RuntimeError("{} not supported in {} mode".format(expr.name, + ("real", "complex")[complex_mode])) return p.Variable(name)(*[expression(c, ctx) for c in expr.children]) From c768f69a56b805f7dd5f35ab26b157df9bb1a475 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 Apr 2020 17:03:40 +0100 Subject: [PATCH 5/8] gem: Teach refactoriser about conj/real/imag Fixes #166. --- gem/refactorise.py | 27 +++++++++++++++++++++++++-- tsfc/driver.py | 7 ------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/gem/refactorise.py b/gem/refactorise.py index f83bbfd8..2ca6e4cc 100644 --- a/gem/refactorise.py +++ b/gem/refactorise.py @@ -8,7 +8,7 @@ from gem.node import Memoizer, traversal from gem.gem import (Node, Conditional, Zero, Product, Sum, Indexed, - ListTensor, one) + ListTensor, one, MathFunction) from gem.optimise import (remove_componenttensors, sum_factorise, traverse_product, traverse_sum, unroll_indexsum, make_rename_map, make_renamer) @@ -169,7 +169,7 @@ def stop_at(expr): sums = [] for expr in compounds: summands = traverse_sum(expr, stop_at=stop_at) - if len(summands) <= 1 and not isinstance(expr, Conditional): + if len(summands) <= 1 and not isinstance(expr, (Conditional, MathFunction)): # Compound term is not an addition, avoid infinite # recursion and fail gracefully raising an exception. raise FactorisationError(expr) @@ -211,6 +211,29 @@ def stop_at(expr): return result +@_collect_monomials.register(MathFunction) +def _collect_monomials_mathfunction(expression, self): + name = expression.name + if name in {"conj", "real", "imag"}: + # These are allowed to be applied to arguments, and hence must + # be dealt with specially. Just push the function onto each + # entry in the monomialsum of the child. + # NOTE: This presently assumes that the "atomics" part of a + # MonomialSum are real. This is true for the coffee, tensor, + # spectral modes: the atomics are indexed tabulation matrices + # (which are guaranteed real). + # If the classifier puts (potentially) complex expressions in + # atomics, then this code needs fixed. + child_ms, = map(self, expression.children) + result = MonomialSum() + for k, v in child_ms.monomials.items(): + result.monomials[k] = MathFunction(name, v) + result.ordering = child_ms.ordering.copy() + return result + else: + return _collect_monomials.dispatch(MathFunction.mro()[1])(expression, self) + + @_collect_monomials.register(Conditional) def _collect_monomials_conditional(expression, self): """Refactorises a conditional expression into a sum-of-products form, diff --git a/tsfc/driver.py b/tsfc/driver.py index 2f31f245..d03165cc 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -50,14 +50,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr assert isinstance(form, Form) # Determine whether in complex mode: - # complex nodes would break the refactoriser. complex_mode = parameters and is_complex(parameters.get("scalar_type")) - if complex_mode: - logger.warning("Disabling whole expression optimisations" - " in GEM for supporting complex mode.") - parameters = parameters.copy() - parameters["mode"] = 'vanilla' - fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode) logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time) From 6fcabe00b1ec69418bede19614534f99ffb6db40 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 Apr 2020 17:04:24 +0100 Subject: [PATCH 6/8] loopy: Type propagation for expressions Enables correct temporary variable dtype to be applied during code generation. Fixes #171. --- tsfc/loopy.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 4 deletions(-) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index 05bde41b..ca6a5356 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -3,10 +3,11 @@ This is the final stage of code generation in TSFC.""" import numpy -from functools import singledispatch +from functools import singledispatch, partial from collections import defaultdict, OrderedDict from gem import gem, impero as imp +from gem.node import Memoizer import islpy as isl import loopy as lp @@ -52,6 +53,85 @@ } +maxtype = partial(numpy.find_common_type, []) + + +@singledispatch +def _assign_dtype(expression, self): + return maxtype(map(self, expression.children)) + + +@_assign_dtype.register(gem.Terminal) +def _assign_dtype_terminal(expression, self): + return self.scalar_type + + +@_assign_dtype.register(gem.Zero) +@_assign_dtype.register(gem.Identity) +@_assign_dtype.register(gem.Delta) +def _assign_dtype_real(expression, self): + return self.real_type + + +@_assign_dtype.register(gem.Literal) +def _assign_dtype_identity(expression, self): + return expression.array.dtype + + +@_assign_dtype.register(gem.Power) +def _assign_dtype_power(expression, self): + # Conservative + return self.scalar_type + + +@_assign_dtype.register(gem.MathFunction) +def _assign_dtype_mathfunction(expression, self): + if expression.name in {"abs", "real", "imag"}: + return self.real_type + elif expression.name == "sqrt": + return self.scalar_type + else: + return maxtype(map(self, expression.children)) + + +@_assign_dtype.register(gem.MinValue) +@_assign_dtype.register(gem.MaxValue) +def _assign_dtype_minmax(expression, self): + # UFL did correctness checking + return self.real_type + + +@_assign_dtype.register(gem.Conditional) +def _assign_dtype_conditional(expression, self): + return maxtype(map(self, expression.children[1:])) + + +@_assign_dtype.register(gem.Comparison) +@_assign_dtype.register(gem.LogicalNot) +@_assign_dtype.register(gem.LogicalAnd) +@_assign_dtype.register(gem.LogicalOr) +def _assign_dtype_logical(expression, self): + return numpy.int8 + + +def assign_dtypes(expressions, scalar_type): + """Assign numpy data types to expressions. + + Used for declaring temporaries when converting from Impero to lower level code. + + :arg expressions: List of GEM expressions. + :arg scalar_type: Default scalar type. + + :returns: list of tuples (expression, dtype).""" + mapper = Memoizer(_assign_dtype) + mapper.scalar_type = scalar_type + if scalar_type.kind == "c": + mapper.real_type = numpy.finfo(scalar_type).dtype + else: + mapper.real_type = scalar_type + return [(e, mapper(e)) for e in expressions] + + class LoopyContext(object): def __init__(self): self.indices = {} # indices for declarations and referencing values, from ImperoC @@ -157,13 +237,13 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", # Create arguments data = list(args) - for i, temp in enumerate(impero_c.temporaries): + for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)): name = "t%d" % i if isinstance(temp, gem.Constant): - data.append(lp.TemporaryVariable(name, shape=temp.shape, dtype=temp.array.dtype, initializer=temp.array, address_space=lp.AddressSpace.LOCAL, read_only=True)) + data.append(lp.TemporaryVariable(name, shape=temp.shape, dtype=dtype, initializer=temp.array, address_space=lp.AddressSpace.LOCAL, read_only=True)) else: shape = tuple([i.extent for i in ctx.indices[temp]]) + temp.shape - data.append(lp.TemporaryVariable(name, shape=shape, dtype=numpy.float64, initializer=None, address_space=lp.AddressSpace.LOCAL, read_only=False)) + data.append(lp.TemporaryVariable(name, shape=shape, dtype=dtype, initializer=None, address_space=lp.AddressSpace.LOCAL, read_only=False)) ctx.gem_to_pymbolic[temp] = p.Variable(name) # Create instructions From 3cdbc8a8e9f66a32e72765e1dcdbc48d2b946e5e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 7 May 2020 23:30:39 +0100 Subject: [PATCH 7/8] compile_ufl: correct abs-simplification for sqrt in complex mode The assumption that the sqrt of a value is positive (and therefore doesn't require wrapping in abs) is no longer true when values can be complex. --- tsfc/driver.py | 6 ++++-- tsfc/fem.py | 26 +++++++++++++------------- tsfc/ufl_utils.py | 13 +++++++++---- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tsfc/driver.py b/tsfc/driver.py index d03165cc..cc98895d 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -155,7 +155,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co integration_dim=integration_dim, entity_ids=entity_ids, argument_multiindices=argument_multiindices, - index_cache=index_cache) + index_cache=index_cache, + complex_mode=is_complex(parameters.get("scalar_type"))) mode_irs = collections.OrderedDict() for integral in integral_data.integrals: @@ -342,7 +343,8 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, inte ufl_cell=coordinates.ufl_domain().ufl_cell(), precision=parameters["precision"], argument_multiindices=argument_multiindices, - index_cache={}) + index_cache={}, + complex_mode=complex_mode) if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()): # This is an optimisation for point-evaluation nodes which diff --git a/tsfc/fem.py b/tsfc/fem.py index 5c226384..1726c468 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -56,7 +56,8 @@ class ContextBase(ProxyKernelInterface): 'precision', 'argument_multiindices', 'facetarea', - 'index_cache') + 'index_cache', + 'complex_mode') def __init__(self, interface, **kwargs): ProxyKernelInterface.__init__(self, interface) @@ -130,7 +131,7 @@ def __init__(self, mt, interface): @property def config(self): config = {name: getattr(self.interface, name) - for name in ["ufl_cell", "precision", "index_cache"]} + for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]} config["interface"] = self.interface return config @@ -143,11 +144,10 @@ def jacobian_at(self, point): expr = PositiveRestricted(expr) elif self.mt.restriction == '-': expr = NegativeRestricted(expr) - expr = preprocess_expression(expr) - config = {"point_set": PointSingleton(point)} config.update(self.config) context = PointSetContext(**config) + expr = preprocess_expression(expr, complex_mode=context.complex_mode) return map_expr_dag(context.translator, expr) def reference_normals(self): @@ -175,10 +175,10 @@ def physical_edge_lengths(self): expr = NegativeRestricted(expr) expr = ufl.as_vector([ufl.sqrt(ufl.dot(expr[i, :], expr[i, :])) for i in range(3)]) - expr = preprocess_expression(expr) config = {"point_set": PointSingleton([1/3, 1/3])} config.update(self.config) context = PointSetContext(**config) + expr = preprocess_expression(expr, complex_mode=context.complex_mode) return map_expr_dag(context.translator, expr) @@ -273,7 +273,7 @@ def cell_avg(self, o): integrand, degree, argument_multiindices = entity_avg(integrand / CellVolume(domain), measure, self.context.argument_multiindices) config = {name: getattr(self.context, name) - for name in ["ufl_cell", "precision", "index_cache"]} + for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) expr, = compile_ufl(integrand, point_sum=True, **config) @@ -290,7 +290,7 @@ def facet_avg(self, o): config = {name: getattr(self.context, name) for name in ["ufl_cell", "precision", "index_cache", "integration_dim", "entity_ids", - "integral_type"]} + "integral_type", "complex_mode"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) expr, = compile_ufl(integrand, point_sum=True, **config) @@ -423,7 +423,7 @@ def translate_spatialcoordinate(terminal, mt, ctx): # Replace terminal with a Coefficient terminal = ctx.coordinate(terminal.ufl_domain()) # Get back to reference space - terminal = preprocess_expression(terminal) + terminal = preprocess_expression(terminal, complex_mode=ctx.complex_mode) # Rebuild modified terminal expr = construct_modified_terminal(mt, terminal) # Translate replaced UFL snippet @@ -451,7 +451,7 @@ def translate_cellvolume(terminal, mt, ctx): interface = CellVolumeKernelInterface(ctx, mt.restriction) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "precision", "index_cache"]} + for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]} config.update(interface=interface, quadrature_degree=degree) expr, = compile_ufl(integrand, point_sum=True, **config) return expr @@ -465,7 +465,7 @@ def translate_facetarea(terminal, mt, ctx): config = {name: getattr(ctx, name) for name in ["ufl_cell", "integration_dim", - "entity_ids", "precision", "index_cache"]} + "entity_ids", "precision", "index_cache", "complex_mode"]} config.update(interface=ctx, quadrature_degree=degree) expr, = compile_ufl(integrand, point_sum=True, **config) return expr @@ -479,7 +479,7 @@ def translate_cellorigin(terminal, mt, ctx): point_set = PointSingleton((0.0,) * domain.topological_dimension()) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "precision", "index_cache"]} + for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]} config.update(interface=ctx, point_set=point_set) context = PointSetContext(**config) return context.translator(expression) @@ -492,7 +492,7 @@ def translate_cell_vertices(terminal, mt, ctx): ps = PointSet(numpy.array(ctx.fiat_cell.get_vertices())) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "precision", "index_cache"]} + for name in ["ufl_cell", "precision", "index_cache", "complex_mode"]} config.update(interface=ctx, point_set=ps) context = PointSetContext(**config) expr = context.translator(ufl_expr) @@ -633,7 +633,7 @@ def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs): context = PointSetContext(**kwargs) # Abs-simplification - expression = simplify_abs(expression) + expression = simplify_abs(expression, context.complex_mode) if interior_facet: expressions = [] for rs in itertools.product(("+", "-"), repeat=len(context.argument_multiindices)): diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 4c2da3a6..35c87bb7 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -273,8 +273,11 @@ def _simplify_abs_expr(o, self, in_abs): @_simplify_abs.register(Sqrt) def _simplify_abs_sqrt(o, self, in_abs): - # Square root is always non-negative - return ufl_reuse_if_untouched(o, self(o.ufl_operands[0], False)) + result = ufl_reuse_if_untouched(o, self(o.ufl_operands[0], False)) + if self.complex_mode and in_abs: + return Abs(result) + else: + return result @_simplify_abs.register(ScalarValue) @@ -326,11 +329,13 @@ def _simplify_abs_abs(o, self, in_abs): return self(o.ufl_operands[0], True) -def simplify_abs(expression): +def simplify_abs(expression, complex_mode): """Simplify absolute values in a UFL expression. Its primary purpose is to "neutralise" CellOrientation nodes that are surrounded by absolute values and thus not at all necessary.""" - return MemoizerArg(_simplify_abs)(expression, False) + mapper = MemoizerArg(_simplify_abs) + mapper.complex_mode = complex_mode + return mapper(expression, False) def apply_mapping(expression, mapping): From 41c3bc21a903c7fdb6fb12a20c15a1593c3a35c2 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Sat, 9 May 2020 19:02:33 +0100 Subject: [PATCH 8/8] loopy: Refactor mathfunction mapping Remove need for math_table since loopy does name mapping and the only names that need translating are Bessel functions and ln. Additionally, don't forbid mathfunctions that only operator on real values in complex mode (since their argument may be real!). --- tsfc/loopy.py | 93 ++++++++++++++------------------------------------- 1 file changed, 26 insertions(+), 67 deletions(-) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index ca6a5356..e04ba2ba 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -22,37 +22,6 @@ from contextlib import contextmanager -# Table of handled math functions in real and complex modes -# Note that loopy handles addition of type prefixes and suffixes itself. -math_table = { - 'sqrt': ('sqrt', 'sqrt'), - 'abs': ('abs', 'abs'), - 'cos': ('cos', 'cos'), - 'sin': ('sin', 'sin'), - 'tan': ('tan', 'tan'), - 'acos': ('acos', 'acos'), - 'asin': ('asin', 'asin'), - 'atan': ('atan', 'atan'), - 'cosh': ('cosh', 'cosh'), - 'sinh': ('sinh', 'sinh'), - 'tanh': ('tanh', 'tanh'), - 'acosh': ('acosh', 'acosh'), - 'asinh': ('asinh', 'asinh'), - 'atanh': ('atanh', 'atanh'), - 'power': ('pow', 'pow'), - 'exp': ('exp', 'exp'), - 'ln': ('log', 'log'), - 'real': (None, 'real'), - 'imag': (None, 'imag'), - 'conj': (None, 'conj'), - 'erf': ('erf', None), - 'atan_2': ('atan2', None), - 'atan2': ('atan2', None), - 'min_value': ('min', None), - 'max_value': ('max', None) -} - - maxtype = partial(numpy.find_common_type, []) @@ -419,49 +388,39 @@ def _expression_power(expr, ctx): @_expression.register(gem.MathFunction) def _expression_mathfunction(expr, ctx): - - complex_mode = int(is_complex(ctx.scalar_type)) - - # Bessel functions if expr.name.startswith('cyl_bessel_'): - if complex_mode: - msg = "Bessel functions for complex numbers: missing implementation" - raise NotImplementedError(msg) + # Bessel functions + if is_complex(ctx.scalar_type): + raise NotImplementedError("Bessel functions for complex numbers: " + "missing implementation") nu, arg = expr.children - nu_thunk = lambda: expression(nu, ctx) - arg_loopy = expression(arg, ctx) - if expr.name == 'cyl_bessel_j': - if nu == gem.Zero(): - return p.Variable("j0")(arg_loopy) - elif nu == gem.one: - return p.Variable("j1")(arg_loopy) - else: - return p.Variable("jn")(nu_thunk(), arg_loopy) - if expr.name == 'cyl_bessel_y': - if nu == gem.Zero(): - return p.Variable("y0")(arg_loopy) - elif nu == gem.one: - return p.Variable("y1")(arg_loopy) - else: - return p.Variable("yn")(nu_thunk(), arg_loopy) - + nu_ = expression(nu, ctx) + arg_ = expression(arg, ctx) # Modified Bessel functions (C++ only) # # These mappings work for FEniCS only, and fail with Firedrake # since no Boost available. - if expr.name in ['cyl_bessel_i', 'cyl_bessel_k']: + if expr.name in {'cyl_bessel_i', 'cyl_bessel_k'}: name = 'boost::math::' + expr.name - return p.Variable(name)(nu_thunk(), arg_loopy) - - assert False, "Unknown Bessel function: {}".format(expr.name) - - # Other math functions - name = math_table[expr.name][complex_mode] - if name is None: - raise RuntimeError("{} not supported in {} mode".format(expr.name, - ("real", "complex")[complex_mode])) - - return p.Variable(name)(*[expression(c, ctx) for c in expr.children]) + return p.Variable(name)(nu_, arg_) + else: + # cyl_bessel_{jy} -> {jy} + name = expr.name[-1:] + if nu == gem.Zero(): + return p.Variable(f"{name}0")(arg_) + elif nu == gem.one: + return p.Variable(f"{name}1")(arg_) + else: + return p.Variable(f"{name}n")(nu_, arg_) + else: + if expr.name == "ln": + name = "log" + else: + name = expr.name + # Not all mathfunctions apply to complex numbers, but this + # will be picked up in loopy. This way we allow erf(real(...)) + # in complex mode (say). + return p.Variable(name)(*(expression(c, ctx) for c in expr.children)) @_expression.register(gem.MinValue)