Add ratelimiting function to basehandler

pull/4/merge
Mark Haines 10 years ago
parent dd2cd9312a
commit c7a7cdf734
  1. 1
      synapse/api/errors.py
  2. 1
      synapse/app/homeserver.py
  3. 4
      synapse/config/homeserver.py
  4. 17
      synapse/handlers/_base.py
  5. 5
      synapse/server.py

@ -28,6 +28,7 @@ class Codes(object):
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
class CodeMessageException(Exception): class CodeMessageException(Exception):

@ -247,6 +247,7 @@ def setup():
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config,
) )
hs.register_servlets() hs.register_servlets()

@ -17,8 +17,10 @@ from .tls import TlsConfig
from .server import ServerConfig from .server import ServerConfig
from .logger import LoggingConfig from .logger import LoggingConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig): class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig):
pass pass
if __name__=='__main__': if __name__=='__main__':

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import cs_error, Codes
class BaseHandler(object): class BaseHandler(object):
@ -25,8 +26,24 @@ class BaseHandler(object):
self.room_lock = hs.get_room_lock_manager() self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs self.hs = hs
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_messsage_burst_count,
)
if not allowed:
raise cs_error(
"Limit exceeded",
Codes.M_LIMIT_EXCEEDED,
retry_after_ms=1000*(time_allowed - time_now),
)
class BaseRoomHandler(BaseHandler): class BaseRoomHandler(BaseHandler):

@ -32,6 +32,7 @@ from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
class BaseHomeServer(object): class BaseHomeServer(object):
@ -73,6 +74,7 @@ class BaseHomeServer(object):
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'event_sources', 'event_sources',
'ratelimiter',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def register_servlets(self): def register_servlets(self):
""" Register all servlets associated with this HomeServer. """ Register all servlets associated with this HomeServer.
""" """

Loading…
Cancel
Save