Skip to content

Commit

Permalink
loopy: Use inferred type for output GlobalArg
Browse files Browse the repository at this point in the history
Only in the case of 0-forms, where we can control the allocated scalar type.
  • Loading branch information
wence- committed May 15, 2020
1 parent 6f3ee34 commit 0eb0af9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
5 changes: 4 additions & 1 deletion tsfc/kernel_interface/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
class Kernel(object):
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
"coefficient_numbers", "__weakref__")
"return_dtype", "coefficient_numbers", "__weakref__")
"""A compiled Kernel object.
:kwarg ast: The COFFEE ast for the kernel.
Expand All @@ -40,12 +40,14 @@ class Kernel(object):
:kwarg coefficient_numbers: A list of which coefficients from the
form the kernel needs.
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
:kwarg return_dtype: numpy dtype of the return value.
:kwarg tabulations: The runtime tabulations this kernel requires
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
"""
def __init__(self, ast=None, integral_type=None, oriented=False,
subdomain_id=None, domain_number=None, quadrature_rule=None,
coefficient_numbers=(),
return_dtype=None,
needs_cell_sizes=False):
# Defaults
self.ast = ast
Expand All @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
self.subdomain_id = subdomain_id
self.coefficient_numbers = coefficient_numbers
self.needs_cell_sizes = needs_cell_sizes
self.return_dtype = return_dtype
super(Kernel, self).__init__()


Expand Down
23 changes: 17 additions & 6 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
class Kernel(object):
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
"coefficient_numbers", "__weakref__")
"return_dtype", "coefficient_numbers", "__weakref__")
"""A compiled Kernel object.
:kwarg ast: The loopy kernel object.
Expand All @@ -40,12 +40,14 @@ class Kernel(object):
:kwarg coefficient_numbers: A list of which coefficients from the
form the kernel needs.
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
:kwarg return_dtype: numpy dtype of the return value.
:kwarg tabulations: The runtime tabulations this kernel requires
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
"""
def __init__(self, ast=None, integral_type=None, oriented=False,
subdomain_id=None, domain_number=None, quadrature_rule=None,
coefficient_numbers=(),
return_dtype=None,
needs_cell_sizes=False):
# Defaults
self.ast = ast
Expand All @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
self.subdomain_id = subdomain_id
self.coefficient_numbers = coefficient_numbers
self.needs_cell_sizes = needs_cell_sizes
self.return_dtype = return_dtype
super(Kernel, self).__init__()


Expand Down Expand Up @@ -164,8 +167,8 @@ def construct_kernel(self, return_arg, impero_c, precision, index_names):
for name_, shape in self.tabulations:
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))

loopy_kernel = generate_loopy(impero_c, args, precision, self.scalar_type,
"expression_kernel", index_names)
loopy_kernel, _ = generate_loopy(impero_c, args, precision, self.scalar_type,
"expression_kernel", index_names, ignore_return_type=True)
return ExpressionKernel(loopy_kernel, self.oriented, self.cell_sizes,
self.coefficients, self.tabulations)

Expand Down Expand Up @@ -207,6 +210,7 @@ def set_arguments(self, arguments, multiindices):
:arg multiindices: GEM argument multiindices
:returns: GEM expression representing the return variable
"""
self.rank = len(arguments)
self.local_tensor, expressions = prepare_arguments(
arguments, multiindices, self.scalar_type, interior_facet=self.interior_facet,
diagonal=self.diagonal)
Expand Down Expand Up @@ -277,7 +281,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
:returns: :class:`Kernel` object
"""

args = [self.local_tensor, self.coordinates_arg]
ignore_return_type = self.rank > 0
if ignore_return_type:
args = [self.local_tensor, self.coordinates_arg]
else:
args = [self.coordinates_arg]
if self.kernel.oriented:
args.append(self.cell_orientations_loopy_arg)
if self.kernel.needs_cell_sizes:
Expand All @@ -292,8 +300,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))

self.kernel.quadrature_rule = quadrature_rule
self.kernel.ast = generate_loopy(impero_c, args, precision,
self.scalar_type, name, index_names)
ast, dtype = generate_loopy(impero_c, args, precision,
self.scalar_type, name, index_names,
ignore_return_type=ignore_return_type)
self.kernel.ast = ast
self.kernel.return_dtype = dtype
return self.kernel

def construct_empty_kernel(self, name):
Expand Down
15 changes: 11 additions & 4 deletions tsfc/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,17 @@ def active_indices(mapping, ctx):
ctx.active_indices.pop(key)


def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[]):
def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[],
ignore_return_type=True):
"""Generates loopy code.
:arg impero_c: ImperoC tuple with Impero AST and other data
:arg args: list of loopy.GlobalArgs
:arg precision: floating-point precision for printing
:arg scalar_type: type of scalars as C typename string
:arg scalar_type: type of scalars as numpy dtype
:arg kernel_name: function name of the kernel
:arg index_names: pre-assigned index names
:arg ignore_return_type: Ignore inferred return type from impero_c?
:returns: loopy kernel
"""
ctx = LoopyContext()
Expand All @@ -205,7 +207,12 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
ctx.epsilon = 10.0 ** (-precision)

# Create arguments
data = list(args)
if ignore_return_type:
return_dtype = scalar_type
data = list(args)
else:
A, return_dtype = impero_c.return_variable
data = [lp.GlobalArg(A.name, shape=A.shape, dtype=return_dtype)] + list(args)
for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)):
name = "t%d" % i
if isinstance(temp, gem.Constant):
Expand Down Expand Up @@ -240,7 +247,7 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
insn_new.append(insn.copy(priority=len(knl.instructions) - i))
knl = knl.copy(instructions=insn_new)

return knl
return knl, return_dtype


@singledispatch
Expand Down

0 comments on commit 0eb0af9

Please sign in to comment.