|
|
|
@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( |
|
|
|
|
"Unrecognised request: You probably wanted a trailing slash") |
|
|
|
|
|
|
|
|
|
def __init__(self, hs): |
|
|
|
|
super(PushRuleRestServlet, self).__init__(hs) |
|
|
|
|
self.store = hs.get_datastore() |
|
|
|
|
self.notifier = hs.get_notifier() |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_PUT(self, request): |
|
|
|
|
spec = _rule_spec_from_path(request.postpath) |
|
|
|
@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
|
|
|
|
|
content = _parse_json(request) |
|
|
|
|
|
|
|
|
|
user_id = requester.user.to_string() |
|
|
|
|
|
|
|
|
|
if 'attr' in spec: |
|
|
|
|
yield self.set_rule_attr(requester.user.to_string(), spec, content) |
|
|
|
|
yield self.set_rule_attr(user_id, spec, content) |
|
|
|
|
self.notify_user(user_id) |
|
|
|
|
defer.returnValue((200, {})) |
|
|
|
|
|
|
|
|
|
if spec['rule_id'].startswith('.'): |
|
|
|
@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
after = _namespaced_rule_id(spec, after[0]) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
yield self.hs.get_datastore().add_push_rule( |
|
|
|
|
user_id=requester.user.to_string(), |
|
|
|
|
yield self.store.add_push_rule( |
|
|
|
|
user_id=user_id, |
|
|
|
|
rule_id=_namespaced_rule_id_from_spec(spec), |
|
|
|
|
priority_class=priority_class, |
|
|
|
|
conditions=conditions, |
|
|
|
@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
before=before, |
|
|
|
|
after=after |
|
|
|
|
) |
|
|
|
|
self.notify_user(user_id) |
|
|
|
|
except InconsistentRuleException as e: |
|
|
|
|
raise SynapseError(400, e.message) |
|
|
|
|
except RuleNotFoundException as e: |
|
|
|
@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
spec = _rule_spec_from_path(request.postpath) |
|
|
|
|
|
|
|
|
|
requester = yield self.auth.get_user_by_req(request) |
|
|
|
|
user_id = requester.user.to_string() |
|
|
|
|
|
|
|
|
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
yield self.hs.get_datastore().delete_push_rule( |
|
|
|
|
requester.user.to_string(), namespaced_rule_id |
|
|
|
|
yield self.store.delete_push_rule( |
|
|
|
|
user_id, namespaced_rule_id |
|
|
|
|
) |
|
|
|
|
self.notify_user(user_id) |
|
|
|
|
defer.returnValue((200, {})) |
|
|
|
|
except StoreError as e: |
|
|
|
|
if e.code == 404: |
|
|
|
@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_GET(self, request): |
|
|
|
|
requester = yield self.auth.get_user_by_req(request) |
|
|
|
|
user = requester.user |
|
|
|
|
user_id = requester.user.to_string() |
|
|
|
|
|
|
|
|
|
# we build up the full structure and then decide which bits of it |
|
|
|
|
# to send which means doing unnecessary work sometimes but is |
|
|
|
|
# is probably not going to make a whole lot of difference |
|
|
|
|
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( |
|
|
|
|
user.to_string() |
|
|
|
|
) |
|
|
|
|
rawrules = yield self.store.get_push_rules_for_user(user_id) |
|
|
|
|
|
|
|
|
|
ruleslist = [] |
|
|
|
|
for rawrule in rawrules: |
|
|
|
@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
|
|
|
|
|
rules['global'] = _add_empty_priority_class_arrays(rules['global']) |
|
|
|
|
|
|
|
|
|
enabled_map = yield self.hs.get_datastore().\ |
|
|
|
|
get_push_rules_enabled_for_user(user.to_string()) |
|
|
|
|
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) |
|
|
|
|
|
|
|
|
|
for r in ruleslist: |
|
|
|
|
rulearray = None |
|
|
|
@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
|
|
|
|
|
pattern_type = c.pop("pattern_type", None) |
|
|
|
|
if pattern_type == "user_id": |
|
|
|
|
c["pattern"] = user.to_string() |
|
|
|
|
c["pattern"] = user_id |
|
|
|
|
elif pattern_type == "user_localpart": |
|
|
|
|
c["pattern"] = user.localpart |
|
|
|
|
c["pattern"] = requester.user.localpart |
|
|
|
|
|
|
|
|
|
rulearray = rules['global'][template_name] |
|
|
|
|
|
|
|
|
@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
def on_OPTIONS(self, _): |
|
|
|
|
return 200, {} |
|
|
|
|
|
|
|
|
|
def notify_user(self, user_id): |
|
|
|
|
stream_id = self.store.get_push_rules_stream_token() |
|
|
|
|
self.notifier.on_new_event( |
|
|
|
|
"push_rules_key", stream_id, users=[user_id] |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def set_rule_attr(self, user_id, spec, val): |
|
|
|
|
if spec['attr'] == 'enabled': |
|
|
|
|
if isinstance(val, dict) and "enabled" in val: |
|
|
|
@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
# bools directly, so let's not break them. |
|
|
|
|
raise SynapseError(400, "Value for 'enabled' must be boolean") |
|
|
|
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) |
|
|
|
|
return self.hs.get_datastore().set_push_rule_enabled( |
|
|
|
|
return self.store.set_push_rule_enabled( |
|
|
|
|
user_id, namespaced_rule_id, val |
|
|
|
|
) |
|
|
|
|
elif spec['attr'] == 'actions': |
|
|
|
@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet): |
|
|
|
|
if is_default_rule: |
|
|
|
|
if namespaced_rule_id not in BASE_RULE_IDS: |
|
|
|
|
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) |
|
|
|
|
return self.hs.get_datastore().set_push_rule_actions( |
|
|
|
|
return self.store.set_push_rule_actions( |
|
|
|
|
user_id, namespaced_rule_id, actions, is_default_rule |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|