Skip to content

Commit

Permalink
Merge pull request #1244 from firedrakeproject/blocks
Browse files Browse the repository at this point in the history
* blocks:
  Rename method: block -> blocks
  • Loading branch information
wence- committed Jul 5, 2018
2 parents 92431bf + b7bc6e0 commit 6f6abba
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
14 changes: 8 additions & 6 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ def solve(self, B, decomposition=None):
return Solve(self, B, decomposition=decomposition)

@cached_property
def block(self):
"""Return a block of the tensor defined on the component spaces
described by indices.
def blocks(self):
"""Returns an object containing the blocks of the tensor defined
on a mixed space. Indices can then be provided to extract a
particular sub-block.
For example, consider the rank-2 tensor described by:
Expand All @@ -235,14 +236,15 @@ def block(self):
w, q, s = TestFunctions(W)
A = Tensor(u*w*dx + p*q*dx + r*s*dx)
The tensor `A` has 3x3 block structure. The the block defined
The tensor `A` has 3x3 block structure. The block defined
by the form `u*w*dx` could be extracted with:
.. code-block:: python
A.block[0, 0]
A.blocks[0, 0]
While the block coupling `p`, `r`, `q`, and `s` could be extracted with:
While the block coupling `p`, `r`, `q`, and `s` could be
extracted with:
.. code-block:: python
Expand Down
29 changes: 17 additions & 12 deletions tests/slate/test_assemble_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def test_vector_subblocks(mesh):
K = Tensor(inner(u, v)*dx + inner(phi, psi)*dx + inner(eta, nu)*dx)
F = Tensor(inner(q, v)*dx + inner(p, psi)*dx + inner(r, nu)*dx)
E = K.inv * F
items = [(E.block[0], q), (E.block[1], p), (E.block[2], r)]
_E = E.blocks
items = [(_E[0], q), (_E[1], p), (_E[2], r)]

for tensor, ref in items:
assert np.allclose(assemble(tensor).dat.data, ref.dat.data, rtol=1e-14)
Expand All @@ -211,24 +212,28 @@ def test_matrix_subblocks(mesh):
# Test individual blocks
indices = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2)]
refs = dict(split_form(A.form))
_A = A.blocks
for x, y in indices:
ref = assemble(refs[x, y]).M.values
block = A.block[x, y]
block = _A[x, y]
assert np.allclose(assemble(block).M.values, ref, rtol=1e-14)

# Mixed blocks
A0101 = A.block[:2, :2]
A1212 = A.block[1:3, 1:3]
A0101 = _A[:2, :2]
A1212 = _A[1:3, 1:3]

_A0101 = A0101.blocks
_A1212 = A1212.blocks

# Block of blocks
A0101_00 = A0101.block[0, 0]
A0101_11 = A0101.block[1, 1]
A0101_01 = A0101.block[0, 1]
A0101_10 = A0101.block[1, 0]
A1212_00 = A1212.block[0, 0]
A1212_11 = A1212.block[1, 1]
A1212_01 = A1212.block[0, 1]
A1212_10 = A1212.block[1, 0]
A0101_00 = _A0101[0, 0]
A0101_11 = _A0101[1, 1]
A0101_01 = _A0101[0, 1]
A0101_10 = _A0101[1, 0]
A1212_00 = _A1212[0, 0]
A1212_11 = _A1212[1, 1]
A1212_01 = _A1212[0, 1]
A1212_10 = _A1212[1, 0]

items = [(A0101_00, refs[(0, 0)]),
(A0101_11, refs[(1, 1)]),
Expand Down
49 changes: 26 additions & 23 deletions tests/slate/test_slate_infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,33 +221,36 @@ def test_blocks(zero_rank_tensor, mixed_matrix, mixed_vector):
a = M.form
L = F.form
splitter = ExtractSubBlock()
M00 = M.block[0, 0]
M11 = M.block[1, 1]
M22 = M.block[2, 2]
M0101 = M.block[:2, :2]
M012 = M.block[:2, 2]
M201 = M.block[2, :2]
F0 = F.block[0]
F1 = F.block[1]
F2 = F.block[2]
F01 = F.block[:2]
F12 = F.block[1:3]
_M = M.blocks
M00 = _M[0, 0]
M11 = _M[1, 1]
M22 = _M[2, 2]
M0101 = _M[:2, :2]
M012 = _M[:2, 2]
M201 = _M[2, :2]

_F = F.blocks
F0 = _F[0]
F1 = _F[1]
F2 = _F[2]
F01 = _F[:2]
F12 = _F[1:3]

# Test index checking
with pytest.raises(ValueError):
S.block[0]
S.blocks[0]

with pytest.raises(ValueError):
F.block[0, 1]
_F[0, 1]

with pytest.raises(ValueError):
M.block[0:5, 2]
_M[0:5, 2]

with pytest.raises(ValueError):
M.block[3, 3]
_M[3, 3]

with pytest.raises(ValueError):
F.block[3]
_F[3]

# Check Tensor is (not) mixed where appropriate
assert not M00.is_mixed
Expand All @@ -263,13 +266,13 @@ def test_blocks(zero_rank_tensor, mixed_matrix, mixed_vector):
assert F12.is_mixed

# Taking blocks of non-mixed block (or scalars) should induce a no-op
assert S.block[None] == S # This is silly, but it's technically a no-op
assert M00.block[0, 0] == M00
assert M11.block[0, 0] == M11
assert M22.block[0, 0] == M22
assert F0.block[0] == F0
assert F1.block[0] == F1
assert F2.block[0] == F2
assert S.blocks[None] == S # This is silly, but it's technically a no-op
assert M00.blocks[0, 0] == M00
assert M11.blocks[0, 0] == M11
assert M22.blocks[0, 0] == M22
assert F0.blocks[0] == F0
assert F1.blocks[0] == F1
assert F2.blocks[0] == F2

# Test arguments
assert M00.arguments() == splitter.split(a, (0, 0)).arguments()
Expand Down

0 comments on commit 6f6abba

Please sign in to comment.