@ -14,8 +14,9 @@
import itertools
import logging
from typing import TYPE_CHECKING , Dict , Iterable , List , Optional
from typing import TYPE_CHECKING , Collection , Dict , Iterable , List , Optional , Set , Tuple
import attr
from unpaddedbase64 import decode_base64 , encode_base64
from synapse . api . constants import EventTypes , Membership
@ -32,6 +33,20 @@ if TYPE_CHECKING:
logger = logging . getLogger ( __name__ )
@attr . s ( slots = True , frozen = True , auto_attribs = True )
class _SearchResult :
# The count of results.
count : int
# A mapping of event ID to the rank of that event.
rank_map : Dict [ str , int ]
# A list of the resulting events.
allowed_events : List [ EventBase ]
# A map of room ID to results.
room_groups : Dict [ str , JsonDict ]
# A set of event IDs to highlight.
highlights : Set [ str ]
class SearchHandler :
def __init__ ( self , hs : " HomeServer " ) :
self . store = hs . get_datastore ( )
@ -100,7 +115,7 @@ class SearchHandler:
""" Performs a full text search for a user.
Args :
user
user : The user performing the search .
content : Search parameters
batch : The next_batch parameter . Used for pagination .
@ -156,6 +171,8 @@ class SearchHandler:
# Include context around each event?
event_context = room_cat . get ( " event_context " , None )
before_limit = after_limit = None
include_profile = False
# Group results together? May allow clients to paginate within a
# group
@ -182,6 +199,73 @@ class SearchHandler:
% ( set ( group_keys ) - { " room_id " , " sender " } , ) ,
)
return await self . _search (
user ,
batch_group ,
batch_group_key ,
batch_token ,
search_term ,
keys ,
filter_dict ,
order_by ,
include_state ,
group_keys ,
event_context ,
before_limit ,
after_limit ,
include_profile ,
)
async def _search (
self ,
user : UserID ,
batch_group : Optional [ str ] ,
batch_group_key : Optional [ str ] ,
batch_token : Optional [ str ] ,
search_term : str ,
keys : List [ str ] ,
filter_dict : JsonDict ,
order_by : str ,
include_state : bool ,
group_keys : List [ str ] ,
event_context : Optional [ bool ] ,
before_limit : Optional [ int ] ,
after_limit : Optional [ int ] ,
include_profile : bool ,
) - > JsonDict :
""" Performs a full text search for a user.
Args :
user : The user performing the search .
batch_group : Pagination information .
batch_group_key : Pagination information .
batch_token : Pagination information .
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
filter_dict : The JSON to build a filter out of .
order_by : How to order the results . Valid values ore " rank " and " recent " .
include_state : True if the state of the room at each result should
be included .
group_keys : A list of ways to group the results . Valid values are
" room_id " and " sender " .
event_context : True to include contextual events around results .
before_limit :
The number of events before a result to include as context .
Only used if event_context is True .
after_limit :
The number of events after a result to include as context .
Only used if event_context is True .
include_profile : True if historical profile information should be
included in the event context .
Only used if event_context is True .
Returns :
dict to be returned to the client with results of search
"""
search_filter = Filter ( self . hs , filter_dict )
# TODO: Search through left rooms too
@ -216,31 +300,165 @@ class SearchHandler:
}
}
sender_group : Optional [ Dict [ str , JsonDict ] ]
if order_by == " rank " :
search_result , sender_group = await self . _search_by_rank (
user , room_ids , search_term , keys , search_filter
)
# Unused return values for rank search.
global_next_batch = None
elif order_by == " recent " :
search_result , global_next_batch = await self . _search_by_recent (
user ,
room_ids ,
search_term ,
keys ,
search_filter ,
batch_group ,
batch_group_key ,
batch_token ,
)
# Unused return values for recent search.
sender_group = None
else :
# We should never get here due to the guard earlier.
raise NotImplementedError ( )
logger . info ( " Found %d events to return " , len ( search_result . allowed_events ) )
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None :
# Note that before and after limit must be set in this case.
assert before_limit is not None
assert after_limit is not None
contexts = await self . _calculate_event_contexts (
user ,
search_result . allowed_events ,
before_limit ,
after_limit ,
include_profile ,
)
else :
contexts = { }
# TODO: Add a limit
state_results = { }
if include_state :
for room_id in { e . room_id for e in search_result . allowed_events } :
state = await self . state_handler . get_current_state ( room_id )
state_results [ room_id ] = list ( state . values ( ) )
aggregations = None
if self . _msc3666_enabled :
aggregations = await self . store . get_bundled_aggregations (
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools . chain (
# The events_before and events_after for each context.
itertools . chain . from_iterable (
itertools . chain ( context [ " events_before " ] , context [ " events_after " ] ) # type: ignore[arg-type]
for context in contexts . values ( )
) ,
# The returned events.
search_result . allowed_events ,
) ,
user . to_string ( ) ,
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
time_now = self . clock . time_msec ( )
for context in contexts . values ( ) :
context [ " events_before " ] = self . _event_serializer . serialize_events (
context [ " events_before " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
context [ " events_after " ] = self . _event_serializer . serialize_events (
context [ " events_after " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
results = [
{
" rank " : search_result . rank_map [ e . event_id ] ,
" result " : self . _event_serializer . serialize_event (
e , time_now , bundle_aggregations = aggregations
) ,
" context " : contexts . get ( e . event_id , { } ) ,
}
for e in search_result . allowed_events
]
rooms_cat_res : JsonDict = {
" results " : results ,
" count " : search_result . count ,
" highlights " : list ( search_result . highlights ) ,
}
if state_results :
rooms_cat_res [ " state " ] = {
room_id : self . _event_serializer . serialize_events ( state_events , time_now )
for room_id , state_events in state_results . items ( )
}
if search_result . room_groups and " room_id " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [
" room_id "
] = search_result . room_groups
if sender_group and " sender " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " sender " ] = sender_group
if global_next_batch :
rooms_cat_res [ " next_batch " ] = global_next_batch
return { " search_categories " : { " room_events " : rooms_cat_res } }
async def _search_by_rank (
self ,
user : UserID ,
room_ids : Collection [ str ] ,
search_term : str ,
keys : Iterable [ str ] ,
search_filter : Filter ,
) - > Tuple [ _SearchResult , Dict [ str , JsonDict ] ] :
"""
Performs a full text search for a user ordering by rank .
Args :
user : The user performing the search .
room_ids : List of room ids to search in
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
search_filter : The event filter to use .
Returns :
A tuple of :
The search results .
A map of sender ID to results .
"""
rank_map = { } # event_id -> rank of event
allowed_events = [ ]
# Holds result of grouping by room, if applicable
room_groups : Dict [ str , JsonDict ] = { }
# Holds result of grouping by sender, if applicable
sender_group : Dict [ str , JsonDict ] = { }
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
highlights = set ( )
count = None
if order_by == " rank " :
search_result = await self . store . search_msgs ( room_ids , search_term , keys )
count = search_result [ " count " ]
if search_result [ " highlights " ] :
highlights . update ( search_result [ " highlights " ] )
highlights = search_result [ " highlights " ]
else :
highlights = set ( )
results = search_result [ " results " ]
rank_map . update ( { r [ " event " ] . event_id : r [ " rank " ] for r in results } )
# event_id -> rank of event
rank_map = { r [ " event " ] . event_id : r [ " rank " ] for r in results }
filtered_events = await search_filter . filter ( [ r [ " event " ] for r in results ] )
@ -262,7 +480,56 @@ class SearchHandler:
)
s [ " results " ] . append ( e . event_id )
elif order_by == " recent " :
return (
_SearchResult (
search_result [ " count " ] ,
rank_map ,
allowed_events ,
room_groups ,
highlights ,
) ,
sender_group ,
)
async def _search_by_recent (
self ,
user : UserID ,
room_ids : Collection [ str ] ,
search_term : str ,
keys : Iterable [ str ] ,
search_filter : Filter ,
batch_group : Optional [ str ] ,
batch_group_key : Optional [ str ] ,
batch_token : Optional [ str ] ,
) - > Tuple [ _SearchResult , Optional [ str ] ] :
"""
Performs a full text search for a user ordering by recent .
Args :
user : The user performing the search .
room_ids : List of room ids to search in
search_term : Search term to search for
keys : List of keys to search in , currently supports
" content.body " , " content.name " , " content.topic "
search_filter : The event filter to use .
batch_group : Pagination information .
batch_group_key : Pagination information .
batch_token : Pagination information .
Returns :
A tuple of :
The search results .
Optionally , a pagination token .
"""
rank_map = { } # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups : Dict [ str , JsonDict ] = { }
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
highlights = set ( )
room_events : List [ EventBase ] = [ ]
i = 0
@ -292,9 +559,7 @@ class SearchHandler:
rank_map . update ( { r [ " event " ] . event_id : r [ " rank " ] for r in results } )
filtered_events = await search_filter . filter (
[ r [ " event " ] for r in results ]
)
filtered_events = await search_filter . filter ( [ r [ " event " ] for r in results ] )
events = await filter_events_for_client (
self . storage , user . to_string ( ) , filtered_events
@ -304,7 +569,6 @@ class SearchHandler:
room_events = room_events [ : search_filter . limit ]
if len ( results ) < search_filter . limit * 2 :
pagination_token = None
break
else :
pagination_token = results [ - 1 ] [ " pagination_token " ]
@ -324,8 +588,7 @@ class SearchHandler:
if batch_group and batch_group_key :
global_next_batch = encode_base64 (
(
" %s \n %s \n %s "
% ( batch_group , batch_group_key , pagination_token )
" %s \n %s \n %s " % ( batch_group , batch_group_key , pagination_token )
) . encode ( " ascii " )
)
else :
@ -340,17 +603,35 @@ class SearchHandler:
)
)
allowed_events . extend ( room_events )
return (
_SearchResult ( count , rank_map , room_events , room_groups , highlights ) ,
global_next_batch ,
)
else :
# We should never get here due to the guard earlier.
raise NotImplementedError ( )
async def _calculate_event_contexts (
self ,
user : UserID ,
allowed_events : List [ EventBase ] ,
before_limit : int ,
after_limit : int ,
include_profile : bool ,
) - > Dict [ str , JsonDict ] :
"""
Calculates the contextual events for any search results .
logger . info ( " Found %d events to return " , len ( allowed_events ) )
Args :
user : The user performing the search .
allowed_events : The search results .
before_limit :
The number of events before a result to include as context .
after_limit :
The number of events after a result to include as context .
include_profile : True if historical profile information should be
included in the event context .
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None :
Returns :
A map of event ID to contextual information .
"""
now_token = self . hs . get_event_sources ( ) . get_current_token ( )
contexts = { }
@ -379,9 +660,9 @@ class SearchHandler:
" start " : await now_token . copy_and_replace (
" room_key " , res . start
) . to_string ( self . store ) ,
" end " : await now_token . copy_and_replace (
" room_key " , res . end
) . to_string ( self . store ) ,
" end " : await now_token . copy_and_replace ( " room_key " , res . end ) . to_string (
self . store
) ,
}
if include_profile :
@ -413,81 +694,5 @@ class SearchHandler:
}
contexts [ event . event_id ] = context
else :
contexts = { }
# TODO: Add a limit
time_now = self . clock . time_msec ( )
aggregations = None
if self . _msc3666_enabled :
aggregations = await self . store . get_bundled_aggregations (
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools . chain (
# The events_before and events_after for each context.
itertools . chain . from_iterable (
itertools . chain ( context [ " events_before " ] , context [ " events_after " ] ) # type: ignore[arg-type]
for context in contexts . values ( )
) ,
# The returned events.
allowed_events ,
) ,
user . to_string ( ) ,
)
for context in contexts . values ( ) :
context [ " events_before " ] = self . _event_serializer . serialize_events (
context [ " events_before " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
context [ " events_after " ] = self . _event_serializer . serialize_events (
context [ " events_after " ] , time_now , bundle_aggregations = aggregations # type: ignore[arg-type]
)
state_results = { }
if include_state :
for room_id in { e . room_id for e in allowed_events } :
state = await self . state_handler . get_current_state ( room_id )
state_results [ room_id ] = list ( state . values ( ) )
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
results = [ ]
for e in allowed_events :
results . append (
{
" rank " : rank_map [ e . event_id ] ,
" result " : self . _event_serializer . serialize_event (
e , time_now , bundle_aggregations = aggregations
) ,
" context " : contexts . get ( e . event_id , { } ) ,
}
)
rooms_cat_res = {
" results " : results ,
" count " : count ,
" highlights " : list ( highlights ) ,
}
if state_results :
s = { }
for room_id , state_events in state_results . items ( ) :
s [ room_id ] = self . _event_serializer . serialize_events (
state_events , time_now
)
rooms_cat_res [ " state " ] = s
if room_groups and " room_id " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " room_id " ] = room_groups
if sender_group and " sender " in group_keys :
rooms_cat_res . setdefault ( " groups " , { } ) [ " sender " ] = sender_group
if global_next_batch :
rooms_cat_res [ " next_batch " ] = global_next_batch
return { " search_categories " : { " room_events " : rooms_cat_res } }
return contexts