Skip to content

Commit

Permalink
PhysicallyMappedElement: implement hand-rolled basis transformation (#…
Browse files Browse the repository at this point in the history
…115)

* PhysicallyMappedElement: implement hand-rolled basis transformation

* comments
  • Loading branch information
pbrubeck authored Dec 10, 2024
1 parent acbd449 commit 45f6d9e
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 33 deletions.
8 changes: 0 additions & 8 deletions finat/aw.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def entity_dofs(self):
1: {0: [0, 1, 2, 3], 1: [4, 5, 6, 7], 2: [8, 9, 10, 11]},
2: {0: [12, 13, 14]}}

@property
def index_shape(self):
return (self.space_dimension(),)

def space_dimension(self):
return 15

Expand Down Expand Up @@ -129,9 +125,5 @@ def entity_dofs(self):
1: {0: [9, 10, 11, 12], 1: [13, 14, 15, 16], 2: [17, 18, 19, 20]},
2: {0: [21, 22, 23]}}

@property
def index_shape(self):
return (self.space_dimension(),)

def space_dimension(self):
return 24
4 changes: 0 additions & 4 deletions finat/bell.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,5 @@ def basis_transformation(self, coordinate_mapping):
def entity_dofs(self):
return self._entity_dofs

@property
def index_shape(self):
return (18,)

def space_dimension(self):
return 18
2 changes: 1 addition & 1 deletion finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def space_dimension(self):

@property
def index_shape(self):
return (self._element.space_dimension(),)
return (self.space_dimension(),)

@property
def value_shape(self):
Expand Down
4 changes: 0 additions & 4 deletions finat/hct.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,5 @@ def basis_transformation(self, coordinate_mapping):
def entity_dofs(self):
return self._entity_dofs

@property
def index_shape(self):
return (9,)

def space_dimension(self):
return 9
4 changes: 0 additions & 4 deletions finat/mtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,5 @@ def basis_transformation(self, coordinate_mapping):
def entity_dofs(self):
return self._entity_dofs

@property
def index_shape(self):
return (self._space_dimension,)

def space_dimension(self):
return self._space_dimension
16 changes: 9 additions & 7 deletions finat/physically_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,17 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
assert coordinate_mapping is not None

M = self.basis_transformation(coordinate_mapping)
M, = gem.optimise.constant_fold_zero((M,))
# we expect M to be sparse with O(1) nonzeros per row
# for each row, get the column index of each nonzero entry
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
for i in range(M.shape[0])]

def matvec(table):
table, = gem.optimise.constant_fold_zero((table,))
i, j = gem.indices(2)
value_indices = self.get_value_indices()
table = gem.Indexed(table, (j, ) + value_indices)
val = gem.ComponentTensor(gem.IndexSum(M[i, j]*table, (j,)), (i,) + value_indices)
# Eliminate zeros
# basis recombination using hand-rolled sparse-dense matrix multiplication
table = [gem.partial_indexed(table, (j,)) for j in range(M.shape[1])]
# the sum approach is faster than calling numpy.dot or gem.IndexSum
expressions = [sum(M.array[i, j] * table[j] for j in js) for i, js in enumerate(csr)]
val = gem.ListTensor(expressions)
return gem.optimise.aggressive_unroll(val)

result = super().basis_evaluation(order, ps, entity=entity)
Expand Down
4 changes: 0 additions & 4 deletions finat/piola_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def __init__(self, fiat_element):
def entity_dofs(self):
return self._entity_dofs

@property
def index_shape(self):
return (self._space_dimension,)

def space_dimension(self):
return self._space_dimension

Expand Down
2 changes: 1 addition & 1 deletion gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def __new__(cls, i, j, dtype=None):

# Fixed indices
if isinstance(i, int) and isinstance(j, int):
return Literal(int(i == j))
return one if i == j else Zero()

self = super(Delta, cls).__new__(cls)
self.i = i
Expand Down

0 comments on commit 45f6d9e

Please sign in to comment.