Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update delete_item logic to handle json partitions #1018

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,37 @@ def test_delete_item(self, store, small_object, large_object):
value = large_object
pairs = list(zip(keys, (value for i in range(len(keys)))))
store.save(pairs)
store._consume_save_queue()

for key, value in pairs:
store._delete_item(key)

for key, value in pairs:
assert_equal(store._get_num_of_partitions(key), 0)
num_partitions, num_json_val_partitions = store._get_num_of_partitions(key)
assert_equal(num_partitions, 0)
assert_equal(num_json_val_partitions, 0)

def test_delete_item_with_json_partitions(self, store, small_object, large_object):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe it would make more sense to test with the larger object here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo, we have both as fixtures in this test already so we could easily test with both :)

key = store.build_key("job_state", "test_job")
value = large_object

store.save([(key, value)])
store._consume_save_queue()

num_partitions, num_json_val_partitions = store._get_num_of_partitions(key)
assert num_partitions > 0
assert num_json_val_partitions > 0
nemacysts marked this conversation as resolved.
Show resolved Hide resolved

store._delete_item(key)

num_partitions, num_json_val_partitions = store._get_num_of_partitions(key)
assert_equal(num_partitions, 0)
assert_equal(num_json_val_partitions, 0)

with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch(
"tron.config.static_config.build_configuration_watcher", autospec=True
):
vals = store.restore([key])
assert key not in vals

def test_retry_saving(self, store, small_object, large_object):
with mock.patch(
Expand Down
19 changes: 10 additions & 9 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _get_first_partitions(self, keys: list):
new_keys = [{"key": {"S": key}, "index": {"N": "0"}} for key in keys]
return self._get_items(new_keys)

# TODO: Check max partitions as JSON is larger
def _get_remaining_partitions(self, items: list, read_json: bool):
"""Get items in the remaining partitions: N = 1 and beyond"""
keys_for_remaining_items = []
Expand Down Expand Up @@ -357,12 +356,13 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
delta=time.time() - start,
)

# TODO: TRON-2238 - Is this ok if we just use the max number of partitions?
def _delete_item(self, key: str) -> None:
start = time.time()
try:
num_partitions, num_json_val_partitions = self._get_num_of_partitions(key)
max_partitions = max(num_partitions, num_json_val_partitions)
with self.table.batch_writer() as batch:
for index in range(self._get_num_of_partitions(key)):
for index in range(max_partitions):
batch.delete_item(
Key={
"key": key,
Expand All @@ -375,23 +375,24 @@ def _delete_item(self, key: str) -> None:
delta=time.time() - start,
)

# TODO: TRON-2238 - Get max partitions between pickle and json
def _get_num_of_partitions(self, key: str) -> int:
def _get_num_of_partitions(self, key: str) -> Tuple[int, int]:
"""
Return the number of partitions an item is divided into.
Return the number of partitions an item is divided into for both pickled and JSON data.
"""
try:
partition = self.table.get_item(
Key={
"key": key,
"index": 0,
},
ProjectionExpression="num_partitions",
ProjectionExpression="num_partitions, num_json_val_partitions",
ConsistentRead=True,
)
return int(partition.get("Item", {}).get("num_partitions", 0))
num_partitions = int(partition.get("Item", {}).get("num_partitions", 0))
num_json_val_partitions = int(partition.get("Item", {}).get("num_json_val_partitions", 0))
return num_partitions, num_json_val_partitions
except self.client.exceptions.ResourceNotFoundException:
return 0
return 0, 0

def cleanup(self) -> None:
self.stopping = True
Expand Down