@ -15,6 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING , Tuple
from twisted . web . http import Request
from synapse . api . errors import AuthError , Codes , NotFoundError , SynapseError
from synapse . http . servlet import RestServlet , parse_boolean , parse_integer
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin ,
assert_user_is_admin ,
)
from synapse . types import JsonDict
if TYPE_CHECKING :
from synapse . app . homeserver import HomeServer
logger = logging . getLogger ( __name__ )
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns ( " /quarantine_media/(?P<room_id>[^/]+) " )
)
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
async def on_POST ( self , request , room_id : str ) :
async def on_POST ( self , request : Request , room_id : str ) - > Tuple [ int , JsonDict ] :
requester = await self . auth . get_user_by_req ( request )
await assert_user_is_admin ( self . auth , requester . user )
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns ( " /user/(?P<user_id>[^/]+)/media/quarantine " )
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
async def on_POST ( self , request , user_id : str ) :
async def on_POST ( self , request : Request , user_id : str ) - > Tuple [ int , JsonDict ] :
requester = await self . auth . get_user_by_req ( request )
await assert_user_is_admin ( self . auth , requester . user )
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
" /media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+) "
)
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
async def on_POST ( self , request , server_name : str , media_id : str ) :
async def on_POST (
self , request : Request , server_name : str , media_id : str
) - > Tuple [ int , JsonDict ] :
requester = await self . auth . get_user_by_req ( request )
await assert_user_is_admin ( self . auth , requester . user )
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200 , { }
class ProtectMediaByID ( RestServlet ) :
""" Protect local media from being quarantined.
"""
PATTERNS = admin_patterns ( " /media/protect/(?P<media_id>[^/]+) " )
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
async def on_POST ( self , request : Request , media_id : str ) - > Tuple [ int , JsonDict ] :
requester = await self . auth . get_user_by_req ( request )
await assert_user_is_admin ( self . auth , requester . user )
logging . info ( " Protecting local media by ID: %s " , media_id )
# Quarantine this media id
await self . store . mark_local_media_as_safe ( media_id )
return 200 , { }
class ListMediaInRoom ( RestServlet ) :
""" Lists all of the media in a given room.
"""
PATTERNS = admin_patterns ( " /room/(?P<room_id>[^/]+)/media " )
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
async def on_GET ( self , request , room_id ) :
async def on_GET ( self , request : Request , room_id : str ) - > Tuple [ int , JsonDict ] :
requester = await self . auth . get_user_by_req ( request )
is_admin = await self . auth . is_server_admin ( requester . user )
if not is_admin :
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet ( RestServlet ) :
PATTERNS = admin_patterns ( " /purge_media_cache " )
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . media_repository = hs . get_media_repository ( )
self . auth = hs . get_auth ( )
async def on_POST ( self , request ) :
async def on_POST ( self , request : Request ) - > Tuple [ int , JsonDict ] :
await assert_requester_is_admin ( self . auth , request )
before_ts = parse_integer ( request , " before_ts " , required = True )
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns ( " /media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+) " )
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
self . server_name = hs . hostname
self . media_repository = hs . get_media_repository ( )
async def on_DELETE ( self , request , server_name : str , media_id : str ) :
async def on_DELETE (
self , request : Request , server_name : str , media_id : str
) - > Tuple [ int , JsonDict ] :
await assert_requester_is_admin ( self . auth , request )
if self . server_name != server_name :
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns ( " /media/(?P<server_name>[^/]+)/delete " )
def __init__ ( self , hs ) :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
self . auth = hs . get_auth ( )
self . server_name = hs . hostname
self . media_repository = hs . get_media_repository ( )
async def on_POST ( self , request , server_name : str ) :
async def on_POST ( self , request : Request , server_name : str ) - > Tuple [ int , JsonDict ] :
await assert_requester_is_admin ( self . auth , request )
before_ts = parse_integer ( request , " before_ts " , required = True )
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200 , { " deleted_media " : deleted_media , " total " : total }
def register_servlets_for_media_repo ( hs , http_server ) :
def register_servlets_for_media_repo ( hs : " HomeServer " , http_server ) :
"""
Media repo specific APIs .
"""
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom ( hs ) . register ( http_server )
QuarantineMediaByID ( hs ) . register ( http_server )
QuarantineMediaByUser ( hs ) . register ( http_server )
ProtectMediaByID ( hs ) . register ( http_server )
ListMediaInRoom ( hs ) . register ( http_server )
DeleteMediaByID ( hs ) . register ( http_server )
DeleteMediaByDateSize ( hs ) . register ( http_server )