Admin API to delete media for a specific user (#10558)

code_spécifique_watcha
Dirk Klimpel 3 years ago committed by GitHub
parent 3ebb6694f0
commit 915b37e5ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/10558.feature
  2. 9
      docs/admin_api/media_admin_api.md
  3. 54
      docs/admin_api/user_admin_api.md
  4. 4
      synapse/rest/admin/media.py
  5. 80
      synapse/rest/admin/users.py
  6. 6
      synapse/rest/media/v1/media_repository.py
  7. 321
      tests/rest/admin/test_user.py

@ -0,0 +1 @@
Admin API to delete several media for a specific user. Contributed by @dklimpel.

@ -12,6 +12,7 @@
- [Delete local media](#delete-local-media) - [Delete local media](#delete-local-media)
* [Delete a specific local media](#delete-a-specific-local-media) * [Delete a specific local media](#delete-a-specific-local-media)
* [Delete local media by date or size](#delete-local-media-by-date-or-size) * [Delete local media by date or size](#delete-local-media-by-date-or-size)
* [Delete media uploaded by a user](#delete-media-uploaded-by-a-user)
- [Purge Remote Media API](#purge-remote-media-api) - [Purge Remote Media API](#purge-remote-media-api)
# Querying media # Querying media
@ -47,7 +48,8 @@ The API returns a JSON body like the following:
## List all media uploaded by a user ## List all media uploaded by a user
Listing all media that has been uploaded by a local user can be achieved through Listing all media that has been uploaded by a local user can be achieved through
the use of the [List media of a user](user_admin_api.md#list-media-of-a-user) the use of the
[List media uploaded by a user](user_admin_api.md#list-media-uploaded-by-a-user)
Admin API. Admin API.
# Quarantine media # Quarantine media
@ -281,6 +283,11 @@ The following fields are returned in the JSON response body:
* `deleted_media`: an array of strings - List of deleted `media_id` * `deleted_media`: an array of strings - List of deleted `media_id`
* `total`: integer - Total number of deleted `media_id` * `total`: integer - Total number of deleted `media_id`
## Delete media uploaded by a user
You can find details of how to delete multiple media uploaded by a user in
[User Admin API](user_admin_api.md#delete-media-uploaded-by-a-user).
# Purge Remote Media API # Purge Remote Media API
The purge remote media API allows server admins to purge old cached remote media. The purge remote media API allows server admins to purge old cached remote media.

@ -443,8 +443,9 @@ The following fields are returned in the JSON response body:
- `joined_rooms` - An array of `room_id`. - `joined_rooms` - An array of `room_id`.
- `total` - Number of rooms. - `total` - Number of rooms.
## User media
## List media of a user ### List media uploaded by a user
Gets a list of all local media that a specific `user_id` has created. Gets a list of all local media that a specific `user_id` has created.
By default, the response is ordered by descending creation date and ascending media ID. By default, the response is ordered by descending creation date and ascending media ID.
The newest media is on top. You can change the order with parameters The newest media is on top. You can change the order with parameters
@ -543,7 +544,6 @@ The following fields are returned in the JSON response body:
- `media` - An array of objects, each containing information about a media. - `media` - An array of objects, each containing information about a media.
Media objects contain the following fields: Media objects contain the following fields:
- `created_ts` - integer - Timestamp when the content was uploaded in ms. - `created_ts` - integer - Timestamp when the content was uploaded in ms.
- `last_access_ts` - integer - Timestamp when the content was last accessed in ms. - `last_access_ts` - integer - Timestamp when the content was last accessed in ms.
- `media_id` - string - The id used to refer to the media. - `media_id` - string - The id used to refer to the media.
@ -551,13 +551,58 @@ The following fields are returned in the JSON response body:
- `media_type` - string - The MIME-type of the media. - `media_type` - string - The MIME-type of the media.
- `quarantined_by` - string - The user ID that initiated the quarantine request - `quarantined_by` - string - The user ID that initiated the quarantine request
for this media. for this media.
- `safe_from_quarantine` - bool - Status if this media is safe from quarantining. - `safe_from_quarantine` - bool - Status if this media is safe from quarantining.
- `upload_name` - string - The name the media was uploaded with. - `upload_name` - string - The name the media was uploaded with.
- `next_token`: integer - Indication for pagination. See above. - `next_token`: integer - Indication for pagination. See above.
- `total` - integer - Total number of media. - `total` - integer - Total number of media.
### Delete media uploaded by a user
This API deletes the *local* media from the disk of your own server
that a specific `user_id` has created. This includes any local thumbnails.
This API will not affect media that has been uploaded to external
media repositories (e.g https://github.com/turt2live/matrix-media-repo/).
By default, the API deletes media ordered by descending creation date and ascending media ID.
The newest media is deleted first. You can change the order with parameters
`order_by` and `dir`. If no `limit` is set the API deletes `100` files per request.
The API is:
```
DELETE /_synapse/admin/v1/users/<user_id>/media
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
```json
{
"deleted_media": [
"abcdefghijklmnopqrstuvwx"
],
"total": 1
}
```
The following fields are returned in the JSON response body:
* `deleted_media`: an array of strings - List of deleted `media_id`
* `total`: integer - Total number of deleted `media_id`
**Note**: There is no `next_token`. This is not useful for deleting media, because
after deleting media the remaining media have a new order.
**Parameters**
This API has the same parameters as
[List media uploaded by a user](#list-media-uploaded-by-a-user).
With the parameters you can for example limit the number of files to delete at once or
delete largest/smallest or newest/oldest files first.
## Login as a user ## Login as a user
Get an access token that can be used to authenticate as that user. Useful for Get an access token that can be used to authenticate as that user. Useful for
@ -1012,4 +1057,3 @@ The following parameters should be set in the URL:
- `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must
be local. be local.

@ -259,7 +259,9 @@ class DeleteMediaByID(RestServlet):
logging.info("Deleting local media by ID: %s", media_id) logging.info("Deleting local media by ID: %s", media_id)
deleted_media, total = await self.media_repository.delete_local_media(media_id) deleted_media, total = await self.media_repository.delete_local_media_ids(
[media_id]
)
return 200, {"deleted_media": deleted_media, "total": total} return 200, {"deleted_media": deleted_media, "total": total}

@ -172,7 +172,7 @@ class UserRestServletV2(RestServlet):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only lookup local users") raise SynapseError(400, "Can only look up local users")
ret = await self.admin_handler.get_user(target_user) ret = await self.admin_handler.get_user(target_user)
@ -796,7 +796,7 @@ class PushersRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users") raise SynapseError(400, "Can only look up local users")
if not await self.store.get_user_by_id(user_id): if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found") raise NotFoundError("User not found")
@ -811,10 +811,10 @@ class PushersRestServlet(RestServlet):
class UserMediaRestServlet(RestServlet): class UserMediaRestServlet(RestServlet):
""" """
Gets information about all uploaded local media for a specific `user_id`. Gets information about all uploaded local media for a specific `user_id`.
With DELETE request you can delete all this media.
Example: Example:
http://localhost:8008/_synapse/admin/v1/users/ http://localhost:8008/_synapse/admin/v1/users/@user:server/media
@user:server/media
Args: Args:
The parameters `from` and `limit` are required for pagination. The parameters `from` and `limit` are required for pagination.
@ -830,6 +830,7 @@ class UserMediaRestServlet(RestServlet):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.media_repository = hs.get_media_repository()
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -840,7 +841,7 @@ class UserMediaRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users") raise SynapseError(400, "Can only look up local users")
user = await self.store.get_user_by_id(user_id) user = await self.store.get_user_by_id(user_id)
if user is None: if user is None:
@ -898,6 +899,73 @@ class UserMediaRestServlet(RestServlet):
return 200, ret return 200, ret
async def on_DELETE(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only look up local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
raise NotFoundError("Unknown user")
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
if start < 0:
raise SynapseError(
400,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
400,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
else:
order_by = parse_string(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
)
media, _ = await self.store.get_local_media_by_user_paginate(
start, limit, user_id, order_by, direction
)
deleted_media, total = await self.media_repository.delete_local_media_ids(
([row["media_id"] for row in media])
)
return 200, {"deleted_media": deleted_media, "total": total}
class UserTokenRestServlet(RestServlet): class UserTokenRestServlet(RestServlet):
"""An admin API for logging in as a user. """An admin API for logging in as a user.
@ -1017,7 +1085,7 @@ class RateLimitRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Can only lookup local users") raise SynapseError(400, "Can only look up local users")
if not await self.store.get_user_by_id(user_id): if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found") raise NotFoundError("User not found")

@ -836,7 +836,9 @@ class MediaRepository:
return {"deleted": deleted} return {"deleted": deleted}
async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: async def delete_local_media_ids(
self, media_ids: List[str]
) -> Tuple[List[str], int]:
""" """
Delete the given local or remote media ID from this server Delete the given local or remote media ID from this server
@ -845,7 +847,7 @@ class MediaRepository:
Returns: Returns:
A tuple of (list of deleted media IDs, total deleted media IDs). A tuple of (list of deleted media IDs, total deleted media IDs).
""" """
return await self._remove_local_media_from_disk([media_id]) return await self._remove_local_media_from_disk(media_ids)
async def delete_old_local_media( async def delete_old_local_media(
self, self,

@ -15,17 +15,21 @@
import hashlib import hashlib
import hmac import hmac
import json import json
import os
import urllib.parse import urllib.parse
from binascii import unhexlify from binascii import unhexlify
from typing import List, Optional from typing import List, Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from parameterized import parameterized
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync from synapse.rest.client.v2_alpha import devices, sync
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
@ -72,7 +76,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}") channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"] "Shared secret registration is not enabled", channel.json_body["error"]
) )
@ -104,7 +108,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce}) body = json.dumps({"nonce": nonce})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"]) self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds # 61 seconds
@ -112,7 +116,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"]) self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self): def test_register_incorrect_nonce(self):
@ -166,7 +170,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self): def test_nonce_reuse(self):
@ -191,13 +195,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it # Now, try and reuse it
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"]) self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self): def test_missing_parts(self):
@ -219,7 +223,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({}) body = json.dumps({})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"]) self.assertEqual("nonce must be specified", channel.json_body["error"])
# #
@ -230,28 +234,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce()}) body = json.dumps({"nonce": nonce()})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"]) self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string # Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234}) body = json.dumps({"nonce": nonce(), "username": 1234})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"]) self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"]) self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"]) self.assertEqual("Invalid username", channel.json_body["error"])
# #
@ -262,28 +266,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce(), "username": "a"}) body = json.dumps({"nonce": nonce(), "username": "a"})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"]) self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string # Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"]) self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"]) self.assertEqual("Invalid password", channel.json_body["error"])
# Super long # Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"]) self.assertEqual("Invalid password", channel.json_body["error"])
# #
@ -301,7 +305,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"]) self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self): def test_displayname(self):
@ -322,11 +326,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"]) self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname") channel = self.make_request("GET", "/profile/@bob1:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"]) self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None # displayname is None
@ -348,11 +352,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"]) self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname") channel = self.make_request("GET", "/profile/@bob2:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"]) self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty # displayname is empty
@ -374,7 +378,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"]) self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname") channel = self.make_request("GET", "/profile/@bob3:test/displayname")
@ -399,11 +403,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"]) self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname") channel = self.make_request("GET", "/profile/@bob4:test/displayname")
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"]) self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config( @override_config(
@ -449,7 +453,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
) )
channel = self.make_request("POST", self.url, body.encode("utf8")) channel = self.make_request("POST", self.url, body.encode("utf8"))
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
@ -638,7 +642,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order # invalid search order
@ -1085,7 +1089,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False}, content={"erase": False},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -2180,7 +2184,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self): def test_get_pushers(self):
""" """
@ -2249,6 +2253,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -2258,37 +2263,34 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.other_user self.other_user
) )
def test_no_auth(self): @parameterized.expand(["GET", "DELETE"])
""" def test_no_auth(self, method: str):
Try to list media of an user without authentication. """Try to list media of an user without authentication."""
""" channel = self.make_request(method, self.url, {})
channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): @parameterized.expand(["GET", "DELETE"])
""" def test_requester_is_no_admin(self, method: str):
If the user is not a server admin, an error is returned. """If the user is not a server admin, an error is returned."""
"""
other_user_token = self.login("user", "pass") other_user_token = self.login("user", "pass")
channel = self.make_request( channel = self.make_request(
"GET", method,
self.url, self.url,
access_token=other_user_token, access_token=other_user_token,
) )
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self): @parameterized.expand(["GET", "DELETE"])
""" def test_user_does_not_exist(self, method: str):
Tests that a lookup for a user that does not exist returns a 404 """Tests that a lookup for a user that does not exist returns a 404"""
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media" url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request( channel = self.make_request(
"GET", method,
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
@ -2296,25 +2298,22 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self): @parameterized.expand(["GET", "DELETE"])
""" def test_user_is_not_local(self, method: str):
Tests that a lookup for a user that is not a local returns a 400 """Tests that a lookup for a user that is not a local returns a 400"""
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request( channel = self.make_request(
"GET", method,
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit(self): def test_limit_GET(self):
""" """Testing list of media with limit"""
Testing list of media with limit
"""
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
@ -2326,16 +2325,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def test_from(self): def test_limit_DELETE(self):
""" """Testing delete of media with limit"""
Testing list of media with a defined starting point (from)
""" number_media = 20
other_user_tok = self.login("user", "pass")
self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request(
"DELETE",
self.url + "?limit=5",
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
def test_from_GET(self):
"""Testing list of media with a defined starting point (from)"""
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
@ -2347,16 +2361,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15) self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def test_limit_and_from(self): def test_from_DELETE(self):
""" """Testing delete of media with a defined starting point (from)"""
Testing list of media with a defined starting point and limit
""" number_media = 20
other_user_tok = self.login("user", "pass")
self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request(
"DELETE",
self.url + "?from=5",
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
def test_limit_and_from_GET(self):
"""Testing list of media with a defined starting point and limit"""
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
@ -2368,59 +2397,78 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10) self.assertEqual(len(channel.json_body["media"]), 10)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def test_invalid_parameter(self): def test_limit_and_from_DELETE(self):
""" """Testing delete of media with a defined starting point and limit"""
If parameters are invalid, an error is returned.
""" number_media = 20
other_user_tok = self.login("user", "pass")
self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request(
"DELETE",
self.url + "?from=5&limit=10",
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@parameterized.expand(["GET", "DELETE"])
def test_invalid_parameter(self, method: str):
"""If parameters are invalid, an error is returned."""
# unkown order_by # unkown order_by
channel = self.make_request( channel = self.make_request(
"GET", method,
self.url + "?order_by=bar", self.url + "?order_by=bar",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order # invalid search order
channel = self.make_request( channel = self.make_request(
"GET", method,
self.url + "?dir=bar", self.url + "?dir=bar",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative limit # negative limit
channel = self.make_request( channel = self.make_request(
"GET", method,
self.url + "?limit=-5", self.url + "?limit=-5",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from # negative from
channel = self.make_request( channel = self.make_request(
"GET", method,
self.url + "?from=-5", self.url + "?from=-5",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self): def test_next_token(self):
""" """
Testing that `next_token` appears at the right place Testing that `next_token` appears at the right place
For deleting media `next_token` is not useful, because
after deleting media the media has a new order.
""" """
number_media = 20 number_media = 20
@ -2435,7 +2483,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media) self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -2448,7 +2496,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media) self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -2461,7 +2509,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19) self.assertEqual(channel.json_body["next_token"], 19)
@ -2475,12 +2523,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1) self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
def test_user_has_no_media(self): def test_user_has_no_media_GET(self):
""" """
Tests that a normal lookup for media is successfully Tests that a normal lookup for media is successfully
if user has no media created if user has no media created
@ -2496,11 +2544,24 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"])) self.assertEqual(0, len(channel.json_body["media"]))
def test_get_media(self): def test_user_has_no_media_DELETE(self):
""" """
Tests that a normal lookup for media is successfully Tests that a delete is successful if user has no media
""" """
channel = self.make_request(
"DELETE",
self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
def test_get_media(self):
"""Tests that a normal lookup for media is successful"""
number_media = 5 number_media = 5
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media_for_user(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
@ -2517,6 +2578,35 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def test_delete_media(self):
"""Tests that a normal delete of media is successful"""
number_media = 5
other_user_tok = self.login("user", "pass")
media_ids = self._create_media_for_user(other_user_tok, number_media)
# Test if the file exists
local_paths = []
for media_id in media_ids:
local_path = self.filepaths.local_media_filepath(media_id)
self.assertTrue(os.path.exists(local_path))
local_paths.append(local_path)
channel = self.make_request(
"DELETE",
self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
# Test if the file is deleted
for local_path in local_paths:
self.assertFalse(os.path.exists(local_path))
def test_order_by(self): def test_order_by(self):
""" """
Testing order list with parameter `order_by` Testing order list with parameter `order_by`
@ -2622,13 +2712,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
[media2] + sorted([media1, media3]), "safe_from_quarantine", "b" [media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
) )
def _create_media_for_user(self, user_token: str, number_media: int): def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]:
""" """
Create a number of media for a specific user Create a number of media for a specific user
Args: Args:
user_token: Access token of the user user_token: Access token of the user
number_media: Number of media to be created for the user number_media: Number of media to be created for the user
Returns:
List of created media ID
""" """
media_ids = []
for _ in range(number_media): for _ in range(number_media):
# file size is 67 Byte # file size is 67 Byte
image_data = unhexlify( image_data = unhexlify(
@ -2637,7 +2730,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
b"0a2db40000000049454e44ae426082" b"0a2db40000000049454e44ae426082"
) )
self._create_media_and_access(user_token, image_data) media_ids.append(self._create_media_and_access(user_token, image_data))
return media_ids
def _create_media_and_access( def _create_media_and_access(
self, self,
@ -2680,7 +2775,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
200, 200,
channel.code, channel.code,
msg=( msg=(
"Expected to receive a 200 on accessing media: %s" % server_and_media_id f"Expected to receive a 200 on accessing media: {server_and_media_id}"
), ),
) )
@ -2718,12 +2813,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url = self.url + "?" url = self.url + "?"
if order_by is not None: if order_by is not None:
url += "order_by=%s&" % (order_by,) url += f"order_by={order_by}&"
if dir is not None and dir in ("b", "f"): if dir is not None and dir in ("b", "f"):
url += "dir=%s" % (dir,) url += f"dir={dir}"
channel = self.make_request( channel = self.make_request(
"GET", "GET",
url.encode("ascii"), url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
@ -2762,7 +2857,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok "POST", self.url, b"{}", access_token=self.admin_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"] return channel.json_body["access_token"]
def test_no_auth(self): def test_no_auth(self):
@ -2803,7 +2898,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`) # We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1) self.assertEqual(len(channel.json_body["devices"]), 1)
@ -2815,11 +2910,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token # Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work # The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@ -2829,7 +2924,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self): def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire """Tests that the target user calling `/logout/all` does *not* expire
@ -2840,17 +2935,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token # Logout all with the real user token
channel = self.make_request( channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok "POST", "logout/all", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work # The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't # .. but the real user's tokens shouldn't
channel = self.make_request( channel = self.make_request(
@ -2867,13 +2962,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token # Logout all with the admin user token
channel = self.make_request( channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok "POST", "logout/all", b"{}", access_token=self.admin_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work # The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@ -2883,7 +2978,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config( @unittest.override_config(
{ {
@ -3243,7 +3338,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only look up local users", channel.json_body["error"])
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -3279,7 +3374,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"}, content={"messages_per_second": "string"},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative # messages_per_second is negative
@ -3290,7 +3385,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1}, content={"messages_per_second": -1},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string # burst_count is a string
@ -3301,7 +3396,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"}, content={"burst_count": "string"},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative # burst_count is negative
@ -3312,7 +3407,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1}, content={"burst_count": -1},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self): def test_return_zero_when_null(self):
@ -3337,7 +3432,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"]) self.assertEqual(0, channel.json_body["burst_count"])
@ -3351,7 +3446,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)
@ -3362,7 +3457,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11}, content={"messages_per_second": 10, "burst_count": 11},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"]) self.assertEqual(11, channel.json_body["burst_count"])
@ -3373,7 +3468,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21}, content={"messages_per_second": 20, "burst_count": 21},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"]) self.assertEqual(21, channel.json_body["burst_count"])
@ -3383,7 +3478,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"]) self.assertEqual(21, channel.json_body["burst_count"])
@ -3393,7 +3488,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)
@ -3403,6 +3498,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)

Loading…
Cancel
Save