diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 5db9fde46e..501dee02e0 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -5,11 +5,11 @@ from ufl import as_vector from ufl.classes import Zero, FixedIndex, ListTensor from ufl.algorithms.map_integrands import map_integrand_dags -from ufl.algorithms.apply_coefficient_split import remove_component_and_list_tensors from ufl.corealg.map_dag import MultiFunction, map_expr_dags from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from tsfc.ufl_utils import remove_indices class ExtractSubBlock(MultiFunction): @@ -170,8 +170,8 @@ def split_form(form, diagonal=False, do_simplify=True): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - #if do_simplify: - # f = remove_component_and_list_tensors(f) + if do_simplify: + f = remove_indices(f) if len(f.integrals()) > 0: if diagonal: i, j = idx diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 3f3735b381..5b3249248f 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -238,9 +238,9 @@ def compile_form(form, name, parameters=None, split=True, dont_split=None, diago kernels = [] numbering = form.terminal_numbering() - all_meshes = extract_domains(form) + all_meshes_in_form = extract_domains(form) if split: - iterable = split_form(form, diagonal=diagonal, do_simplify=(len(all_meshes) > 1)) + iterable = split_form(form, diagonal=diagonal) else: nargs = len(form.arguments()) if diagonal: @@ -256,8 +256,8 @@ def compile_form(form, name, parameters=None, split=True, dont_split=None, diago continue # Map local domain/coefficient/constant numbers (as seen inside the # compiler) to the global coefficient/constant numbers - meshes = extract_domains(f) - domain_number_map = tuple(all_meshes.index(m) for m in meshes) + all_meshes_in_subform = extract_domains(f) + domain_number_map = tuple(all_meshes_in_form.index(m) for m in all_meshes_in_subform) coefficient_numbers = tuple( numbering[c] for c in f.coefficients() )