From ac4f3fb1d84311bc4ed568093c73bb0ae0622801 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 26 Jul 2023 20:53:46 -0700 Subject: [PATCH 01/13] create device and store keys in same call --- synapse/handlers/device.py | 19 +++++++++++++++++-- synapse/rest/client/devices.py | 9 +-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f3a713f5fa77..ebcb95a6619b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -41,6 +41,7 @@ run_as_background_process, wrap_as_background_process, ) +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.types import ( JsonDict, StrCollection, @@ -656,15 +657,17 @@ async def store_dehydrated_device( device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, + device_keys: Optional[JsonDict] = None, ) -> str: - """Store a dehydrated device for a user. If the user had a previous - dehydrated device, it is removed. + """Store a dehydrated device for a user, optionally storing the keys associated with + it as well. If the user had a previous dehydrated device, it is removed. Args: user_id: the user that we are storing the device for device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device + device_keys: keys for the dehydrated device Returns: device id of the dehydrated device """ @@ -678,6 +681,18 @@ async def store_dehydrated_device( ) if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) + + # we do this here to avoid a circular import + if self.hs.config.worker.worker_app is None: + # if main process + key_uploader = self.hs.get_e2e_keys_handler().upload_keys_for_user + else: + # if worker process + key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(self.hs) + + # if keys are provided store them + if device_keys: + await key_uploader(user_id=user_id, device_id=device_id, keys=device_keys) return device_id async def rehydrate_device( diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 690d2ec406fc..39ff29f611fb 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -536,7 +536,6 @@ class Config: async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: submission = parse_and_validate_json_object_from_request(request, self.PutBody) requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() device_info = submission.dict() if "device_keys" not in device_info.keys(): @@ -545,18 +544,12 @@ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "Device key(s) not found, these must be provided.", ) - # TODO: Those two operations, creating a device and storing the - # device's keys should be atomic. device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), submission.device_id, submission.device_data.dict(), submission.initial_device_display_name, - ) - - # TODO: Do we need to do something with the result here? - await self.key_uploader( - user_id=user_id, device_id=submission.device_id, keys=submission.dict() + device_info, ) return 200, {"device_id": device_id} From ceed144eb293ad395e29c4c57fe96a30c13854c3 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 26 Jul 2023 20:54:10 -0700 Subject: [PATCH 02/13] no longer delete to-device messages after retrieval --- synapse/handlers/devicemessage.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 15e94a03cbe7..17ff8821d974 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -367,19 +367,6 @@ async def get_events_for_dehydrated_device( errcode=Codes.INVALID_PARAM, ) - # if we have a since token, delete any to-device messages before that token - # (since we now know that the device has received them) - deleted = await self.store.delete_messages_for_device( - user_id, device_id, since_stream_id - ) - logger.debug( - "Deleted %d to-device messages up to %d for user_id %s device_id %s", - deleted, - since_stream_id, - user_id, - device_id, - ) - to_token = self.event_sources.get_current_token().to_device_key messages, stream_id = await self.store.get_messages_for_device( From 8a7db881a474c5bf1ae4e0a781b59ac6d22e8ece Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 26 Jul 2023 21:17:02 -0700 Subject: [PATCH 03/13] update tests --- tests/handlers/test_device.py | 9 +++++---- tests/rest/client/test_devices.py | 32 +++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 647ee0927984..e1e58fa6e648 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -566,15 +566,16 @@ def test_dehydrate_v2_and_fetch_events(self) -> None: self.assertEqual(len(res["events"]), 1) self.assertEqual(res["events"][0]["content"]["body"], "foo") - # Fetch the message of the dehydrated device again, which should return nothing - # and delete the old messages + # Fetch the message of the dehydrated device again, which should return + # the same message as it has not been deleted res = self.get_success( self.message_handler.get_events_for_dehydrated_device( requester=requester, device_id=stored_dehydrated_device_id, - since_token=res["next_batch"], + since_token=None, limit=10, ) ) self.assertTrue(len(res["next_batch"]) > 1) - self.assertEqual(len(res["events"]), 0) + self.assertEqual(len(res["events"]), 1) + self.assertEqual(res["events"][0]["content"]["body"], "foo") diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index b7d420cfec02..9e94b2ab2499 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -312,6 +312,23 @@ def test_dehydrate_msc3814(self) -> None: } self.assertEqual(device_data, expected_device_data) + # test that the keys are correctly uploaded + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device1"], + }, + }, + token, + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["device_keys"][user][device_id]["keys"], + content["device_keys"]["keys"], + ) + # create another device for the user ( new_device_id, @@ -348,10 +365,21 @@ def test_dehydrate_msc3814(self) -> None: self.assertEqual(channel.code, 200) expected_content = {"body": "test_message"} self.assertEqual(channel.json_body["events"][0]["content"], expected_content) + + # fetch messages again and make sure that the message was not deleted + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content={}, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["events"][0]["content"], expected_content) next_batch_token = channel.json_body.get("next_batch") - # fetch messages again and make sure that the message was deleted and we are returned an - # empty array + # make sure fetching messages with next batch token works - there are no unfetched + # messages so we should receive an empty array content = {"next_batch": next_batch_token} channel = self.make_request( "POST", From 217e2ebcac4ec4e76ed11f05e03b68a07901a863 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 26 Jul 2023 21:34:52 -0700 Subject: [PATCH 04/13] newsfragment --- changelog.d/16010.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/16010.misc diff --git a/changelog.d/16010.misc b/changelog.d/16010.misc new file mode 100644 index 000000000000..1e1a14806910 --- /dev/null +++ b/changelog.d/16010.misc @@ -0,0 +1 @@ +Update dehydrated devices implementation. From 61818c8bd2e275b585880f5ac8132af2ba287c57 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 2 Aug 2023 20:11:05 -0700 Subject: [PATCH 05/13] add a function to check and prepare keys for insertion when creating dehydrated device --- synapse/handlers/device.py | 144 ++++++++++++++++++++++++++++++++----- 1 file changed, 128 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ebcb95a6619b..846665372130 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -26,6 +26,8 @@ Tuple, ) +from canonicaljson import encode_canonical_json + from synapse.api import errors from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( @@ -41,7 +43,6 @@ run_as_background_process, wrap_as_background_process, ) -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.types import ( JsonDict, StrCollection, @@ -386,6 +387,7 @@ def __init__(self, hs: "HomeServer"): self.federation_sender = hs.get_federation_sender() self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() + self.db_pool = hs.get_datastores().main.db_pool self.device_list_updater = DeviceListUpdater(hs, self) @@ -657,7 +659,7 @@ async def store_dehydrated_device( device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, - device_keys: Optional[JsonDict] = None, + keys_for_device: Optional[JsonDict] = None, ) -> str: """Store a dehydrated device for a user, optionally storing the keys associated with it as well. If the user had a previous dehydrated device, it is removed. @@ -667,7 +669,7 @@ async def store_dehydrated_device( device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device - device_keys: keys for the dehydrated device + keys_for_device: keys for the dehydrated device Returns: device id of the dehydrated device """ @@ -676,24 +678,134 @@ async def store_dehydrated_device( device_id, initial_device_display_name, ) - old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data - ) + + time_now = self.clock.time_msec() + + if keys_for_device: + keys = await self._check_and_prepare_keys_for_dehydrated_device( + user_id, device_id, keys_for_device + ) + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data, time_now, keys + ) + else: + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data, time_now + ) + if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) - # we do this here to avoid a circular import - if self.hs.config.worker.worker_app is None: - # if main process - key_uploader = self.hs.get_e2e_keys_handler().upload_keys_for_user - else: - # if worker process - key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(self.hs) + return device_id - # if keys are provided store them + async def _check_and_prepare_keys_for_dehydrated_device( + self, user_id: str, device_id: str, keys: JsonDict + ) -> dict: + """ + Check if any of the provided keys are duplicate and raise if they are, + prepare keys for insertion in DB + + Args: + user_id: user to store keys for + device_id: the dehydrated device to store keys for + keys: the keys - device_keys, onetime_keys, or fallback keys to store + + Returns: + keys that have been checked for duplicates and are ready to be inserted into + DB + """ + keys_to_return: dict = {} + device_keys = keys.get("device_keys", None) if device_keys: - await key_uploader(user_id=user_id, device_id=device_id, keys=device_keys) - return device_id + old_key_json = await self.db_pool.simple_select_one_onecol( + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_device_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_device_key_json: + raise SynapseError( + 400, + f"Device key for user_id: {user_id}, device_id {device_id} already stored.", + ) + + keys_to_return["device_keys"] = new_device_key_json + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + # import this here to avoid a circular import + from synapse.handlers.e2e_keys import _one_time_keys_match + + # make a list of (alg, id, key) tuples + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append((algorithm, key_id, key_obj)) + + # First we check if we have already persisted any of the keys. + existing_key_map = await self.store.get_e2e_one_time_keys( + user_id, device_id, [k_id for _, k_id, _ in key_list] + ) + + new_one_time_keys = ( + [] + ) # Keys that we need to insert. (alg, id, json) tuples. + for algorithm, key_id, key in key_list: + ex_json = existing_key_map.get((algorithm, key_id), None) + if ex_json: + if not _one_time_keys_match(ex_json, key): + raise SynapseError( + 400, + ( + "One time key %s:%s already exists. " + "Old key: %s; new key: %r" + ) + % (algorithm, key_id, ex_json, key), + ) + else: + new_one_time_keys.append( + (algorithm, key_id, encode_canonical_json(key).decode("ascii")) + ) + keys_to_return["one_time_keys"] = new_one_time_keys + + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys: + new_fallback_keys = {} + # there should actually only be one item in the dict but we iterate nevertheless - + # see _set_e2e_fallback_keys_txn + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + old_key_json = await self.db_pool.simple_select_one_onecol( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcol="key_json", + allow_none=True, + ) + + new_fallback_key_json = encode_canonical_json(fallback_key).decode( + "utf-8" + ) + + # If the uploaded key is the same as the current fallback key, + # don't do anything. This prevents marking the key as unused if it + # was already used. + if old_key_json == new_fallback_key_json: + raise SynapseError( + 400, f"Fallback key {old_key_json} already exists." + ) + # TODO: should this be an update? it assumes that there will only be one fallback key + new_fallback_keys[f"{algorithm}:{key_id}"] = fallback_key + keys_to_return["fallback_keys"] = new_fallback_keys + return keys_to_return async def rehydrate_device( self, user_id: str, access_token: str, device_id: str From ec98b6a7b7bb712cee1277c85075b9117e2aa7d5 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 2 Aug 2023 20:11:41 -0700 Subject: [PATCH 06/13] store keys as part of DB transaction that creates dehydrated device --- synapse/storage/databases/main/devices.py | 73 ++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d9df437e518a..80fd7e12877e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1188,8 +1188,66 @@ async def get_dehydrated_device( ) def _store_dehydrated_device_txn( - self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + device_data: str, + time: int, + keys: Optional[JsonDict] = None, ) -> Optional[str]: + # TODO: make keys non-optional once support for msc2697 is dropped + if keys: + device_keys = keys.get("device_keys", None) + if device_keys: + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time, "key_json": device_keys}, + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), + values=[ + (user_id, device_id, algorithm, key_id, time, json_bytes) + for algorithm, key_id, json_bytes in one_time_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys: + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + self.db_pool.simple_upsert_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + ) + old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, table="dehydrated_devices", @@ -1203,10 +1261,16 @@ def _store_dehydrated_device_txn( keyvalues={"user_id": user_id}, values={"device_id": device_id, "device_data": device_data}, ) + return old_device_id async def store_dehydrated_device( - self, user_id: str, device_id: str, device_data: JsonDict + self, + user_id: str, + device_id: str, + device_data: JsonDict, + time_now: int, + keys: Optional[dict] = None, ) -> Optional[str]: """Store a dehydrated device for a user. @@ -1214,15 +1278,20 @@ async def store_dehydrated_device( user_id: the user that we are storing the device for device_id: the ID of the dehydrated device device_data: the dehydrated device information + time_now: current time at the request + keys: keys for the dehydrated device Returns: device id of the user's previous dehydrated device, if any """ + return await self.db_pool.runInteraction( "store_dehydrated_device_txn", self._store_dehydrated_device_txn, user_id, device_id, json_encoder.encode(device_data), + time_now, + keys, ) async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: From ef2099d398f327ebb6cf706745c215da16372ccb Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 2 Aug 2023 20:12:03 -0700 Subject: [PATCH 07/13] add some tests to verify that keys are being stored properly --- tests/rest/client/test_devices.py | 45 ++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index 9e94b2ab2499..7b100602c434 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -20,7 +20,7 @@ from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, keys, login, register from synapse.server import HomeServer -from synapse.types import JsonDict, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -282,6 +282,17 @@ def test_dehydrate_msc3814(self) -> None: "": {":": ""} }, }, + "fallback_keys": { + "alg1:device1": "f4llb4ckk3y", + "signed_:": { + "fallback": "true", + "key": "f4llb4ckk3y", + "signatures": { + "": {":": ""} + }, + }, + }, + "one_time_keys": {"alg1:k1": "0net1m3k3y"}, } channel = self.make_request( "PUT", @@ -328,6 +339,38 @@ def test_dehydrate_msc3814(self) -> None: channel.json_body["device_keys"][user][device_id]["keys"], content["device_keys"]["keys"], ) + # first claim should return the onetime key we uploaded + res = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}}, + }, + ) + # second claim should return fallback key + res2 = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}}, + }, + ) # create another device for the user ( From 8ca4c1fee47f098e45939adef1a1f21fdf2748b5 Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 3 Aug 2023 11:14:38 -0700 Subject: [PATCH 08/13] Update synapse/handlers/device.py Co-authored-by: Patrick Cloke --- synapse/handlers/device.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 846665372130..74affc8ee7f6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -685,13 +685,11 @@ async def store_dehydrated_device( keys = await self._check_and_prepare_keys_for_dehydrated_device( user_id, device_id, keys_for_device ) - old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data, time_now, keys - ) else: - old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data, time_now - ) + keys = None + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data, time_now, key + ) if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) From c3cb6465244a810b69461bc4ee920b0941c21eb4 Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 3 Aug 2023 11:16:08 -0700 Subject: [PATCH 09/13] Apply suggestions from code review Co-authored-by: Patrick Cloke --- synapse/storage/databases/main/devices.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 80fd7e12877e..3f23695b077e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1278,8 +1278,9 @@ async def store_dehydrated_device( user_id: the user that we are storing the device for device_id: the ID of the dehydrated device device_data: the dehydrated device information - time_now: current time at the request + time_now: current time at the request in milliseconds keys: keys for the dehydrated device + Returns: device id of the user's previous dehydrated device, if any """ From f14c51d7ca80a7e72bbef3a62cdafd7560f47dc9 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Thu, 3 Aug 2023 11:37:39 -0700 Subject: [PATCH 10/13] requested changes --- synapse/handlers/device.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 74affc8ee7f6..f2ed7214b3b6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -687,8 +687,9 @@ async def store_dehydrated_device( ) else: keys = None + old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data, time_now, key + user_id, device_id, device_data, time_now, keys ) if old_device_id is not None: @@ -705,12 +706,15 @@ async def _check_and_prepare_keys_for_dehydrated_device( Args: user_id: user to store keys for - device_id: the dehydrated device to store keys for - keys: the keys - device_keys, onetime_keys, or fallback keys to store + device_id: the device_id of the dehydrated device to store keys for + keys: a dict of device_keys, onetime_keys, or fallback keys provided by the client + to store associated with the dehydrated device id. Consists of pairs where + the key is the key type (i.e. device_key, onetime_key or fallback_key) + and the value is the key of that respective type. Returns: - keys that have been checked for duplicates and are ready to be inserted into - DB + A dict where the keys are key type (i.e. device_key, onetime_key or fallback_key) + and values are the respective keys of that type """ keys_to_return: dict = {} device_keys = keys.get("device_keys", None) From 9e3de14e1371809cccf4fc28f52bcb5aae96f8ae Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Mon, 7 Aug 2023 14:53:14 -0700 Subject: [PATCH 11/13] refactor to allow for calling key storage methods from _store_dehydrated_device_txn and do so --- synapse/handlers/device.py | 123 +------------ synapse/storage/databases/main/devices.py | 57 ++---- .../storage/databases/main/end_to_end_keys.py | 169 ++++++++++++------ 3 files changed, 130 insertions(+), 219 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f2ed7214b3b6..f0a00fbeda96 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -26,8 +26,6 @@ Tuple, ) -from canonicaljson import encode_canonical_json - from synapse.api import errors from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( @@ -681,15 +679,8 @@ async def store_dehydrated_device( time_now = self.clock.time_msec() - if keys_for_device: - keys = await self._check_and_prepare_keys_for_dehydrated_device( - user_id, device_id, keys_for_device - ) - else: - keys = None - old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data, time_now, keys + user_id, device_id, device_data, time_now, keys_for_device ) if old_device_id is not None: @@ -697,118 +688,6 @@ async def store_dehydrated_device( return device_id - async def _check_and_prepare_keys_for_dehydrated_device( - self, user_id: str, device_id: str, keys: JsonDict - ) -> dict: - """ - Check if any of the provided keys are duplicate and raise if they are, - prepare keys for insertion in DB - - Args: - user_id: user to store keys for - device_id: the device_id of the dehydrated device to store keys for - keys: a dict of device_keys, onetime_keys, or fallback keys provided by the client - to store associated with the dehydrated device id. Consists of pairs where - the key is the key type (i.e. device_key, onetime_key or fallback_key) - and the value is the key of that respective type. - - Returns: - A dict where the keys are key type (i.e. device_key, onetime_key or fallback_key) - and values are the respective keys of that type - """ - keys_to_return: dict = {} - device_keys = keys.get("device_keys", None) - if device_keys: - old_key_json = await self.db_pool.simple_select_one_onecol( - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_device_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_device_key_json: - raise SynapseError( - 400, - f"Device key for user_id: {user_id}, device_id {device_id} already stored.", - ) - - keys_to_return["device_keys"] = new_device_key_json - - one_time_keys = keys.get("one_time_keys", None) - if one_time_keys: - # import this here to avoid a circular import - from synapse.handlers.e2e_keys import _one_time_keys_match - - # make a list of (alg, id, key) tuples - key_list = [] - for key_id, key_obj in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append((algorithm, key_id, key_obj)) - - # First we check if we have already persisted any of the keys. - existing_key_map = await self.store.get_e2e_one_time_keys( - user_id, device_id, [k_id for _, k_id, _ in key_list] - ) - - new_one_time_keys = ( - [] - ) # Keys that we need to insert. (alg, id, json) tuples. - for algorithm, key_id, key in key_list: - ex_json = existing_key_map.get((algorithm, key_id), None) - if ex_json: - if not _one_time_keys_match(ex_json, key): - raise SynapseError( - 400, - ( - "One time key %s:%s already exists. " - "Old key: %s; new key: %r" - ) - % (algorithm, key_id, ex_json, key), - ) - else: - new_one_time_keys.append( - (algorithm, key_id, encode_canonical_json(key).decode("ascii")) - ) - keys_to_return["one_time_keys"] = new_one_time_keys - - fallback_keys = keys.get("fallback_keys", None) - if fallback_keys: - new_fallback_keys = {} - # there should actually only be one item in the dict but we iterate nevertheless - - # see _set_e2e_fallback_keys_txn - for key_id, fallback_key in fallback_keys.items(): - algorithm, key_id = key_id.split(":", 1) - old_key_json = await self.db_pool.simple_select_one_onecol( - table="e2e_fallback_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - }, - retcol="key_json", - allow_none=True, - ) - - new_fallback_key_json = encode_canonical_json(fallback_key).decode( - "utf-8" - ) - - # If the uploaded key is the same as the current fallback key, - # don't do anything. This prevents marking the key as unused if it - # was already used. - if old_key_json == new_fallback_key_json: - raise SynapseError( - 400, f"Fallback key {old_key_json} already exists." - ) - # TODO: should this be an update? it assumes that there will only be one fallback key - new_fallback_keys[f"{algorithm}:{key_id}"] = fallback_key - keys_to_return["fallback_keys"] = new_fallback_keys - return keys_to_return - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3f23695b077e..e4162f846b11 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ cast, ) +from canonicaljson import encode_canonical_json from typing_extensions import Literal from synapse.api.constants import EduTypes @@ -1200,53 +1201,29 @@ def _store_dehydrated_device_txn( if keys: device_keys = keys.get("device_keys", None) if device_keys: - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time, "key_json": device_keys}, + # Type ignore - this function is defined on EndToEndKeyStore which we do + # have access to due to hs.get_datastore() "magic" + self._set_e2e_device_keys_txn( # type: ignore[attr-defined] + txn, user_id, device_id, time, device_keys ) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: - self.db_pool.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - keys=( - "user_id", - "device_id", - "algorithm", - "key_id", - "ts_added_ms", - "key_json", - ), - values=[ - (user_id, device_id, algorithm, key_id, time, json_bytes) - for algorithm, key_id, json_bytes in one_time_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append( + ( + algorithm, + key_id, + encode_canonical_json(key_obj).decode("ascii"), + ) + ) + self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list) fallback_keys = keys.get("fallback_keys", None) if fallback_keys: - for key_id, fallback_key in fallback_keys.items(): - algorithm, key_id = key_id.split(":", 1) - self.db_pool.simple_upsert_txn( - txn, - table="e2e_fallback_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - }, - values={ - "key_id": key_id, - "key_json": json_encoder.encode(fallback_key), - "used": False, - }, - ) + self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys) old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 91ae9c457d78..f9fc2be49491 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -522,36 +522,56 @@ async def add_e2e_one_time_keys( new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", str(new_keys)) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self.db_pool.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - keys=( - "user_id", - "device_id", - "algorithm", - "key_id", - "ts_added_ms", - "key_json", - ), - values=[ - (user_id, device_id, algorithm, key_id, time_now, json_bytes) - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - await self.db_pool.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + "add_e2e_one_time_keys_insert", + self._add_e2e_one_time_keys_txn, + user_id, + device_id, + time_now, + new_keys, + ) + + def _add_e2e_one_time_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: + """Insert some new one time keys for a device. Errors if any of the keys already exist. + + Args: + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", str(new_keys)) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), + values=[ + (user_id, device_id, algorithm, key_id, time_now, json_bytes) + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) @cached(max_entries=10000) @@ -723,6 +743,14 @@ def _set_e2e_fallback_keys_txn( device_id: str, fallback_keys: JsonDict, ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded @@ -1304,42 +1332,69 @@ async def set_e2e_device_keys( ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store """ - def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", str(device_keys)) + return await self.db_pool.runInteraction( + "set_e2e_device_keys", + self._set_e2e_device_keys_txn, + user_id, + device_id, + time_now, + device_keys, + ) - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) + def _set_e2e_device_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + device_keys: JsonDict, + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", str(device_keys)) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, ) + log_kv({"message": "Device keys stored."}) + return True async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: From cdc9859c721120d5221d19b608294171a454cec0 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Mon, 7 Aug 2023 15:19:07 -0700 Subject: [PATCH 12/13] fix lint caused by merge --- synapse/rest/client/devices.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 991b2445fcde..555c8a7c50cc 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -532,6 +532,7 @@ class Config: async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: submission = parse_and_validate_json_object_from_request(request, self.PutBody) requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() old_dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) From c5f435702914848eb9041acfe3f9961f6cb9e14e Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 8 Aug 2023 10:26:10 -0700 Subject: [PATCH 13/13] comment + minor cleanup --- synapse/rest/client/devices.py | 8 -------- synapse/storage/databases/main/end_to_end_keys.py | 3 ++- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 555c8a7c50cc..925f037743c0 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -29,7 +29,6 @@ parse_integer, ) from synapse.http.site import SynapseRequest -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client.models import AuthenticationData from synapse.rest.models import RequestBodyModel @@ -480,13 +479,6 @@ def __init__(self, hs: "HomeServer"): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = handler - if hs.config.worker.worker_app is None: - # if main process - self.key_uploader = self.e2e_keys_handler.upload_keys_for_user - else: - # then a worker - self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f9fc2be49491..b49dea577cba 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -545,7 +545,8 @@ def _add_e2e_one_time_keys_txn( user_id: id of user to get keys for device_id: id of device to get keys for time_now: insertion time to record (ms since epoch) - new_keys: keys to add - each a tuple of (algorithm, key_id, key json) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note + that the key JSON must be in canonical JSON form """ set_tag("user_id", user_id) set_tag("device_id", device_id)