Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support MSC3814: Dehydrated Devices Part 2 #16010

Merged
merged 14 commits into from
Aug 8, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/16010.misc
clokep marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update dehydrated devices implementation.
clokep marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 11 additions & 3 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def __init__(self, hs: "HomeServer"):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -656,15 +657,17 @@ async def store_dehydrated_device(
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str:
"""Store a dehydrated device for a user. If the user had a previous
dehydrated device, it is removed.
"""Store a dehydrated device for a user, optionally storing the keys associated with
it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
Expand All @@ -673,11 +676,16 @@ async def store_dehydrated_device(
device_id,
initial_device_display_name,
)

time_now = self.clock.time_msec()

old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
user_id, device_id, device_data, time_now, keys_for_device
)

if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])

return device_id

async def rehydrate_device(
Expand Down
13 changes: 0 additions & 13 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,6 @@ async def get_events_for_dehydrated_device(
errcode=Codes.INVALID_PARAM,
)

# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
deleted,
since_stream_id,
user_id,
device_id,
)

to_token = self.event_sources.get_current_token().to_device_key

messages, stream_id = await self.store.get_messages_for_device(
Expand Down
16 changes: 1 addition & 15 deletions synapse/rest/client/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
parse_integer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
Expand Down Expand Up @@ -480,13 +479,6 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler

if hs.config.worker.worker_app is None:
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

Expand Down Expand Up @@ -549,18 +541,12 @@ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"Device key(s) not found, these must be provided.",
)

# TODO: Those two operations, creating a device and storing the
# device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
)

# TODO: Do we need to do something with the result here?
await self.key_uploader(
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
device_info,
)

return 200, {"device_id": device_id}
Expand Down
51 changes: 49 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
cast,
)

from canonicaljson import encode_canonical_json
from typing_extensions import Literal

from synapse.api.constants import EduTypes
Expand Down Expand Up @@ -1188,8 +1189,42 @@ async def get_dehydrated_device(
)

def _store_dehydrated_device_txn(
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
device_data: str,
time: int,
keys: Optional[JsonDict] = None,
) -> Optional[str]:
# TODO: make keys non-optional once support for msc2697 is dropped
if keys:
device_keys = keys.get("device_keys", None)
if device_keys:
# Type ignore - this function is defined on EndToEndKeyStore which we do
# have access to due to hs.get_datastore() "magic"
self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
txn, user_id, device_id, time, device_keys
)

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append(
(
algorithm,
key_id,
encode_canonical_json(key_obj).decode("ascii"),
)
)
self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)

fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)

old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
Expand All @@ -1203,26 +1238,38 @@ def _store_dehydrated_device_txn(
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)

return old_device_id

async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: JsonDict
self,
user_id: str,
device_id: str,
device_data: JsonDict,
time_now: int,
keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
Args:
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
time_now: current time at the request in milliseconds
keys: keys for the dehydrated device
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
Returns:
device id of the user's previous dehydrated device, if any
"""

return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
time_now,
keys,
)

async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
Expand Down
Loading