@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any , Dict , List , Optional , Tuple
from typing import Any , List , Optional , Tuple
import attr
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted . internet . task import LoopingCall
from twisted . web . http import HTTPChannel
from synapse . app . generic_worker import GenericWorkerServer
from synapse . app . generic_worker import (
GenericWorkerReplicationHandler ,
GenericWorkerServer ,
)
from synapse . http . site import SynapseRequest
from synapse . replication . slave . storage . _base import BaseSlavedStore
from synapse . replication . tcp . client import ReplicationDataHandler
from synapse . replication . tcp . handler import ReplicationCommandHandler
from synapse . replication . tcp . protocol import ClientReplicationStreamProtocol
from synapse . replication . tcp . resource import ReplicationStreamProtocolFactory
from synapse . server import HomeServer
from synapse . util import Clock
from tests import unittest
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self . _server_transport = None
def _build_replication_data_handler ( self ) :
return TestReplicationDataHandler ( self . worker_hs . get_datastore ( ) )
return TestReplicationDataHandler ( self . worker_hs )
def reconnect ( self ) :
if self . _client_transport :
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self . assertEqual ( request . method , b " GET " )
class TestReplicationDataHandler ( ReplicationData Handler ) :
class TestReplicationDataHandler ( GenericWorker ReplicationHandler) :
""" Drop-in for ReplicationDataHandler which just collects RDATA rows """
def __init__ ( self , store : BaseSlavedStore ) :
super ( ) . __init__ ( store )
# streams to subscribe to: map from stream id to position
self . stream_positions = { } # type: Dict[str, int]
def __init__ ( self , hs : HomeServer ) :
super ( ) . __init__ ( hs )
# list of received (stream_name, token, row) tuples
self . received_rdata_rows = [ ] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate ( self ) :
return self . stream_positions
async def on_rdata ( self , stream_name , token , rows ) :
await super ( ) . on_rdata ( stream_name , token , rows )
for r in rows :
self . received_rdata_rows . append ( ( stream_name , token , r ) )
if (
stream_name in self . stream_positions
and token > self . stream_positions [ stream_name ]
) :
self . stream_positions [ stream_name ] = token
@attr . s ( )
class OneShotRequestFactory :