Add type annotations to `trace` decorator. (#13328)

Functions that are decorated with `trace` are now properly typed
and the type hints for them are fixed.
1.103.0-whithout-watcha
Patrick Cloke 2 years ago committed by GitHub
parent 47822fd2e8
commit a6895dd576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/13328.misc
  2. 2
      synapse/federation/federation_client.py
  3. 2
      synapse/federation/transport/client.py
  4. 16
      synapse/handlers/e2e_keys.py
  5. 48
      synapse/logging/opentracing.py
  6. 4
      synapse/replication/http/_base.py
  7. 4
      synapse/rest/client/keys.py
  8. 11
      synapse/rest/client/room_keys.py
  9. 4
      synapse/rest/client/sendtodevice.py
  10. 12
      synapse/rest/client/sync.py
  11. 2
      synapse/storage/databases/main/devices.py
  12. 47
      synapse/storage/databases/main/end_to_end_keys.py

@ -0,0 +1 @@
Add type hints to `trace` decorator.

@ -217,7 +217,7 @@ class FederationClient(FederationBase):
) )
async def claim_client_keys( async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int self, destination: str, content: JsonDict, timeout: Optional[int]
) -> JsonDict: ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.

@ -619,7 +619,7 @@ class TransportLayerClient:
) )
async def claim_client_keys( async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict: ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -92,7 +92,11 @@ class E2eKeysHandler:
@trace @trace
async def query_devices( async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
) -> JsonDict: ) -> JsonDict:
"""Handle a device key query from a client """Handle a device key query from a client
@ -120,9 +124,7 @@ class E2eKeysHandler:
the number of in-flight queries at a time. the number of in-flight queries at a time.
""" """
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get( device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
"device_keys", {}
)
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
@ -392,7 +394,7 @@ class E2eKeysHandler:
@trace @trace
async def query_local_devices( async def query_local_devices(
self, query: Dict[str, Optional[List[str]]] self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users """Get E2E device keys for local users
@ -461,7 +463,7 @@ class E2eKeysHandler:
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict: ) -> JsonDict:
local_query: List[Tuple[str, str, str]] = [] local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}

@ -84,14 +84,13 @@ the function becomes the operation name for the span.
return something_usual_and_useful return something_usual_and_useful
Operation names can be explicitly set for a function by passing the Operation names can be explicitly set for a function by using ``trace_with_opname``:
operation name to ``trace``
.. code-block:: python .. code-block:: python
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace_with_opname
@trace(opname="a_better_operation_name") @trace_with_opname("a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs): def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things # Does all kinds of cool and expected things
return something_usual_and_useful return something_usual_and_useful
@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators # Tracing decorators
def trace(func=None, opname: Optional[str] = None): def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
""" """
Decorator to trace a function. Decorator to trace a function with a custom opname.
Sets the operation name to that of the function's or that given
as operation_name. See the module's doc string for usage See the module's doc string for usage examples.
examples.
""" """
def decorator(func): def decorator(func: Callable[P, R]) -> Callable[P, R]:
if opentracing is None: if opentracing is None:
return func # type: ignore[unreachable] return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@wraps(func) @wraps(func)
async def _trace_inner(*args, **kwargs): async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(_opname): with start_active_span(opname):
return await func(*args, **kwargs) return await func(*args, **kwargs) # type: ignore[misc]
else: else:
# The other case here handles both sync functions and those # The other case here handles both sync functions and those
# decorated with inlineDeferred. # decorated with inlineDeferred.
@wraps(func) @wraps(func)
def _trace_inner(*args, **kwargs): def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(_opname) scope = start_active_span(opname)
scope.__enter__() scope.__enter__()
try: try:
@ -858,14 +855,23 @@ def trace(func=None, opname: Optional[str] = None):
scope.__exit__(type(e), None, e.__traceback__) scope.__exit__(type(e), None, e.__traceback__)
raise raise
return _trace_inner return _trace_inner # type: ignore[return-value]
if func:
return decorator(func)
else:
return decorator return decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's name.
See the module's doc string for usage examples.
"""
return trace_with_opname(func.__name__)(func)
def tag_args(func: Callable[P, R]) -> Callable[P, R]: def tag_args(func: Callable[P, R]) -> Callable[P, R]:
""" """
Tags all of the args to the active span. Tags all of the args to the active span.

@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer, is_method_cancellable from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"ascii" "ascii"
) )
@trace(opname="outgoing_replication_request") @trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress(): with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name: if instance_name == local_instance_name:

