Skip to content

Commit

Permalink
impero: Attempt type inference on return variable
Browse files Browse the repository at this point in the history
Will be used to allocate correct type of 0-form output variables. This
is necessary in complex mode where if we explicitly ask for a real
functional we need the type to match.
  • Loading branch information
wence- committed May 15, 2020
1 parent 533cc18 commit 54e2978
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
33 changes: 30 additions & 3 deletions gem/impero_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from functools import singledispatch
from itertools import chain, groupby

from numpy import find_common_type

from gem.node import traversal, collect_refcount
from gem import gem, impero as imp, optimise, scheduling

Expand All @@ -21,7 +23,9 @@
# temporaries - List of GEM expressions which have assigned temporaries
# declare - Where to declare temporaries to get correct C code
# indices - Indices for declarations and referencing values
ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices'])
# return_variable - 2-tuple of gem return variable and inferred numpy dtype
ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices',
'return_variable'])


class NoopError(Exception):
Expand All @@ -38,11 +42,12 @@ def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True
return expressions


def compile_gem(assignments, prefix_ordering, remove_zeros=False):
def compile_gem(assignments, prefix_ordering, scalar_type, remove_zeros=False):
"""Compiles GEM to Impero.
:arg assignments: list of (return variable, expression DAG root) pairs
:arg prefix_ordering: outermost loop indices
:arg scalar_type: default scalar type
:arg remove_zeros: remove zero assignment to return variables
"""
# Remove zeros
Expand All @@ -52,6 +57,9 @@ def nonzero(assignment):
return not isinstance(expression, gem.Zero)
assignments = list(filter(nonzero, assignments))

# Type inference for return value
return_variable = infer_dtype(assignments, scalar_type)

# Just the expressions
expressions = [expression for variable, expression in assignments]

Expand Down Expand Up @@ -88,7 +96,26 @@ def nonzero(assignment):
declare, indices = place_declarations(tree, temporaries, get_indices)

# Prepare ImperoC (Impero AST + other data for code generation)
return ImperoC(tree, temporaries, declare, indices)
return ImperoC(tree, temporaries, declare, indices, return_variable)


def infer_dtype(assignments, scalar_type):
from tsfc.loopy import assign_dtypes
from gem.node import traversal

def extract_variable(expr):
x, = set(v for v in traversal([expr]) if isinstance(v, gem.Variable))
return x

vars = set()
dtypes = set()
for var, expression in assignments:
var = extract_variable(var)
((_, dtype), ) = assign_dtypes([expression], scalar_type)
vars.add(var)
dtypes.add(dtype)
var, = vars
return var, find_common_type([], dtypes)


def make_prefix_ordering(indices, prefix_ordering):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy

from gem import impero_utils
from gem.gem import Index, Indexed, IndexSum, Product, Variable
Expand All @@ -18,7 +19,7 @@ def make_expression(i, j):
e2 = make_expression(i, i)

def gencode(expr):
impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j))
impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j), numpy.dtype(numpy.float64))
return impero_c.tree

assert len(gencode(e1).children) == len(gencode(e2).children)
Expand Down
6 changes: 4 additions & 2 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
for var in return_variables]))
index_ordering = tuple(quadrature_indices) + split_argument_indices
try:
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
impero_c = impero_utils.compile_gem(assignments, index_ordering,
parameters["scalar_type"], remove_zeros=True)
except impero_utils.NoopError:
# No operations, construct empty kernel
return builder.construct_empty_kernel(kernel_name)
Expand Down Expand Up @@ -421,7 +422,8 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, inte
# TODO: one should apply some GEM optimisations as in assembly,
# but we don't for now.
ir, = impero_utils.preprocess_gem([ir])
impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices)
impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices,
parameters["scalar_type"])
index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices))
# Handle kernel interface requirements
builder.register_requirements([ir])
Expand Down

0 comments on commit 54e2978

Please sign in to comment.