|
|
|
@ -12,13 +12,14 @@ |
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
from typing import Any, Callable, List, Optional, Tuple |
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
|
import hiredis |
|
|
|
|
|
|
|
|
|
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime |
|
|
|
|
from twisted.internet.protocol import Protocol |
|
|
|
|
from twisted.internet.task import LoopingCall |
|
|
|
|
from twisted.web.http import HTTPChannel |
|
|
|
|
|
|
|
|
@ -27,7 +28,7 @@ from synapse.app.generic_worker import ( |
|
|
|
|
GenericWorkerServer, |
|
|
|
|
) |
|
|
|
|
from synapse.http.server import JsonResource |
|
|
|
|
from synapse.http.site import SynapseRequest |
|
|
|
|
from synapse.http.site import SynapseRequest, SynapseSite |
|
|
|
|
from synapse.replication.http import ReplicationRestResource, streams |
|
|
|
|
from synapse.replication.tcp.handler import ReplicationCommandHandler |
|
|
|
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol |
|
|
|
@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
self.server_factory = ReplicationStreamProtocolFactory(self.hs) |
|
|
|
|
self.streamer = self.hs.get_replication_streamer() |
|
|
|
|
|
|
|
|
|
# Fake in memory Redis server that servers can connect to. |
|
|
|
|
self._redis_server = FakeRedisPubSubServer() |
|
|
|
|
|
|
|
|
|
store = self.hs.get_datastore() |
|
|
|
|
self.database_pool = store.db_pool |
|
|
|
|
|
|
|
|
|
self.reactor.lookups["testserv"] = "1.2.3.4" |
|
|
|
|
self.reactor.lookups["localhost"] = "127.0.0.1" |
|
|
|
|
|
|
|
|
|
# A map from a HS instance to the associated HTTP Site to use for |
|
|
|
|
# handling inbound HTTP requests to that instance. |
|
|
|
|
self._hs_to_site = {self.hs: self.site} |
|
|
|
|
|
|
|
|
|
if self.hs.config.redis.redis_enabled: |
|
|
|
|
# Handle attempts to connect to fake redis server. |
|
|
|
|
self.reactor.add_tcp_client_callback( |
|
|
|
|
"localhost", 6379, self.connect_any_redis_attempts, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self._worker_hs_to_resource = {} |
|
|
|
|
self.hs.get_tcp_replication().start_replication(self.hs) |
|
|
|
|
|
|
|
|
|
# When we see a connection attempt to the master replication listener we |
|
|
|
|
# automatically set up the connection. This is so that tests don't |
|
|
|
|
# manually have to go and explicitly set it up each time (plus sometimes |
|
|
|
|
# it is impossible to write the handling explicitly in the tests). |
|
|
|
|
# |
|
|
|
|
# Register the master replication listener: |
|
|
|
|
self.reactor.add_tcp_client_callback( |
|
|
|
|
"1.2.3.4", 8765, self._handle_http_replication_attempt |
|
|
|
|
"1.2.3.4", |
|
|
|
|
8765, |
|
|
|
|
lambda: self._handle_http_replication_attempt(self.hs, 8765), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def create_test_json_resource(self): |
|
|
|
@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
**kwargs |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# If the instance is in the `instance_map` config then workers may try |
|
|
|
|
# and send HTTP requests to it, so we register it with |
|
|
|
|
# `_handle_http_replication_attempt` like we do with the master HS. |
|
|
|
|
instance_name = worker_hs.get_instance_name() |
|
|
|
|
instance_loc = worker_hs.config.worker.instance_map.get(instance_name) |
|
|
|
|
if instance_loc: |
|
|
|
|
# Ensure the host is one that has a fake DNS entry. |
|
|
|
|
if instance_loc.host not in self.reactor.lookups: |
|
|
|
|
raise Exception( |
|
|
|
|
"Host does not have an IP for instance_map[%r].host = %r" |
|
|
|
|
% (instance_name, instance_loc.host,) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.reactor.add_tcp_client_callback( |
|
|
|
|
self.reactor.lookups[instance_loc.host], |
|
|
|
|
instance_loc.port, |
|
|
|
|
lambda: self._handle_http_replication_attempt( |
|
|
|
|
worker_hs, instance_loc.port |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
store = worker_hs.get_datastore() |
|
|
|
|
store.db_pool._db_pool = self.database_pool._db_pool |
|
|
|
|
|
|
|
|
|
repl_handler = ReplicationCommandHandler(worker_hs) |
|
|
|
|
client = ClientReplicationStreamProtocol( |
|
|
|
|
worker_hs, "client", "test", self.clock, repl_handler, |
|
|
|
|
) |
|
|
|
|
server = self.server_factory.buildProtocol(None) |
|
|
|
|
# Set up TCP replication between master and the new worker if we don't |
|
|
|
|
# have Redis support enabled. |
|
|
|
|
if not worker_hs.config.redis_enabled: |
|
|
|
|
repl_handler = ReplicationCommandHandler(worker_hs) |
|
|
|
|
client = ClientReplicationStreamProtocol( |
|
|
|
|
worker_hs, "client", "test", self.clock, repl_handler, |
|
|
|
|
) |
|
|
|
|
server = self.server_factory.buildProtocol(None) |
|
|
|
|
|
|
|
|
|
client_transport = FakeTransport(server, self.reactor) |
|
|
|
|
client.makeConnection(client_transport) |
|
|
|
|
client_transport = FakeTransport(server, self.reactor) |
|
|
|
|
client.makeConnection(client_transport) |
|
|
|
|
|
|
|
|
|
server_transport = FakeTransport(client, self.reactor) |
|
|
|
|
server.makeConnection(server_transport) |
|
|
|
|
server_transport = FakeTransport(client, self.reactor) |
|
|
|
|
server.makeConnection(server_transport) |
|
|
|
|
|
|
|
|
|
# Set up a resource for the worker |
|
|
|
|
resource = ReplicationRestResource(self.hs) |
|
|
|
|
resource = ReplicationRestResource(worker_hs) |
|
|
|
|
|
|
|
|
|
for servlet in self.servlets: |
|
|
|
|
servlet(worker_hs, resource) |
|
|
|
|
|
|
|
|
|
self._worker_hs_to_resource[worker_hs] = resource |
|
|
|
|
self._hs_to_site[worker_hs] = SynapseSite( |
|
|
|
|
logger_name="synapse.access.http.fake", |
|
|
|
|
site_tag="{}-{}".format( |
|
|
|
|
worker_hs.config.server.server_name, worker_hs.get_instance_name() |
|
|
|
|
), |
|
|
|
|
config=worker_hs.config.server.listeners[0], |
|
|
|
|
resource=resource, |
|
|
|
|
server_version_string="1", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if worker_hs.config.redis.redis_enabled: |
|
|
|
|
worker_hs.get_tcp_replication().start_replication(worker_hs) |
|
|
|
|
|
|
|
|
|
return worker_hs |
|
|
|
|
|
|
|
|
@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): |
|
|
|
|
render(request, self._worker_hs_to_resource[worker_hs], self.reactor) |
|
|
|
|
render(request, self._hs_to_site[worker_hs].resource, self.reactor) |
|
|
|
|
|
|
|
|
|
def replicate(self): |
|
|
|
|
"""Tell the master side of replication that something has happened, and then |
|
|
|
@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
self.streamer.on_notifier_poke() |
|
|
|
|
self.pump() |
|
|
|
|
|
|
|
|
|
def _handle_http_replication_attempt(self): |
|
|
|
|
"""Handles a connection attempt to the master replication HTTP |
|
|
|
|
listener. |
|
|
|
|
def _handle_http_replication_attempt(self, hs, repl_port): |
|
|
|
|
"""Handles a connection attempt to the given HS replication HTTP |
|
|
|
|
listener on the given port. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# We should have at least one outbound connection attempt, where the |
|
|
|
@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
self.assertGreaterEqual(len(clients), 1) |
|
|
|
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop() |
|
|
|
|
self.assertEqual(host, "1.2.3.4") |
|
|
|
|
self.assertEqual(port, 8765) |
|
|
|
|
self.assertEqual(port, repl_port) |
|
|
|
|
|
|
|
|
|
# Set up client side protocol |
|
|
|
|
client_protocol = client_factory.buildProtocol(None) |
|
|
|
@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
# Set up the server side protocol |
|
|
|
|
channel = _PushHTTPChannel(self.reactor) |
|
|
|
|
channel.requestFactory = request_factory |
|
|
|
|
channel.site = self.site |
|
|
|
|
channel.site = self._hs_to_site[hs] |
|
|
|
|
|
|
|
|
|
# Connect client to server and vice versa. |
|
|
|
|
client_to_server_transport = FakeTransport( |
|
|
|
@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
|
# inside `connecTCP` before the connection has been passed back to the |
|
|
|
|
# code that requested the TCP connection. |
|
|
|
|
|
|
|
|
|
def connect_any_redis_attempts(self): |
|
|
|
|
"""If redis is enabled we need to deal with workers connecting to a |
|
|
|
|
redis server. We don't want to use a real Redis server so we use a |
|
|
|
|
fake one. |
|
|
|
|
""" |
|
|
|
|
clients = self.reactor.tcpClients |
|
|
|
|
self.assertEqual(len(clients), 1) |
|
|
|
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) |
|
|
|
|
self.assertEqual(host, "localhost") |
|
|
|
|
self.assertEqual(port, 6379) |
|
|
|
|
|
|
|
|
|
client_protocol = client_factory.buildProtocol(None) |
|
|
|
|
server_protocol = self._redis_server.buildProtocol(None) |
|
|
|
|
|
|
|
|
|
client_to_server_transport = FakeTransport( |
|
|
|
|
server_protocol, self.reactor, client_protocol |
|
|
|
|
) |
|
|
|
|
client_protocol.makeConnection(client_to_server_transport) |
|
|
|
|
|
|
|
|
|
server_to_client_transport = FakeTransport( |
|
|
|
|
client_protocol, self.reactor, server_protocol |
|
|
|
|
) |
|
|
|
|
server_protocol.makeConnection(server_to_client_transport) |
|
|
|
|
|
|
|
|
|
return client_to_server_transport, server_to_client_transport |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReplicationDataHandler(GenericWorkerReplicationHandler): |
|
|
|
|
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" |
|
|
|
@ -467,3 +547,105 @@ class _PullToPushProducer: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
self.stopProducing() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeRedisPubSubServer: |
|
|
|
|
"""A fake Redis server for pub/sub. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
self._subscribers = set() |
|
|
|
|
|
|
|
|
|
def add_subscriber(self, conn): |
|
|
|
|
"""A connection has called SUBSCRIBE |
|
|
|
|
""" |
|
|
|
|
self._subscribers.add(conn) |
|
|
|
|
|
|
|
|
|
def remove_subscriber(self, conn): |
|
|
|
|
"""A connection has called UNSUBSCRIBE |
|
|
|
|
""" |
|
|
|
|
self._subscribers.discard(conn) |
|
|
|
|
|
|
|
|
|
def publish(self, conn, channel, msg) -> int: |
|
|
|
|
"""A connection want to publish a message to subscribers. |
|
|
|
|
""" |
|
|
|
|
for sub in self._subscribers: |
|
|
|
|
sub.send(["message", channel, msg]) |
|
|
|
|
|
|
|
|
|
return len(self._subscribers) |
|
|
|
|
|
|
|
|
|
def buildProtocol(self, addr): |
|
|
|
|
return FakeRedisPubSubProtocol(self) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeRedisPubSubProtocol(Protocol): |
|
|
|
|
"""A connection from a client talking to the fake Redis server. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, server: FakeRedisPubSubServer): |
|
|
|
|
self._server = server |
|
|
|
|
self._reader = hiredis.Reader() |
|
|
|
|
|
|
|
|
|
def dataReceived(self, data): |
|
|
|
|
self._reader.feed(data) |
|
|
|
|
|
|
|
|
|
# We might get multiple messages in one packet. |
|
|
|
|
while True: |
|
|
|
|
msg = self._reader.gets() |
|
|
|
|
|
|
|
|
|
if msg is False: |
|
|
|
|
# No more messages. |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
if not isinstance(msg, list): |
|
|
|
|
# Inbound commands should always be a list |
|
|
|
|
raise Exception("Expected redis list") |
|
|
|
|
|
|
|
|
|
self.handle_command(msg[0], *msg[1:]) |
|
|
|
|
|
|
|
|
|
def handle_command(self, command, *args): |
|
|
|
|
"""Received a Redis command from the client. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# We currently only support pub/sub. |
|
|
|
|
if command == b"PUBLISH": |
|
|
|
|
channel, message = args |
|
|
|
|
num_subscribers = self._server.publish(self, channel, message) |
|
|
|
|
self.send(num_subscribers) |
|
|
|
|
elif command == b"SUBSCRIBE": |
|
|
|
|
(channel,) = args |
|
|
|
|
self._server.add_subscriber(self) |
|
|
|
|
self.send(["subscribe", channel, 1]) |
|
|
|
|
else: |
|
|
|
|
raise Exception("Unknown command") |
|
|
|
|
|
|
|
|
|
def send(self, msg): |
|
|
|
|
"""Send a message back to the client. |
|
|
|
|
""" |
|
|
|
|
raw = self.encode(msg).encode("utf-8") |
|
|
|
|
|
|
|
|
|
self.transport.write(raw) |
|
|
|
|
self.transport.flush() |
|
|
|
|
|
|
|
|
|
def encode(self, obj): |
|
|
|
|
"""Encode an object to its Redis format. |
|
|
|
|
|
|
|
|
|
Supports: strings/bytes, integers and list/tuples. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if isinstance(obj, bytes): |
|
|
|
|
# We assume bytes are just unicode strings. |
|
|
|
|
obj = obj.decode("utf-8") |
|
|
|
|
|
|
|
|
|
if isinstance(obj, str): |
|
|
|
|
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) |
|
|
|
|
if isinstance(obj, int): |
|
|
|
|
return ":{val}\r\n".format(val=obj) |
|
|
|
|
if isinstance(obj, (list, tuple)): |
|
|
|
|
items = "".join(self.encode(a) for a in obj) |
|
|
|
|
return "*{len}\r\n{items}".format(len=len(obj), items=items) |
|
|
|
|
|
|
|
|
|
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) |
|
|
|
|
|
|
|
|
|
def connectionLost(self, reason): |
|
|
|
|
self._server.remove_subscriber(self) |
|
|
|
|