Skip to content

Commit

Permalink
loopy: Refactor mathfunction mapping
Browse files Browse the repository at this point in the history
Remove need for math_table since loopy does name mapping and the only
names that need translating are Bessel functions and ln. Additionally,
don't forbid mathfunctions that only operator on real values in
complex mode (since their argument may be real!).
  • Loading branch information
wence- committed May 13, 2020
1 parent 3cdbc8a commit 41c3bc2
Showing 1 changed file with 26 additions and 67 deletions.
93 changes: 26 additions & 67 deletions tsfc/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,6 @@
from contextlib import contextmanager


# Table of handled math functions in real and complex modes
# Note that loopy handles addition of type prefixes and suffixes itself.
math_table = {
'sqrt': ('sqrt', 'sqrt'),
'abs': ('abs', 'abs'),
'cos': ('cos', 'cos'),
'sin': ('sin', 'sin'),
'tan': ('tan', 'tan'),
'acos': ('acos', 'acos'),
'asin': ('asin', 'asin'),
'atan': ('atan', 'atan'),
'cosh': ('cosh', 'cosh'),
'sinh': ('sinh', 'sinh'),
'tanh': ('tanh', 'tanh'),
'acosh': ('acosh', 'acosh'),
'asinh': ('asinh', 'asinh'),
'atanh': ('atanh', 'atanh'),
'power': ('pow', 'pow'),
'exp': ('exp', 'exp'),
'ln': ('log', 'log'),
'real': (None, 'real'),
'imag': (None, 'imag'),
'conj': (None, 'conj'),
'erf': ('erf', None),
'atan_2': ('atan2', None),
'atan2': ('atan2', None),
'min_value': ('min', None),
'max_value': ('max', None)
}


maxtype = partial(numpy.find_common_type, [])


Expand Down Expand Up @@ -419,49 +388,39 @@ def _expression_power(expr, ctx):

@_expression.register(gem.MathFunction)
def _expression_mathfunction(expr, ctx):

complex_mode = int(is_complex(ctx.scalar_type))

# Bessel functions
if expr.name.startswith('cyl_bessel_'):
if complex_mode:
msg = "Bessel functions for complex numbers: missing implementation"
raise NotImplementedError(msg)
# Bessel functions
if is_complex(ctx.scalar_type):
raise NotImplementedError("Bessel functions for complex numbers: "
"missing implementation")
nu, arg = expr.children
nu_thunk = lambda: expression(nu, ctx)
arg_loopy = expression(arg, ctx)
if expr.name == 'cyl_bessel_j':
if nu == gem.Zero():
return p.Variable("j0")(arg_loopy)
elif nu == gem.one:
return p.Variable("j1")(arg_loopy)
else:
return p.Variable("jn")(nu_thunk(), arg_loopy)
if expr.name == 'cyl_bessel_y':
if nu == gem.Zero():
return p.Variable("y0")(arg_loopy)
elif nu == gem.one:
return p.Variable("y1")(arg_loopy)
else:
return p.Variable("yn")(nu_thunk(), arg_loopy)

nu_ = expression(nu, ctx)
arg_ = expression(arg, ctx)
# Modified Bessel functions (C++ only)
#
# These mappings work for FEniCS only, and fail with Firedrake
# since no Boost available.
if expr.name in ['cyl_bessel_i', 'cyl_bessel_k']:
if expr.name in {'cyl_bessel_i', 'cyl_bessel_k'}:
name = 'boost::math::' + expr.name
return p.Variable(name)(nu_thunk(), arg_loopy)

assert False, "Unknown Bessel function: {}".format(expr.name)

# Other math functions
name = math_table[expr.name][complex_mode]
if name is None:
raise RuntimeError("{} not supported in {} mode".format(expr.name,
("real", "complex")[complex_mode]))

return p.Variable(name)(*[expression(c, ctx) for c in expr.children])
return p.Variable(name)(nu_, arg_)
else:
# cyl_bessel_{jy} -> {jy}
name = expr.name[-1:]
if nu == gem.Zero():
return p.Variable(f"{name}0")(arg_)
elif nu == gem.one:
return p.Variable(f"{name}1")(arg_)
else:
return p.Variable(f"{name}n")(nu_, arg_)
else:
if expr.name == "ln":
name = "log"
else:
name = expr.name
# Not all mathfunctions apply to complex numbers, but this
# will be picked up in loopy. This way we allow erf(real(...))
# in complex mode (say).
return p.Variable(name)(*(expression(c, ctx) for c in expr.children))


@_expression.register(gem.MinValue)
Expand Down

0 comments on commit 41c3bc2

Please sign in to comment.