@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys") @trace_with_opname("upload_keys")
async def on_POST( async def on_POST(
self, request: SynapseRequest, device_id: Optional[str] self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple, cast
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
version = parse_string(request, "version") version = parse_string(request, "version", required=True)
if session_id: if session_id:
body = {"sessions": {session_id: body}} body = {"sessions": {session_id: body}}
@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version", required=True) version = parse_string(request, "version", required=True)
room_keys = await self.e2e_room_keys_handler.get_room_keys( room_keys = cast(
JsonDict,
await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
),
) )
# Convert room_keys to the right format to return. # Convert room_keys to the right format to return.
@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version") version = parse_string(request, "version", required=True)
ret = await self.e2e_room_keys_handler.delete_room_keys( ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id

@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict from synapse.types import JsonDict
@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@trace(opname="sendToDevice") @trace_with_opname("sendToDevice")
def on_PUT( def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Awaitable[Tuple[int, JsonDict]]:

@ -37,7 +37,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder from synapse.util import json_decoder
@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete") logger.debug("Event formatting complete")
return 200, response_content return 200, response_content
@trace(opname="sync.encode_response") @trace_with_opname("sync.encode_response")
async def encode_response( async def encode_response(
self, self,
time_now: int, time_now: int,
@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
] ]
} }
@trace(opname="sync.encode_joined") @trace_with_opname("sync.encode_joined")
async def encode_joined( async def encode_joined(
self, self,
rooms: List[JoinedSyncResult], rooms: List[JoinedSyncResult],
@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
return joined return joined
@trace(opname="sync.encode_invited") @trace_with_opname("sync.encode_invited")
async def encode_invited( async def encode_invited(
self, self,
rooms: List[InvitedSyncResult], rooms: List[InvitedSyncResult],
@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
@trace(opname="sync.encode_knocked") @trace_with_opname("sync.encode_knocked")
async def encode_knocked( async def encode_knocked(
self, self,
rooms: List[KnockedSyncResult], rooms: List[KnockedSyncResult],
@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
return knocked return knocked
@trace(opname="sync.encode_archived") @trace_with_opname("sync.encode_archived")
async def encode_archived( async def encode_archived(
self, self,
rooms: List[ArchivedSyncResult], rooms: List[ArchivedSyncResult],

@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
@trace @trace
async def get_user_devices_from_cache( async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, str]] self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.

@ -22,11 +22,14 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
Union,
cast, cast,
overload,
) )
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import ( from synapse.appservice import (
@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id] user_devices = devices[user_id]
results = [] results = []
for device_id, device in user_devices.items(): for device_id, device in user_devices.items():
result = {"device_id": device_id} result: JsonDict = {"device_id": device_id}
keys = device.keys keys = device.keys
if keys: if keys:
@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {} rv[user_id] = {}
for device_id, device_info in device_keys.items(): for device_id, device_info in device_keys.items():
r = device_info.keys r = device_info.keys
if r is None:
continue
r["unsigned"] = {} r["unsigned"] = {}
display_name = device_info.display_name display_name = device_info.display_name
if display_name is not None: if display_name is not None:
@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv return rv
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[True],
include_deleted_devices: Literal[True],
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
...
@trace @trace
async def get_e2e_device_keys_and_signatures( async def get_e2e_device_keys_and_signatures(
self, self,
query_list: List[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False, include_all_devices: bool = False,
include_deleted_devices: bool = False, include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Union[
Dict[str, Dict[str, DeviceKeyLookupResult]],
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
]:
"""Fetch a list of device keys """Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also Any cross-signatures made on the keys by the owner of the device are also
@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False db_autocommit = False
row = await self.db_pool.runInteraction( claim_row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", "claim_e2e_one_time_keys",
_claim_e2e_one_time_key, _claim_e2e_one_time_key,
user_id, user_id,
@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm, algorithm,
db_autocommit=db_autocommit, db_autocommit=db_autocommit,
) )
if row: if claim_row:
device_results = results.setdefault(user_id, {}).setdefault( device_results = results.setdefault(user_id, {}).setdefault(
device_id, {} device_id, {}
) )
device_results[row[0]] = row[1] device_results[claim_row[0]] = claim_row[1]
continue continue
# No one-time key available, so see if there's a fallback # No one-time key available, so see if there's a fallback

Loading…
Cancel
Save