mirror of https://github.com/watcha-fr/synapse
parent
6e70979973
commit
95f30ecd1f
@ -0,0 +1,111 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright 2015 OpenMarket Ltd |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from ._base import client_v2_patterns |
||||
|
||||
from synapse.http.servlet import RestServlet |
||||
from synapse.api.errors import AuthError, SynapseError |
||||
|
||||
from twisted.internet import defer |
||||
|
||||
import logging |
||||
|
||||
import simplejson as json |
||||
|
||||
logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
class AccountDataServlet(RestServlet): |
||||
""" |
||||
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 |
||||
""" |
||||
PATTERNS = client_v2_patterns( |
||||
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" |
||||
) |
||||
|
||||
def __init__(self, hs): |
||||
super(AccountDataServlet, self).__init__() |
||||
self.auth = hs.get_auth() |
||||
self.store = hs.get_datastore() |
||||
self.notifier = hs.get_notifier() |
||||
|
||||
@defer.inlineCallbacks |
||||
def on_PUT(self, request, user_id, account_data_type): |
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request) |
||||
if user_id != auth_user.to_string(): |
||||
raise AuthError(403, "Cannot add account data for other users.") |
||||
|
||||
try: |
||||
content_bytes = request.content.read() |
||||
body = json.loads(content_bytes) |
||||
except: |
||||
raise SynapseError(400, "Invalid JSON") |
||||
|
||||
max_id = yield self.store.add_account_data_for_user( |
||||
user_id, account_data_type, body |
||||
) |
||||
|
||||
yield self.notifier.on_new_event( |
||||
"account_data_key", max_id, users=[user_id] |
||||
) |
||||
|
||||
defer.returnValue((200, {})) |
||||
|
||||
|
||||
class RoomAccountDataServlet(RestServlet): |
||||
""" |
||||
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 |
||||
""" |
||||
PATTERNS = client_v2_patterns( |
||||
"/user/(?P<user_id>[^/]*)" |
||||
"/rooms/(?P<room_id>[^/]*)" |
||||
"/account_data/(?P<account_data_type>[^/]*)" |
||||
) |
||||
|
||||
def __init__(self, hs): |
||||
super(RoomAccountDataServlet, self).__init__() |
||||
self.auth = hs.get_auth() |
||||
self.store = hs.get_datastore() |
||||
self.notifier = hs.get_notifier() |
||||
|
||||
@defer.inlineCallbacks |
||||
def on_PUT(self, request, user_id, room_id, account_data_type): |
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request) |
||||
if user_id != auth_user.to_string(): |
||||
raise AuthError(403, "Cannot add account data for other users.") |
||||
|
||||
try: |
||||
content_bytes = request.content.read() |
||||
body = json.loads(content_bytes) |
||||
except: |
||||
raise SynapseError(400, "Invalid JSON") |
||||
|
||||
if not isinstance(body, dict): |
||||
raise ValueError("Expected a JSON object") |
||||
|
||||
max_id = yield self.store.add_account_data_to_room( |
||||
user_id, room_id, account_data_type, body |
||||
) |
||||
|
||||
yield self.notifier.on_new_event( |
||||
"account_data_key", max_id, users=[user_id] |
||||
) |
||||
|
||||
defer.returnValue((200, {})) |
||||
|
||||
|
||||
def register_servlets(hs, http_server): |
||||
AccountDataServlet(hs).register(http_server) |
||||
RoomAccountDataServlet(hs).register(http_server) |
@ -0,0 +1,211 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright 2014, 2015 OpenMarket Ltd |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from ._base import SQLBaseStore |
||||
from twisted.internet import defer |
||||
|
||||
import ujson as json |
||||
import logging |
||||
|
||||
logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
class AccountDataStore(SQLBaseStore): |
||||
|
||||
def get_account_data_for_user(self, user_id): |
||||
"""Get all the client account_data for a user. |
||||
|
||||
Args: |
||||
user_id(str): The user to get the account_data for. |
||||
Returns: |
||||
A deferred pair of a dict of global account_data and a dict |
||||
mapping from room_id string to per room account_data dicts. |
||||
""" |
||||
|
||||
def get_account_data_for_user_txn(txn): |
||||
rows = self._simple_select_list_txn( |
||||
txn, "account_data", {"user_id": user_id}, |
||||
["account_data_type", "content"] |
||||
) |
||||
|
||||
global_account_data = { |
||||
row["account_data_type"]: json.loads(row["content"]) for row in rows |
||||
} |
||||
|
||||
rows = self._simple_select_list_txn( |
||||
txn, "room_account_data", {"user_id": user_id}, |
||||
["room_id", "account_data_type", "content"] |
||||
) |
||||
|
||||
by_room = {} |
||||
for row in rows: |
||||
room_data = by_room.setdefault(row["room_id"], {}) |
||||
room_data[row["account_data_type"]] = json.loads(row["content"]) |
||||
|
||||
return (global_account_data, by_room) |
||||
|
||||
return self.runInteraction( |
||||
"get_account_data_for_user", get_account_data_for_user_txn |
||||
) |
||||
|
||||
def get_account_data_for_room(self, user_id, room_id): |
||||
"""Get all the client account_data for a user for a room. |
||||
|
||||
Args: |
||||
user_id(str): The user to get the account_data for. |
||||
room_id(str): The room to get the account_data for. |
||||
Returns: |
||||
A deferred dict of the room account_data |
||||
""" |
||||
def get_account_data_for_room_txn(txn): |
||||
rows = self._simple_select_list_txn( |
||||
txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, |
||||
["account_data_type", "content"] |
||||
) |
||||
|
||||
return { |
||||
row["account_data_type"]: json.loads(row["content"]) for row in rows |
||||
} |
||||
|
||||
return self.runInteraction( |
||||
"get_account_data_for_room", get_account_data_for_room_txn |
||||
) |
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id): |
||||
"""Get all the client account_data for a that's changed. |
||||
|
||||
Args: |
||||
user_id(str): The user to get the account_data for. |
||||
stream_id(int): The point in the stream since which to get updates |
||||
Returns: |
||||
A deferred pair of a dict of global account_data and a dict |
||||
mapping from room_id string to per room account_data dicts. |
||||
""" |
||||
|
||||
def get_updated_account_data_for_user_txn(txn): |
||||
sql = ( |
||||
"SELECT account_data_type, content FROM account_data" |
||||
" WHERE user_id = ? AND stream_id > ?" |
||||
) |
||||
|
||||
txn.execute(sql, (user_id, stream_id)) |
||||
|
||||
global_account_data = { |
||||
row[0]: json.loads(row[1]) for row in txn.fetchall() |
||||
} |
||||
|
||||
sql = ( |
||||
"SELECT room_id, account_data_type, content FROM room_account_data" |
||||
" WHERE user_id = ? AND stream_id > ?" |
||||
) |
||||
|
||||
txn.execute(sql, (user_id, stream_id)) |
||||
|
||||
account_data_by_room = {} |
||||
for row in txn.fetchall(): |
||||
room_account_data = account_data_by_room.setdefault(row[0], {}) |
||||
room_account_data[row[1]] = json.loads(row[2]) |
||||
|
||||
return (global_account_data, account_data_by_room) |
||||
|
||||
return self.runInteraction( |
||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn |
||||
) |
||||
|
||||
@defer.inlineCallbacks |
||||
def add_account_data_to_room(self, user_id, room_id, account_data_type, content): |
||||
"""Add some account_data to a room for a user. |
||||
Args: |
||||
user_id(str): The user to add a tag for. |
||||
room_id(str): The room to add a tag for. |
||||
account_data_type(str): The type of account_data to add. |
||||
content(dict): A json object to associate with the tag. |
||||
Returns: |
||||
A deferred that completes once the account_data has been added. |
||||
""" |
||||
content_json = json.dumps(content) |
||||
|
||||
def add_account_data_txn(txn, next_id): |
||||
self._simple_upsert_txn( |
||||
txn, |
||||
table="room_account_data", |
||||
keyvalues={ |
||||
"user_id": user_id, |
||||
"room_id": room_id, |
||||
"account_data_type": account_data_type, |
||||
}, |
||||
values={ |
||||
"stream_id": next_id, |
||||
"content": content_json, |
||||
} |
||||
) |
||||
self._update_max_stream_id(txn, next_id) |
||||
|
||||
with (yield self._account_data_id_gen.get_next(self)) as next_id: |
||||
yield self.runInteraction( |
||||
"add_room_account_data", add_account_data_txn, next_id |
||||
) |
||||
|
||||
result = yield self._account_data_id_gen.get_max_token(self) |
||||
defer.returnValue(result) |
||||
|
||||
@defer.inlineCallbacks |
||||
def add_account_data_for_user(self, user_id, account_data_type, content): |
||||
"""Add some account_data to a room for a user. |
||||
Args: |
||||
user_id(str): The user to add a tag for. |
||||
account_data_type(str): The type of account_data to add. |
||||
content(dict): A json object to associate with the tag. |
||||
Returns: |
||||
A deferred that completes once the account_data has been added. |
||||
""" |
||||
content_json = json.dumps(content) |
||||
|
||||
def add_account_data_txn(txn, next_id): |
||||
self._simple_upsert_txn( |
||||
txn, |
||||
table="account_data", |
||||
keyvalues={ |
||||
"user_id": user_id, |
||||
"account_data_type": account_data_type, |
||||
}, |
||||
values={ |
||||
"stream_id": next_id, |
||||
"content": content_json, |
||||
} |
||||
) |
||||
self._update_max_stream_id(txn, next_id) |
||||
|
||||
with (yield self._account_data_id_gen.get_next(self)) as next_id: |
||||
yield self.runInteraction( |
||||
"add_user_account_data", add_account_data_txn, next_id |
||||
) |
||||
|
||||
result = yield self._account_data_id_gen.get_max_token(self) |
||||
defer.returnValue(result) |
||||
|
||||
def _update_max_stream_id(self, txn, next_id): |
||||
"""Update the max stream_id |
||||
|
||||
Args: |
||||
txn: The database cursor |
||||
next_id(int): The the revision to advance to. |
||||
""" |
||||
update_max_id_sql = ( |
||||
"UPDATE account_data_max_stream_id" |
||||
" SET stream_id = ?" |
||||
" WHERE stream_id < ?" |
||||
) |
||||
txn.execute(update_max_id_sql, (next_id, next_id)) |
Loading…
Reference in new issue