@ -16,7 +16,7 @@
import logging
from collections import namedtuple
from typing import Any , Awaitable , Callable , Iterable , List , Optional , Tuple
from typing import Any , Awaitable , Callable , List , Optional , Tuple
import attr
@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
#
# The arguments are:
#
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable [ [ Token , Token , int ] , Awaitable [ StreamUpdateResult ] ]
UpdateFunction = Callable [ [ str , Token , Token , int ] , Awaitable [ StreamUpdateResult ] ]
class Stream ( object ) :
@ -93,6 +94,7 @@ class Stream(object):
def __init__ (
self ,
local_instance_name : str ,
current_token_function : Callable [ [ ] , Token ] ,
update_function : UpdateFunction ,
) :
@ -108,9 +110,11 @@ class Stream(object):
stream tokens . See the UpdateFunction type definition for more info .
Args :
local_instance_name : The instance name of the current process
current_token_function : callback to get the current token , as above
update_function : callback go get stream updates , as above
"""
self . local_instance_name = local_instance_name
self . current_token = current_token_function
self . update_function = update_function
@ -135,14 +139,14 @@ class Stream(object):
"""
current_token = self . current_token ( )
updates , current_token , limited = await self . get_updates_since (
self . last_token , current_token
self . local_instance_name , self . l ast_token , current_token
)
self . last_token = current_token
return updates , current_token , limited
async def get_updates_since (
self , from_token : Token , upto_token : Token
self , instance_name : str , from_token : Token , upto_token : Token
) - > StreamUpdateResult :
""" Like get_updates except allows specifying from when we should
stream updates
@ -160,19 +164,19 @@ class Stream(object):
return [ ] , upto_token , False
updates , upto_token , limited = await self . update_function (
from_token , upto_token , _STREAM_UPDATE_TARGET_ROW_COUNT ,
instance_name , from_token , upto_token , _STREAM_UPDATE_TARGET_ROW_COUNT ,
)
return updates , upto_token , limited
def db_query_to_update_function (
query_function : Callable [ [ Token , Token , int ] , Awaitable [ Iterable [ tuple ] ] ]
query_function : Callable [ [ Token , Token , int ] , Awaitable [ List [ tuple ] ] ]
) - > UpdateFunction :
""" Wraps a db query function which returns a list of rows to make it
suitable for use as an ` update_function ` for the Stream class
"""
async def update_function ( from_token , upto_token , limit ) :
async def update_function ( instance_name , from_token , upto_token , limit ) :
rows = await query_function ( from_token , upto_token , limit )
updates = [ ( row [ 0 ] , row [ 1 : ] ) for row in rows ]
limited = False
@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates . make_client ( hs )
async def update_function (
from_token : int , upto_token : int , limit : int
instance_name : str , from_token : int , upto_token : int , limit : int
) - > StreamUpdateResult :
result = await client (
stream_name = stream_name , from_token = from_token , upto_token = upto_token ,
instance_name = instance_name ,
stream_name = stream_name ,
from_token = from_token ,
upto_token = upto_token ,
)
return result [ " updates " ] , result [ " upto_token " ] , result [ " limited " ]
@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_current_backfill_token ,
db_query_to_update_function ( store . get_all_new_backfill_event_rows ) ,
)
@ -261,7 +269,9 @@ class PresenceStream(Stream):
# Query master process
update_function = make_http_update_function ( hs , self . NAME )
super ( ) . __init__ ( store . get_current_presence_token , update_function )
super ( ) . __init__ (
hs . get_instance_name ( ) , store . get_current_presence_token , update_function
)
class TypingStream ( Stream ) :
@ -284,7 +294,9 @@ class TypingStream(Stream):
# Query master process
update_function = make_http_update_function ( hs , self . NAME )
super ( ) . __init__ ( typing_handler . get_current_token , update_function )
super ( ) . __init__ (
hs . get_instance_name ( ) , typing_handler . get_current_token , update_function
)
class ReceiptsStream ( Stream ) :
@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_max_receipt_stream_id ,
db_query_to_update_function ( store . get_all_updated_receipts ) ,
)
@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__ ( self , hs ) :
self . store = hs . get_datastore ( )
super ( PushRulesStream , self ) . __init__ (
self . _current_token , self . _update_function
hs . get_instance_name ( ) , self . _current_token , self . _update_function
)
def _current_token ( self ) - > int :
push_rules_token , _ = self . store . get_push_rules_stream_token ( )
return push_rules_token
async def _update_function ( self , from_token : Token , to_token : Token , limit : int ) :
async def _update_function (
self , instance_name : str , from_token : Token , to_token : Token , limit : int
) :
rows = await self . store . get_all_push_rule_updates ( from_token , to_token , limit )
limited = False
@ -356,6 +371,7 @@ class PushersStream(Stream):
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_pushers_stream_token ,
db_query_to_update_function ( store . get_all_updated_pushers_rows ) ,
)
@ -387,6 +403,7 @@ class CachesStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_cache_stream_token ,
db_query_to_update_function ( store . get_all_updated_caches ) ,
)
@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_current_public_room_stream_id ,
db_query_to_update_function ( store . get_all_new_public_rooms ) ,
)
@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_device_stream_token ,
db_query_to_update_function ( store . get_all_device_list_changes_for_remotes ) ,
)
@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_to_device_stream_token ,
db_query_to_update_function ( store . get_all_new_device_messages ) ,
)
@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_max_account_data_stream_id ,
db_query_to_update_function ( store . get_all_updated_tags ) ,
)
@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__ ( self , hs ) :
self . store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
self . store . get_max_account_data_stream_id ,
db_query_to_update_function ( self . _update_function ) ,
)
@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_group_stream_token ,
db_query_to_update_function ( store . get_all_groups_changes ) ,
)
@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
super ( ) . __init__ (
hs . get_instance_name ( ) ,
store . get_device_stream_token ,
db_query_to_update_function (
store . get_all_user_signature_changes_for_remotes