Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support MSC3814: Dehydrated Devices Part 2 #16010

Merged
merged 14 commits into from
Aug 8, 2023
Merged
1 change: 1 addition & 0 deletions changelog.d/16010.misc
clokep marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update dehydrated devices implementation.
clokep marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 11 additions & 3 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def __init__(self, hs: "HomeServer"):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -656,15 +657,17 @@ async def store_dehydrated_device(
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str:
"""Store a dehydrated device for a user. If the user had a previous
dehydrated device, it is removed.
"""Store a dehydrated device for a user, optionally storing the keys associated with
it as well. If the user had a previous dehydrated device, it is removed.

Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
Expand All @@ -673,11 +676,16 @@ async def store_dehydrated_device(
device_id,
initial_device_display_name,
)

time_now = self.clock.time_msec()

old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
user_id, device_id, device_data, time_now, keys_for_device
)

if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])

return device_id

async def rehydrate_device(
Expand Down
13 changes: 0 additions & 13 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,6 @@ async def get_events_for_dehydrated_device(
errcode=Codes.INVALID_PARAM,
)

# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
deleted,
since_stream_id,
user_id,
device_id,
)

to_token = self.event_sources.get_current_token().to_device_key

messages, stream_id = await self.store.get_messages_for_device(
Expand Down
8 changes: 1 addition & 7 deletions synapse/rest/client/devices.py
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -549,18 +549,12 @@ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"Device key(s) not found, these must be provided.",
)

# TODO: Those two operations, creating a device and storing the
# device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
)

# TODO: Do we need to do something with the result here?
await self.key_uploader(
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
device_info,
)

return 200, {"device_id": device_id}
Expand Down
51 changes: 49 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
cast,
)

from canonicaljson import encode_canonical_json
from typing_extensions import Literal

from synapse.api.constants import EduTypes
Expand Down Expand Up @@ -1188,8 +1189,42 @@ async def get_dehydrated_device(
)

def _store_dehydrated_device_txn(
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
device_data: str,
time: int,
keys: Optional[JsonDict] = None,
) -> Optional[str]:
# TODO: make keys non-optional once support for msc2697 is dropped
if keys:
device_keys = keys.get("device_keys", None)
if device_keys:
# Type ignore - this function is defined on EndToEndKeyStore which we do
# have access to due to hs.get_datastore() "magic"
self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
txn, user_id, device_id, time, device_keys
)

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append(
(
algorithm,
key_id,
encode_canonical_json(key_obj).decode("ascii"),
)
)
self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)

fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)

old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
Expand All @@ -1203,26 +1238,38 @@ def _store_dehydrated_device_txn(
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)

return old_device_id

async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: JsonDict
self,
user_id: str,
device_id: str,
device_data: JsonDict,
time_now: int,
keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.

Args:
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
time_now: current time at the request in milliseconds
keys: keys for the dehydrated device
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

Returns:
device id of the user's previous dehydrated device, if any
"""

return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
time_now,
keys,
)

async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
Expand Down
169 changes: 112 additions & 57 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,36 +522,56 @@ async def add_e2e_one_time_keys(
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""

def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
"add_e2e_one_time_keys_insert",
self._add_e2e_one_time_keys_txn,
user_id,
device_id,
time_now,
new_keys,
)

def _add_e2e_one_time_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
new_keys: Iterable[Tuple[str, str, str]],
) -> None:
"""Insert some new one time keys for a device. Errors if any of the keys already exist.

Args:
user_id: id of user to get keys for
device_id: id of device to get keys for
time_now: insertion time to record (ms since epoch)
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that the key JSON must be in canonical JSON form.

Alternatively, update _upload_one_time_keys_for_user and _store_dehydrated_device_txn to not do the encoding and do it internally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed the comment, thanks for pointing that out!

"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

@cached(max_entries=10000)
Expand Down Expand Up @@ -723,6 +743,14 @@ def _set_e2e_fallback_keys_txn(
device_id: str,
fallback_keys: JsonDict,
) -> None:
"""Set the user's e2e fallback keys.

Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
Expand Down Expand Up @@ -1304,42 +1332,69 @@ async def set_e2e_device_keys(
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.

Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""

def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))
return await self.db_pool.runInteraction(
"set_e2e_device_keys",
self._set_e2e_device_keys_txn,
user_id,
device_id,
time_now,
device_keys,
)

old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
def _set_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
device_keys: JsonDict,
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.

# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))

old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)

if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")

self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False

return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True

async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
Expand Down
9 changes: 5 additions & 4 deletions tests/handlers/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,15 +566,16 @@ def test_dehydrate_v2_and_fetch_events(self) -> None:
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")

# Fetch the message of the dehydrated device again, which should return nothing
# and delete the old messages
# Fetch the message of the dehydrated device again, which should return
# the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=res["next_batch"],
since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0)
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
Loading