Skip to content

Commit

Permalink
Remove debug logs, add constant for transact_write_items, update tran…
Browse files Browse the repository at this point in the history
…saction limit, and change setitem signature for take value tuple
  • Loading branch information
KaspariK committed Oct 18, 2024
1 parent 7e775da commit 3b48d61
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from typing import DefaultDict
from typing import Dict
from typing import List
from typing import Literal
from typing import Sequence
from typing import Tuple
from typing import TypeVar

import boto3 # type: ignore
Expand All @@ -28,17 +30,16 @@
# TODO: Restore items


# Max DynamoDB object size is 400KB. Since we save two
# copies of the object (pickled and JSON), we need to
# consider this max size applies to the entire item, so
# we use a max size of 200KB for each version.
# Max DynamoDB object size is 400KB. Since we save two copies of the object (pickled and JSON),
# we need to consider this max size applies to the entire item, so we use a max size of 200KB
# for each version.
#
# In testing I could get away with 201_000 for both
# partitions so this should be enough overhead to
# contain the object name and other data.
# In testing I could get away with 201_000 for both partitions so this should be enough overhead
# to contain the object name and other data.
OBJECT_SIZE = 200_000 # TODO: config this to avoid rolling out new version when we swap back to 400_000?
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 10
MAX_TRANSACT_WRITE_ITEMS = 100 # Max number of items to write in a single transaction
log = logging.getLogger(__name__)
T = TypeVar("T")

Expand Down Expand Up @@ -154,7 +155,6 @@ def _merge_items(self, first_items, remaining_items) -> dict:
return deserialized_items

def save(self, key_value_pairs) -> None:
log.debug(f"Adding to save queue: {key_value_pairs}")
for key, val in key_value_pairs:
while True:
qlen = len(self.save_queue)
Expand All @@ -181,20 +181,20 @@ def _consume_save_queue(self):
for _ in range(qlen):
try:
with self.save_lock:
key, (pickled_val, json_val) = self.save_queue.popitem(last=False)
log.debug(f"Processing save for {key} with a value of {pickled_val}")
key, (val, json_val) = self.save_queue.popitem(last=False)
# Remove all previous data with the same partition key
# TODO: only remove excess partitions if new data has fewer
self._delete_item(key)
if pickled_val is not None:
self.__setitem__(key, pickle.dumps(pickled_val), json_val)
if val is not None:
self[key] = (pickle.dumps(val), json_val)
# reset errors count if we can successfully save
saved += 1
except Exception as e:
error = "tron_dynamodb_save_failure: failed to save key " f'"{key}" to dynamodb:\n{repr(e)}'
log.error(error)
# Add items back to the queue if we failed to save
with self.save_lock:
self.save_queue[key] = (pickled_val, json_val)
self.save_queue[key] = (val, json_val)
duration = time.time() - start
log.info(f"saved {saved} items in {duration}s")

Expand All @@ -206,7 +206,8 @@ def _consume_save_queue(self):
def get_type_from_key(self, key: str) -> str:
return key.split()[0]

def _serialize_item(self, key: str, state: Dict[str, Any]) -> str:
# TODO: TRON-2305 - In an ideal world, we wouldn't be passing around state/state_data dicts. It would be a lot nicer to have regular objects here
def _serialize_item(self, key: Literal[runstate.JOB_STATE, runstate.JOB_RUN_STATE], state: Dict[str, Any]) -> str: # type: ignore
if key == runstate.JOB_STATE:
return Job.to_json(state)
elif key == runstate.JOB_RUN_STATE:
Expand All @@ -230,24 +231,21 @@ def _save_loop(self):
log.error("too many dynamodb errors in a row, crashing")
os.exit(1)

def __setitem__(self, key: str, pickled_val: bytes, json_val: str) -> None:
def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"""
Partition the item and write up to 10 partitions atomically.
Retry up to 3 times on failure
Partition the item and write up to MAX_TRANSACT_WRITE_ITEMS
partitions atomically. Retry up to 3 times on failure.
Examine the size of `pickled_val` and `json_val`, and
splice them into different parts based on `OBJECT_SIZE`
with different sort keys, and save them under the same
partition key built.
"""
start = time.time()

pickled_val, json_val = value
num_partitions = math.ceil(len(pickled_val) / OBJECT_SIZE)
num_json_val_partitions = math.ceil(len(json_val) / OBJECT_SIZE)

log.debug(
f"Saving key: {key} with {num_partitions} pickle partitions and {num_json_val_partitions} json partitions"
)

items = []

# Use the maximum number of partitions (JSON can be larger
Expand Down Expand Up @@ -281,11 +279,11 @@ def __setitem__(self, key: str, pickled_val: bytes, json_val: str) -> None:
"TableName": self.name,
},
}

count = 0
items.append(item)
# Only up to 10 items are allowed per transactions
# TODO: transact_write_items can take up to 100 items now
while len(items) == 10 or index == max_partitions - 1:

while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
try:
self.client.transact_write_items(TransactItems=items)
items = []
Expand Down

0 comments on commit 3b48d61

Please sign in to comment.