@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any , List , Optional , Tuple
from typing import Any , Awaitable , Callable , List , Optional , Tuple
import attr
from synapse . replication . http . streams import ReplicationGetStreamUpdates
from synapse . types import JsonDict
logger = logging . getLogger ( __name__ )
@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
# Some type aliases to make things a bit easier.
# A stream position token
Token = int
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple [ Token , tuple ]
class Stream ( object ) :
""" Base class for the streams.
@ -56,6 +65,7 @@ class Stream(object):
return cls . ROW_TYPE ( * row )
def __init__ ( self , hs ) :
# The token from which we last asked for updates
self . last_token = self . current_token ( )
@ -65,61 +75,46 @@ class Stream(object):
"""
self . last_token = self . current_token ( )
async def get_updates ( self ) :
async def get_updates ( self ) - > Tuple [ List [ Tuple [ Token , JsonDict ] ] , Token , bool ] :
""" Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn ' t been called before).
Returns :
Deferred [ Tuple [ List [ Tuple [ int , Any ] ] , int ] :
Resolves to a pair ` ` ( updates , current_token ) ` ` , where ` ` updates ` ` is a
list of ` ` ( token , row ) ` ` entries . ` ` row ` ` will be json - serialised and
sent over the replication steam .
A triplet ` ( updates , new_last_token , limited ) ` , where ` updates ` is
a list of ` ( token , row ) ` entries , ` new_last_token ` is the new
position in stream , and ` limited ` is whether there are more updates
to fetch .
"""
updates , current_token = await self . get_updates_since ( self . last_token )
current_token = self . current_token ( )
updates , current_token , limited = await self . get_updates_since (
self . last_token , current_token
)
self . last_token = current_token
return updates , current_token
return updates , current_token , limited
async def get_updates_since (
self , from_token : int
) - > Tuple [ List [ Tuple [ int , JsonDict ] ] , int ] :
self , from_token : Token , upto_token : Token , limit : int = 100
) - > Tuple [ List [ Tuple [ Token , JsonDict ] ] , Token , bool ] :
""" Like get_updates except allows specifying from when we should
stream updates
Returns :
Resolves to a pair ` ( updates , new_last_token ) ` , where ` updates ` is
a list of ` ( token , row ) ` entries and ` new_last_token ` is the new
position in stream .
A triplet ` ( updates , new_last_token , limited ) ` , where ` updates ` is
a list of ` ( token , row ) ` entries , ` new_last_token ` is the new
position in stream , and ` limited ` is whether there are more updates
to fetch .
"""
if from_token in ( " NOW " , " now " ) :
return [ ] , self . current_token ( )
current_token = self . current_token ( )
from_token = int ( from_token )
if from_token == current _token:
return [ ] , current_token
if from_token == upto_token :
return [ ] , upto_token , False
rows = await self . update_function (
from_token , current _token, limit = MAX_EVENTS_BEHIND + 1
updates , upto_token , limited = await self . update_function (
from_token , upto _token, limit = limit ,
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools . islice ( rows , MAX_EVENTS_BEHIND + 1 )
updates = [ ( row [ 0 ] , row [ 1 : ] ) for row in rows ]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if len ( updates ) > = MAX_EVENTS_BEHIND :
raise Exception ( " stream %s has fallen behind " % ( self . NAME ) )
# The update function didn't hit the limit, so we must have got all
# the updates to `current_token`, and can return that as our new
# stream position.
return updates , current_token
return updates , upto_token , limited
def current_token ( self ) :
""" Gets the current token of the underlying streams. Should be provided
@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError ( )
def db_query_to_update_function (
query_function : Callable [ [ Token , Token , int ] , Awaitable [ List [ tuple ] ] ]
) - > Callable [ [ Token , Token , int ] , Awaitable [ Tuple [ List [ StreamRow ] , Token , bool ] ] ] :
""" Wraps a db query function which returns a list of rows to make it
suitable for use as an ` update_function ` for the Stream class
"""
async def update_function ( from_token , upto_token , limit ) :
rows = await query_function ( from_token , upto_token , limit )
updates = [ ( row [ 0 ] , row [ 1 : ] ) for row in rows ]
limited = False
if len ( updates ) == limit :
upto_token = rows [ - 1 ] [ 0 ]
limited = True
return updates , upto_token , limited
return update_function
def make_http_update_function (
hs , stream_name : str
) - > Callable [ [ Token , Token , Token ] , Awaitable [ Tuple [ List [ StreamRow ] , Token , bool ] ] ] :
""" Makes a suitable function for use as an `update_function` that queries
the master process for updates .
"""
client = ReplicationGetStreamUpdates . make_client ( hs )
async def update_function (
from_token : int , upto_token : int , limit : int
) - > Tuple [ List [ Tuple [ int , tuple ] ] , int , bool ] :
return await client (
stream_name = stream_name ,
from_token = from_token ,
upto_token = upto_token ,
limit = limit ,
)
return update_function
class BackfillStream ( Stream ) :
""" We fetched some old events and either we had never seen that event before
or it went from being an outlier to not .
@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__ ( self , hs ) :
store = hs . get_datastore ( )
self . current_token = store . get_current_backfill_token # type: ignore
self . update_function = store . get_all_new_backfill_event_rows # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_new_backfill_event_rows ) # type: ignore
super ( BackfillStream , self ) . __init__ ( hs )
@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs . get_datastore ( )
presence_handler = hs . get_presence_handler ( )
self . _is_worker = hs . config . worker_app is not None
self . current_token = store . get_current_presence_token # type: ignore
self . update_function = presence_handler . get_all_presence_updates # type: ignore
if hs . config . worker_app is None :
self . update_function = db_query_to_update_function ( presence_handler . get_all_presence_updates ) # type: ignore
else :
# Query master process
self . update_function = make_http_update_function ( hs , self . NAME ) # type: ignore
super ( PresenceStream , self ) . __init__ ( hs )
@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs . get_typing_handler ( )
self . current_token = typing_handler . get_current_token # type: ignore
self . update_function = typing_handler . get_all_typing_updates # type: ignore
if hs . config . worker_app is None :
self . update_function = db_query_to_update_function ( typing_handler . get_all_typing_updates ) # type: ignore
else :
# Query master process
self . update_function = make_http_update_function ( hs , self . NAME ) # type: ignore
super ( TypingStream , self ) . __init__ ( hs )
@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_max_receipt_stream_id # type: ignore
self . update_function = store . get_all_updated_receipts # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_updated_receipts ) # type: ignore
super ( ReceiptsStream , self ) . __init__ ( hs )
@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function ( self , from_token , to_token , limit ) :
rows = await self . store . get_all_push_rule_updates ( from_token , to_token , limit )
return [ ( row [ 0 ] , row [ 2 ] ) for row in rows ]
limited = False
if len ( rows ) == limit :
to_token = rows [ - 1 ] [ 0 ]
limited = True
return [ ( row [ 0 ] , ( row [ 2 ] , ) ) for row in rows ] , to_token , limited
class PushersStream ( Stream ) :
@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_pushers_stream_token # type: ignore
self . update_function = store . get_all_updated_pushers_rows # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_updated_pushers_rows ) # type: ignore
super ( PushersStream , self ) . __init__ ( hs )
@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_cache_stream_token # type: ignore
self . update_function = store . get_all_updated_caches # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_updated_caches ) # type: ignore
super ( CachesStream , self ) . __init__ ( hs )
@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_current_public_room_stream_id # type: ignore
self . update_function = store . get_all_new_public_rooms # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_new_public_rooms ) # type: ignore
super ( PublicRoomsStream , self ) . __init__ ( hs )
@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_device_stream_token # type: ignore
self . update_function = store . get_all_device_list_changes_for_remotes # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_device_list_changes_for_remotes ) # type: ignore
super ( DeviceListsStream , self ) . __init__ ( hs )
@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_to_device_stream_token # type: ignore
self . update_function = store . get_all_new_device_messages # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_new_device_messages ) # type: ignore
super ( ToDeviceStream , self ) . __init__ ( hs )
@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_max_account_data_stream_id # type: ignore
self . update_function = store . get_all_updated_tags # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_updated_tags ) # type: ignore
super ( TagAccountDataStream , self ) . __init__ ( hs )
@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self . store = hs . get_datastore ( )
self . current_token = self . store . get_max_account_data_stream_id # type: ignore
self . update_function = db_query_to_update_function ( self . _update_function ) # type: ignore
super ( AccountDataStream , self ) . __init__ ( hs )
async def update_function ( self , from_token , to_token , limit ) :
async def _ update_function( self , from_token , to_token , limit ) :
global_results , room_results = await self . store . get_all_updated_account_data (
from_token , from_token , to_token , limit
)
@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_group_stream_token # type: ignore
self . update_function = store . get_all_groups_changes # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_groups_changes ) # type: ignore
super ( GroupServerStream , self ) . __init__ ( hs )
@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs . get_datastore ( )
self . current_token = store . get_device_stream_token # type: ignore
self . update_function = store . get_all_user_signature_changes_for_remotes # type: ignore
self . update_function = db_query_to_update_function ( store . get_all_user_signature_changes_for_remotes ) # type: ignore
super ( UserSignatureStream , self ) . __init__ ( hs )