@ -14,8 +14,9 @@
import itertools
import logging
from typing import TYPE_CHECKING , Dict , Iterable , List , Optional
from typing import TYPE_CHECKING , Collection , Dict , Iterable , List , Optional , Set , Tuple
import attr
from unpaddedbase64 import decode_base64 , encode_base64
from synapse . api . constants import EventTypes , Membership
@ -32,6 +33,20 @@ if TYPE_CHECKING:
logger = logging . getLogger ( __name__ )
@attr . s ( slots = True , frozen = True , auto_attribs = True )
class _SearchResult :
# The count of results.
count : int
# A mapping of event ID to the rank of that event.
rank_map : Dict [ str , int ]
# A list of the resulting events.
allowed_events : List [ EventBase ]
# A map of room ID to results.
room_groups : Dict [ str , JsonDict ]
# A set of event IDs to highlight.
highlights : Set [ str ]
class SearchHandler :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
@ -100,7 +115,7 @@ class SearchHandler:
""" Performs a full text search for a user.
Args :
user
user : The user performing the search .
content : Search parameters
batch : The next_batch parameter . Used for pagination .
@ -156,6 +171,8 @@ class SearchHandler:
# Include context around each event?
event_context = room_cat . get ( " event_context " , None )
before_limit = after_limit = None
include_profile = False
# Group results together? May allow clients to paginate within a
# group
@ -182,6 +199,73 @@ class SearchHandler:
% ( set ( group_keys ) - { " room_id " , " sender " } , ) ,
)
return await self . _search (
user ,
batch_group ,
batch_group_key ,
batch_token ,
search_term ,
keys ,
filter_dict ,
order_by ,
include_state ,
group_keys ,
event_context ,
before_limit ,
after_limit ,
include_profile ,
)
async def _search (
self ,
user : UserID ,
batch_group : Optional [ str ] ,
batch_group_key : Optional [ str ] ,
batch_token : Optional [ str ] ,
search_term : str ,
keys : List [ str ] ,
filter_dict : JsonDict ,
order_by : str ,
include_state : bool ,
group_keys : List [ str ] ,
event_context : Optional [ bool ] ,
before_limit : Optional [ int ] ,
after_limit : Optional [ int ] ,
include_profile : bool ,
) - > JsonDict :
""" Performs a full text search for a user.
Args :
user : The user performing the search .
batch_group : Pagination information .
batch_group_key : Pagination information .
batch_token : Pagination information .
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
filter_dict : The JSON to build a filter out of .
order_by : How to order the results . Valid values ore " rank " and " recent " .
include_state : True if the state of the room at each result should
be included .
group_keys : A list of ways to group the results . Valid values are
" room_id " and " sender " .
event_context : True to include contextual events around results .
before_limit :
The number of events before a result to include as context .
Only used if event_context is True .
after_limit :
The number of events after a result to include as context .
Only used if event_context is True .
include_profile : True if historical profile information should be
included in the event context .
Only used if event_context is True .
Returns :
dict to be returned to the client with results of search
"""
search_filter = Filter ( self . hs , filter_dict )
# TODO: Search through left rooms too
@ -216,278 +300,399 @@ class SearchHandler:
}
}
rank_map = { } # event_id -> rank of event
allowed_events = [ ]
# Holds result of grouping by room, if applicable
room_groups : Dict [ str , JsonDict ] = { }
# Holds result of grouping by sender, if applicable
sender_group : Dict [ str , JsonDict ] = { }
sender_group : Optional [ Dict [ str , JsonDict ] ]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
highlights = set ( )
if order_by == " rank " :
search_result , sender_group = await self . _search_by_rank (
user , room_ids , search_term , keys , search_filter
)
# Unused return values for rank search.
global_next_batch = None
elif order_by == " recent " :
search_result , global_next_batch = await self . _search_by_recent (
user ,
room_ids ,
search_term ,
keys ,
search_filter ,
batch_group ,
batch_group_key ,
batch_token ,
)
# Unused return values for recent search.
sender_group = None
else :
# We should never get here due to the guard earlier.
raise NotImplementedError ( )
count = None
logger . info ( " Found %d events to return " , len ( search_result . allowed_events ) )
if order_by == " rank " :
search_result = await self . store . search_msgs ( room_ids , search_term , keys )
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None :
# Note that before and after limit must be set in this case.
assert before_limit is not None
assert after_limit is not None
contexts = await self . _calculate_event_contexts (
user ,
search_result . allowed_events ,
before_limit ,
after_limit ,
include_profile ,
)
else :
contexts = { }
count = search_result [ " count " ]
# TODO: Add a limit
if search_result [ " highlights " ] :
highlights . update ( search_result [ " highlights " ] )
state_results = { }
if include_state :
for room_id in { e . room_id for e in search_result . allowed_events } :
state = await self . state_handler . get_current_state ( room_id )
state_results [ room_id ] = list ( state . values ( ) )
results = search_result [ " results " ]
aggregations = None
if self . _msc3666_enabled :
aggregations = await self . store . get_bundled_aggregations (
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools . chain (
# The events_before and events_after for each context.
itertools . chain . from_iterable (
itertools . chain ( context [ " events_before " ] , context [ " events_after " ] ) # type: ignore[arg-type]
for context in contexts . values ( )
) ,
# The returned events.
search_result . allowed_events ,
) ,
user . to_string ( ) ,
)
rank_map . update ( { r [ " event " ] . event_id : r [ " rank " ] for r in results } )
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
filtered_events = await search_filter . filter ( [ r [ " event " ] for r in results ] )
time_now = self . clock . time_msec ( )
events = await filter_events_for_client (
self . storage , user . to_string ( ) , filtered_events
for context in contexts . values ( ) :
context [ " events_before " ] = self . _event_serializer . serialize_events (
context [ " events_before " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
context [ " events_after " ] = self . _event_serializer . serialize_events (
context [ " events_after " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
events . sort ( key = lambda e : - rank_map [ e . event_id ] )
allowed_events = events [ : search_filter . limit ]
results = [
{
" rank " : search_result . rank_map [ e . event_id ] ,
" result " : self . _event_serializer . serialize_event (
e , time_now , bundle_aggregations = aggregations
) ,
" context " : contexts . get ( e . event_id , { } ) ,
}
for e in search_result . allowed_events
]
for e in allowed_events :
rm = room_groups . setdefault (
e . room_id , { " results " : [ ] , " order " : rank_map [ e . event_id ] }
)
rm [ " results " ] . append ( e . event_id )
rooms_cat_res : JsonDict = {
" results " : results ,
" count " : search_result . count ,
" highlights " : list ( search_result . highlights ) ,
}
s = sender_group . setdefault (
e . sender , { " results " : [ ] , " order " : rank_map [ e . event_id ] }
)
s [ " results " ] . append ( e . event_id )
if state_results :
rooms_cat_res [ " state " ] = {
room_id : self . _event_serializer . serialize_events ( state_events , time_now )
for room_id , state_events in state_results . items ( )
}
elif order_by == " recent " :
room_events : List [ EventBase ] = [ ]
i = 0
pagination_token = batch_token
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len ( room_events ) < search_filter . limit and i < 5 :
i + = 1
search_result = await self . store . search_rooms (
room_ids ,
search_term ,
keys ,
search_filter . limit * 2 ,
pagination_token = pagination_token ,
)
if search_result . room_groups and " room_id " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [
" room_id "
] = search_result . room_groups
if search_result [ " highlights " ] :
highlights . update ( search_result [ " highlight s" ] )
if sender_group and " sender " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " sender " ] = sender_group
count = search_result [ " count " ]
if global_next_batch :
rooms_cat_res [ " next_batch " ] = global_next_batch
results = search_result [ " results " ]
return { " search_categories " : { " room_events " : rooms_cat_res } }
results_map = { r [ " event " ] . event_id : r for r in results }
async def _search_by_rank (
self ,
user : UserID ,
room_ids : Collection [ str ] ,
search_term : str ,
keys : Iterable [ str ] ,
search_filter : Filter ,
) - > Tuple [ _SearchResult , Dict [ str , JsonDict ] ] :
"""
Performs a full text search for a user ordering by rank .
rank_map . update ( { r [ " event " ] . event_id : r [ " rank " ] for r in results } )
Args :
user : The user performing the search .
room_ids : List of room ids to search in
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
search_filter : The event filter to use .
filtered_events = await search_filter . filter (
[ r [ " event " ] for r in results ]
)
Returns :
A tuple of :
The search results .
A map of sender ID to results .
"""
rank_map = { } # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups : Dict [ str , JsonDict ] = { }
# Holds result of grouping by sender, if applicable
sender_group : Dict [ str , JsonDict ] = { }
events = await filter_events_for_client (
self . storage , user . to_string ( ) , filtered_events
)
search_result = await self . store . search_msgs ( room_ids , search_term , keys )
room_events . extend ( events )
room_events = room_events [ : search_filter . limit ]
if search_result [ " highlights " ] :
highlights = search_result [ " highlights " ]
else :
highlights = set ( )
if len ( results ) < search_filter . limit * 2 :
pagination_token = None
break
else :
pagination_token = results [ - 1 ] [ " pagination_token " ]
for event in room_events :
group = room_groups . setdefault ( event . room_id , { " results " : [ ] } )
group [ " results " ] . append ( event . event_id )
if room_events and len ( room_events ) > = search_filter . limit :
last_event_id = room_events [ - 1 ] . event_id
pagination_token = results_map [ last_event_id ] [ " pagination_token " ]
# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key :
global_next_batch = encode_base64 (
(
" %s \n %s \n %s "
% ( batch_group , batch_group_key , pagination_token )
) . encode ( " ascii " )
)
else :
global_next_batch = encode_base64 (
( " %s \n %s \n %s " % ( " all " , " " , pagination_token ) ) . encode ( " ascii " )
)
results = search_result [ " results " ]
for room_id , group in room_groups . items ( ) :
group [ " next_batch " ] = encode_base64 (
( " %s \n %s \n %s " % ( " room_id " , room_id , pagination_token ) ) . encode (
" ascii "
)
)
# event_id -> rank of event
rank_map = { r [ " event " ] . event_id : r [ " rank " ] for r in results }
allowed_events . extend ( room_events )
filtered_events = await search_filter . filter ( [ r [ " event " ] for r in results ] )
else :
# We should never get here due to the guard earlier.
raise NotImplementedError ( )
events = await filter_events_for_client (
self . storage , user . to_string ( ) , filtered_events
)
logger . info ( " Found %d events to return " , len ( allowed_events ) )
events . sort ( key = lambda e : - rank_map [ e . event_id ] )
allowed_events = events [ : search_filter . limit ]
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None :
now_token = self . hs . get_event_sources ( ) . get_current_token ( )
for e in allowed_events :
rm = room_groups . setdefault (
e . room_id , { " results " : [ ] , " order " : rank_map [ e . event_id ] }
)
rm [ " results " ] . append ( e . event_id )
contexts = { }
for event in allowed_events :
res = await self . store . get_events_around (
event . room_id , event . event_id , before_limit , after_limit
)
s = sender_group . setdefault (
e . sender , { " results " : [ ] , " order " : rank_map [ e . event_id ] }
)
s [ " results " ] . append ( e . event_id )
return (
_SearchResult (
search_result [ " count " ] ,
rank_map ,
allowed_events ,
room_groups ,
highlights ,
) ,
sender_group ,
)
logger . info (
" Context for search returned %d and %d events " ,
len ( res . events_before ) ,
len ( res . events_after ) ,
)
async def _search_by_recent (
self ,
user : UserID ,
room_ids : Collection [ str ] ,
search_term : str ,
keys : Iterable [ str ] ,
search_filter : Filter ,
batch_group : Optional [ str ] ,
batch_group_key : Optional [ str ] ,
batch_token : Optional [ str ] ,
) - > Tuple [ _SearchResult , Optional [ str ] ] :
"""
Performs a full text search for a user ordering by recent .
events_before = await filter_events_for_client (
self . storage , user . to_string ( ) , res . events_before
)
Args :
user : The user performing the search .
room_ids : List of room ids to search in
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
search_filter : The event filter to use .
batch_group : Pagination information .
batch_group_key : Pagination information .
batch_token : Pagination information .
events_after = await filter_events_for_client (
self . storage , user . to_string ( ) , res . events_after
)
Returns :
A tuple of :
The search results .
Optionally , a pagination token .
"""
rank_map = { } # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups : Dict [ str , JsonDict ] = { }
context = {
" events_before " : events_before ,
" events_after " : events_after ,
" start " : await now_token . copy_and_replace (
" room_key " , res . start
) . to_string ( self . store ) ,
" end " : await now_token . copy_and_replace (
" room_key " , res . end
) . to_string ( self . store ) ,
}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
if include_profile :
senders = {
ev . sender
for ev in itertools . chain ( events_before , [ event ] , events_after )
}
highlights = set ( )
if events_after :
last_event_id = events_after [ - 1 ] . event_id
else :
last_event_id = event . event_id
room_events : List [ EventBase ] = [ ]
i = 0
pagination_token = batch_token
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len ( room_events ) < search_filter . limit and i < 5 :
i + = 1
search_result = await self . store . search_rooms (
room_ids ,
search_term ,
keys ,
search_filter . limit * 2 ,
pagination_token = pagination_token ,
)
state_filter = StateFilter . from_types (
[ ( EventTypes . Member , sender ) for sender in senders ]
)
if search_result [ " highlights " ] :
highlights . update ( search_result [ " highlights " ] )
count = search_result [ " count " ]
results = search_result [ " results " ]
results_map = { r [ " event " ] . event_id : r for r in results }
rank_map . update ( { r [ " event " ] . event_id : r [ " rank " ] for r in results } )
filtered_events = await search_filter . filter ( [ r [ " event " ] for r in results ] )
state = await self . state_store . get_state_for_event (
last_event_id , state_filter
events = await filter_events_for_client (
self . storage , user . to_string ( ) , filtered_events
)
room_events . extend ( events )
room_events = room_events [ : search_filter . limit ]
if len ( results ) < search_filter . limit * 2 :
break
else :
pagination_token = results [ - 1 ] [ " pagination_token " ]
for event in room_events :
group = room_groups . setdefault ( event . room_id , { " results " : [ ] } )
group [ " results " ] . append ( event . event_id )
if room_events and len ( room_events ) > = search_filter . limit :
last_event_id = room_events [ - 1 ] . event_id
pagination_token = results_map [ last_event_id ] [ " pagination_token " ]
# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key :
global_next_batch = encode_base64 (
(
" %s \n %s \n %s " % ( batch_group , batch_group_key , pagination_token )
) . encode ( " ascii " )
)
else :
global_next_batch = encode_base64 (
( " %s \n %s \n %s " % ( " all " , " " , pagination_token ) ) . encode ( " ascii " )
)
for room_id , group in room_groups . items ( ) :
group [ " next_batch " ] = encode_base64 (
( " %s \n %s \n %s " % ( " room_id " , room_id , pagination_token ) ) . encode (
" ascii "
)
)
context [ " profile_info " ] = {
s . state_key : {
" displayname " : s . content . get ( " displayname " , None ) ,
" avatar_url " : s . content . get ( " avatar_url " , None ) ,
}
for s in state . values ( )
if s . type == EventTypes . Member and s . state_key in senders
}
return (
_SearchResult ( count , rank_map , room_events , room_groups , highlights ) ,
global_next_batch ,
)
contexts [ event . event_id ] = context
else :
contexts = { }
async def _calculate_event_contexts (
self ,
user : UserID ,
allowed_events : List [ EventBase ] ,
before_limit : int ,
after_limit : int ,
include_profile : bool ,
) - > Dict [ str , JsonDict ] :
"""
Calculates the contextual events for any search results .
# TODO: Add a limit
Args :
user : The user performing the search .
allowed_events : The search results .
before_limit :
The number of events before a result to include as context .
after_limit :
The number of events after a result to include as context .
include_profile : True if historical profile information should be
included in the event context .
time_now = self . clock . time_msec ( )
Returns :
A map of event ID to contextual information .
"""
now_token = self . hs . get_event_sources ( ) . get_current_token ( )
aggregations = None
if self . _msc3666_enabled :
aggregations = await self . store . get_bundled_aggregations (
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools . chain (
# The events_before and events_after for each context.
itertools . chain . from_iterable (
itertools . chain ( context [ " events_before " ] , context [ " events_after " ] ) # type: ignore[arg-type]
for context in contexts . values ( )
) ,
# The returned events.
allowed_events ,
) ,
user . to_string ( ) ,
contexts = { }
for event in allowed_events :
res = await self . store . get_events_around (
event . room_id , event . event_id , before_limit , after_limit
)
for context in contexts . values ( ) :
context [ " events_before " ] = self . _event_serializer . serialize_events (
context [ " events_before " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
logger . info (
" Context for search returned %d and %d events " ,
len ( res . events_before ) ,
len ( res . events_after ) ,
)
context [ " events_after " ] = self . _event_serializer . serialize_events (
context [ " events_after " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
events_before = await filter_events_for_client (
self . storage , user . to_string ( ) , res . events_before
)
state_results = { }
if include_state :
for room_id in { e . room_id for e in allowed_events } :
state = await self . state_handler . get_current_state ( room_id )
state_results [ room_id ] = list ( state . values ( ) )
events_after = await filter_events_for_client (
self . storage , user . to_string ( ) , res . events_after
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
context = {
" events_before " : events_before ,
" events_after " : events_after ,
" start " : await now_token . copy_and_replace (
" room_key " , res . start
) . to_string ( self . store ) ,
" end " : await now_token . copy_and_replace ( " room_key " , res . end ) . to_string (
self . store
) ,
}
results = [ ]
for e in allowed_events :
results . append (
{
" rank " : rank_map [ e . event_id ] ,
" result " : self . _event_serializer . serialize_event (
e , time_now , bundle_aggregations = aggregations
) ,
" context " : contexts . get ( e . event_id , { } ) ,
if include_profile :
senders = {
ev . sender
for ev in itertools . chain ( events_before , [ event ] , events_after )
}
)
rooms_cat_res = {
" results " : results ,
" count " : count ,
" highlights " : list ( highlights ) ,
}
if events_after :
last_event_id = events_after [ - 1 ] . event_id
else :
last_event_id = event . event_id
if state_results :
s = { }
for room_id , state_events in state_results . items ( ) :
s [ room_id ] = self . _event_serializer . serialize_events (
state_events , time_now
state_filter = StateFilter . from_types (
[ ( EventTypes . Member , sender ) for sender in senders ]
)
rooms_cat_res [ " state " ] = s
if room_groups and " room_id " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " room_id " ] = room_groups
state = await self . state_store . get_state_for_event (
last_event_id , state_filter
)
if sender_group and " sender " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " sender " ] = sender_group
context [ " profile_info " ] = {
s . state_key : {
" displayname " : s . content . get ( " displayname " , None ) ,
" avatar_url " : s . content . get ( " avatar_url " , None ) ,
}
for s in state . values ( )
if s . type == EventTypes . Member and s . state_key in senders
}
if global_next_batch :
rooms_cat_res [ " next_batch " ] = global_next_batch
contexts [ event . event_id ] = context
return { " search_categories " : { " room_events " : rooms_cat_res } }
return contexts