Skip to content

Commit

Permalink
[query] Refactor and fix Backend's persist and unpersist (#12864)
Browse files Browse the repository at this point in the history
- There is now one persist method that accepts any of a Table,
  MatrixTable, or BlockMatrix.
- Backend.unpersist was incorrect with respect to unpersist docs, it was
  not accepting the persisted dataset, or returning the unpersisted
  dataset, or doing nothing if the dataset wasn't persisted.
  • Loading branch information
chrisvittal authored Apr 20, 2023
1 parent 970d446 commit cc8d364
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 43 deletions.
47 changes: 20 additions & 27 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from typing import Mapping, List, Union, Tuple, Dict, Optional, Any, AbstractSet
from typing import Mapping, List, Union, TypeVar, Tuple, Dict, Optional, Any, AbstractSet
import abc
import orjson
import pkg_resources
import zipfile

from hailtop.config.user_config import configuration_of

from ..fs.fs import FS
from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS
from ..expr import Expression
from ..expr.types import HailType
from ..fs.fs import FS
from ..ir import BaseIR
from ..linalg.blockmatrix import BlockMatrix
from ..matrixtable import MatrixTable
from ..table import Table
from ..utils.java import FatalError


Dataset = TypeVar('Dataset', Table, MatrixTable, BlockMatrix)


def fatal_error_from_java_error_triplet(short_message, expanded_message, error_id):
from .. import __version__
if error_id != -1:
Expand Down Expand Up @@ -156,32 +162,19 @@ def fs(self) -> FS:
def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str):
pass

def persist_table(self, t):
from hail.context import TemporaryFilename
tf = TemporaryFilename(prefix='persist_table')
self._persisted_locations[t] = tf
return t.checkpoint(tf.__enter__())

def unpersist_table(self, t):
try:
self._persisted_locations[t].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{t} is not persisted') from err

def persist_matrix_table(self, mt):
def persist(self, dataset: Dataset) -> Dataset:
from hail.context import TemporaryFilename
tf = TemporaryFilename(prefix='persist_matrix_table')
self._persisted_locations[mt] = tf
return mt.checkpoint(tf.__enter__())

def unpersist_matrix_table(self, mt):
try:
self._persisted_locations[mt].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{mt} is not persisted') from err

def unpersist_block_matrix(self, id):
pass
tempfile = TemporaryFilename(prefix=f'persist_{type(dataset).__name__}')
persisted = dataset.checkpoint(tempfile.__enter__())
self._persisted_locations[persisted] = (tempfile, dataset)
return persisted

def unpersist(self, dataset: Dataset) -> Dataset:
tempfile, unpersisted = self._persisted_locations.pop(dataset, (None, None))
if tempfile is None:
return dataset
tempfile.__exit__(None, None, None)
return unpersisted

@abc.abstractmethod
def register_ir_function(self,
Expand Down
14 changes: 4 additions & 10 deletions hail/python/hail/linalg/blockmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
TableFromBlockMatrixNativeReader, TableRead, BlockMatrixSlice,
BlockMatrixSparsify, BlockMatrixDensify, RectangleSparsifier,
RowIntervalSparsifier, BandSparsifier, PerBlockSparsifier)
from hail.ir.blockmatrix_reader import (BlockMatrixNativeReader,
BlockMatrixBinaryReader, BlockMatrixPersistReader)
from hail.ir.blockmatrix_reader import BlockMatrixNativeReader, BlockMatrixBinaryReader
from hail.ir.blockmatrix_writer import (BlockMatrixBinaryWriter,
BlockMatrixNativeWriter, BlockMatrixRectanglesWriter, BlockMatrixPersistWriter)
BlockMatrixNativeWriter, BlockMatrixRectanglesWriter)
from hail.ir import ExportType
from hail.table import Table
from hail.typecheck import (typecheck, typecheck_method, nullable, oneof,
Expand Down Expand Up @@ -1337,9 +1336,7 @@ def persist(self, storage_level='MEMORY_AND_DISK'):
:class:`.BlockMatrix`
Persisted block matrix.
"""
id = Env.get_uid()
Env.backend().execute(BlockMatrixWrite(self._bmir, BlockMatrixPersistWriter(id, storage_level)))
return BlockMatrix(BlockMatrixRead(BlockMatrixPersistReader(id, self._bmir), _assert_type=self._bmir._type))
return Env.backend().persist_blockmatrix(self)

def unpersist(self):
"""Unpersists this block matrix from memory/disk.
Expand All @@ -1354,10 +1351,7 @@ def unpersist(self):
:class:`.BlockMatrix`
Unpersisted block matrix.
"""
if isinstance(self._bmir, BlockMatrixRead) and isinstance(self._bmir.reader, BlockMatrixPersistReader):
Env.backend().unpersist_block_matrix(self._bmir.reader.id)
return self._bmir.reader.unpersisted()
return self
return Env.backend().unpersist_blockmatrix(self)

def __pos__(self):
return self
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3643,7 +3643,7 @@ def persist(self, storage_level: str = 'MEMORY_AND_DISK') -> 'MatrixTable':
:class:`.MatrixTable`
Persisted dataset.
"""
return Env.backend().persist_matrix_table(self)
return Env.backend().persist(self)

def unpersist(self) -> 'MatrixTable':
"""
Expand All @@ -3659,7 +3659,7 @@ def unpersist(self) -> 'MatrixTable':
:class:`.MatrixTable`
Unpersisted dataset.
"""
return Env.backend().unpersist_matrix_table(self)
return Env.backend().unpersist(self)

@typecheck_method(name=str)
def add_row_index(self, name: str = 'row_idx') -> 'MatrixTable':
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,7 +2124,7 @@ def persist(self, storage_level='MEMORY_AND_DISK') -> 'Table':
:class:`.Table`
Persisted table.
"""
return Env.backend().persist_table(self)
return Env.backend().persist(self)

def unpersist(self) -> 'Table':
"""
Expand All @@ -2140,7 +2140,7 @@ def unpersist(self) -> 'Table':
:class:`.Table`
Unpersisted table.
"""
return Env.backend().unpersist_table(self)
return Env.backend().unpersist(self)

@typecheck_method(_localize=bool, _timed=bool)
def collect(self, _localize=True, *, _timed=False):
Expand Down
2 changes: 0 additions & 2 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,6 @@ def test_parses(self):
for x in (self.blockmatrix_irs() + [persist]):
backend._parse_blockmatrix_ir(str(x))

backend.unpersist_block_matrix('x')


class ValueTests(unittest.TestCase):

Expand Down

0 comments on commit cc8d364

Please sign in to comment.