From 45f6d9eed5ed6175b2c7f60f99fe7887536921e3 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 10 Dec 2024 15:15:48 +0000 Subject: [PATCH] PhysicallyMappedElement: implement hand-rolled basis transformation (#115) * PhysicallyMappedElement: implement hand-rolled basis transformation * comments --- finat/aw.py | 8 -------- finat/bell.py | 4 ---- finat/fiat_elements.py | 2 +- finat/hct.py | 4 ---- finat/mtw.py | 4 ---- finat/physically_mapped.py | 16 +++++++++------- finat/piola_mapped.py | 4 ---- gem/gem.py | 2 +- 8 files changed, 11 insertions(+), 33 deletions(-) diff --git a/finat/aw.py b/finat/aw.py index 80d7b7ea..0ed11cbd 100644 --- a/finat/aw.py +++ b/finat/aw.py @@ -78,10 +78,6 @@ def entity_dofs(self): 1: {0: [0, 1, 2, 3], 1: [4, 5, 6, 7], 2: [8, 9, 10, 11]}, 2: {0: [12, 13, 14]}} - @property - def index_shape(self): - return (self.space_dimension(),) - def space_dimension(self): return 15 @@ -129,9 +125,5 @@ def entity_dofs(self): 1: {0: [9, 10, 11, 12], 1: [13, 14, 15, 16], 2: [17, 18, 19, 20]}, 2: {0: [21, 22, 23]}} - @property - def index_shape(self): - return (self.space_dimension(),) - def space_dimension(self): return 24 diff --git a/finat/bell.py b/finat/bell.py index ce7d0002..98d4aecc 100644 --- a/finat/bell.py +++ b/finat/bell.py @@ -73,9 +73,5 @@ def basis_transformation(self, coordinate_mapping): def entity_dofs(self): return self._entity_dofs - @property - def index_shape(self): - return (18,) - def space_dimension(self): return 18 diff --git a/finat/fiat_elements.py b/finat/fiat_elements.py index 1f208189..0203a3a7 100644 --- a/finat/fiat_elements.py +++ b/finat/fiat_elements.py @@ -79,7 +79,7 @@ def space_dimension(self): @property def index_shape(self): - return (self._element.space_dimension(),) + return (self.space_dimension(),) @property def value_shape(self): diff --git a/finat/hct.py b/finat/hct.py index f072d8ef..efcb131b 100644 --- a/finat/hct.py +++ b/finat/hct.py @@ -93,9 +93,5 @@ def basis_transformation(self, coordinate_mapping): def entity_dofs(self): return self._entity_dofs - @property - def index_shape(self): - return (9,) - def space_dimension(self): return 9 diff --git a/finat/mtw.py b/finat/mtw.py index 55d1a1f7..e7312afb 100644 --- a/finat/mtw.py +++ b/finat/mtw.py @@ -42,9 +42,5 @@ def basis_transformation(self, coordinate_mapping): def entity_dofs(self): return self._entity_dofs - @property - def index_shape(self): - return (self._space_dimension,) - def space_dimension(self): return self._space_dimension diff --git a/finat/physically_mapped.py b/finat/physically_mapped.py index 2ea49761..4b6c6089 100644 --- a/finat/physically_mapped.py +++ b/finat/physically_mapped.py @@ -268,15 +268,17 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None): assert coordinate_mapping is not None M = self.basis_transformation(coordinate_mapping) - M, = gem.optimise.constant_fold_zero((M,)) + # we expect M to be sparse with O(1) nonzeros per row + # for each row, get the column index of each nonzero entry + csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)] + for i in range(M.shape[0])] def matvec(table): - table, = gem.optimise.constant_fold_zero((table,)) - i, j = gem.indices(2) - value_indices = self.get_value_indices() - table = gem.Indexed(table, (j, ) + value_indices) - val = gem.ComponentTensor(gem.IndexSum(M[i, j]*table, (j,)), (i,) + value_indices) - # Eliminate zeros + # basis recombination using hand-rolled sparse-dense matrix multiplication + table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])] + # the sum approach is faster than calling numpy.dot or gem.IndexSum + expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)] + val = gem.ListTensor(expressions) return gem.optimise.aggressive_unroll(val) result = super().basis_evaluation(order, ps, entity=entity) diff --git a/finat/piola_mapped.py b/finat/piola_mapped.py index 18f51e99..c9557224 100644 --- a/finat/piola_mapped.py +++ b/finat/piola_mapped.py @@ -113,10 +113,6 @@ def __init__(self, fiat_element): def entity_dofs(self): return self._entity_dofs - @property - def index_shape(self): - return (self._space_dimension,) - def space_dimension(self): return self._space_dimension diff --git a/gem/gem.py b/gem/gem.py index 9f00d553..8369b6f7 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -969,7 +969,7 @@ def __new__(cls, i, j, dtype=None): # Fixed indices if isinstance(i, int) and isinstance(j, int): - return Literal(int(i == j)) + return one if i == j else Zero() self = super(Delta, cls).__new__(cls) self.i = i